Skip to content

Commit

Permalink
add batching rule for standard_gamma (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed May 25, 2019
1 parent 8a633e3 commit 3bf3552
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
3 changes: 2 additions & 1 deletion numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def summary(samples, prob=0.89):
header_format = '{:>20} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}'
columns = ['', 'mean', 'sd', '{:.1f}%'.format(50 * (1 - prob)),
'{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'Rhat']
print('\n', header_format.format(*columns))
print('\n')
print(header_format.format(*columns))

# FIXME: maybe allow a `digits` arg to set how many floatting points are needed?
row_format = '{:>20} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}'
Expand Down
7 changes: 4 additions & 3 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ def sample(self, key, size=()):
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
normalize_term = (np.sum(gammaln(self.concentration), axis=-1) -
gammaln(np.sum(self.concentration, axis=-1)))
return np.sum(np.log(value) * (self.concentration - 1.), axis=-1) - normalize_term
concentration = lax.convert_element_type(self.concentration, value.dtype)
normalize_term = (np.sum(gammaln(concentration), axis=-1) -
gammaln(np.sum(concentration, axis=-1)))
return np.sum(np.log(value) * (concentration - 1.), axis=-1) - normalize_term

@property
def mean(self):
Expand Down
9 changes: 7 additions & 2 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import jax.numpy as np
from jax import canonicalize_dtype, custom_transforms, device_get, jit, lax, random, vmap
from jax.interpreters import ad
from jax.interpreters import ad, batching
from jax.numpy.lax_numpy import _promote_args_like
from jax.scipy.special import gammaln
from jax.util import partial
Expand Down Expand Up @@ -54,8 +54,12 @@ def _next_kxv(kxv):

# TODO: use upstream implementation when available because it is 2x faster
def _standard_gamma_impl(key, alpha):
if key.ndim > 1:
keys = vmap(lambda k: random.split(k, np.size(alpha[0])))(key)
else:
keys = random.split(key, alpha.size)
alphas = np.reshape(alpha, -1)
keys = random.split(key, alphas.size)
keys = np.reshape(keys, (-1, 2))
samples = vmap(_standard_gamma_one)(keys, alphas)
return samples.reshape(alpha.shape)

Expand Down Expand Up @@ -175,6 +179,7 @@ def _standard_gamma_p(key, alpha):

ad.defjvp2(_standard_gamma_p.primitive, None,
lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha))
batching.defvectorized(_standard_gamma_p.primitive)


@partial(jit, static_argnums=(2, 3))
Expand Down
12 changes: 11 additions & 1 deletion test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.testing import assert_allclose

import jax.numpy as np
from jax import grad, jacobian, jit, lax, random
from jax import grad, jacobian, jit, lax, random, vmap
from jax.scipy.special import expit
from jax.util import partial

Expand Down Expand Up @@ -160,6 +160,16 @@ def test_standard_gamma_grad(alpha):
assert_allclose(actual_grad, expected_grad, atol=1e-8, rtol=0.0005)


def test_standard_gamma_batch():
rng = random.PRNGKey(0)
alphas = np.array([1., 2., 3.])
rngs = random.split(rng, 3)

samples = vmap(lambda rng, alpha: standard_gamma(rng, alpha))(rngs, alphas)
for i in range(3):
assert_allclose(samples[i], standard_gamma(rngs[i], alphas[i]))


@pytest.mark.parametrize('p, shape', [
(np.array([0.1, 0.9]), ()),
(np.array([0.2, 0.8]), (2,)),
Expand Down

0 comments on commit 3bf3552

Please sign in to comment.