Skip to content

Commit

Permalink
Add pathwise tests for some samplers (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed May 29, 2019
1 parent c843cdc commit 692788c
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax
import jax.numpy as np
import jax.random as random
from jax import grad, lax, vmap
from jax import grad, jacfwd, lax, vmap

import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
Expand Down Expand Up @@ -249,6 +249,31 @@ def fn(args):
assert_allclose(np.sum(actual_grad[i]), expected_grad, rtol=0.02)


@pytest.mark.parametrize('jax_dist, sp_dist, params', [
(dist.Gamma, osp.gamma, (1.,)),
(dist.Gamma, osp.gamma, (0.1,)),
(dist.Gamma, osp.gamma, (10.,)),
# TODO: add more test cases for Beta/StudentT (and Dirichlet too) when
# their pathwise grad (independent of standard_gamma grad) is implemented.
pytest.param(dist.Beta, osp.beta, (1., 1.), marks=pytest.mark.xfail(
reason='currently, variance of grad of beta sampler is large')),
pytest.param(dist.StudentT, osp.t, (1.,), marks=pytest.mark.xfail(
reason='currently, variance of grad of t sampler is large')),
])
def test_pathwise_gradient(jax_dist, sp_dist, params):
rng = random.PRNGKey(0)
N = 100
z = jax_dist(*params).sample(key=rng, size=(N,))
actual_grad = jacfwd(lambda x: jax_dist(*x).sample(key=rng, size=(N,)))(params)
eps = 1e-3
for i in range(len(params)):
args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
cdf_dot = (sp_dist(*args_rhs).cdf(z) - sp_dist(*args_lhs).cdf(z)) / (2 * eps)
expected_grad = -cdf_dot / sp_dist(*params).pdf(z)
assert_allclose(actual_grad[i], expected_grad, rtol=0.005)


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
@pytest.mark.parametrize('prepend_shape', [
(),
Expand Down

0 comments on commit 692788c

Please sign in to comment.