Skip to content

Commit

Permalink
Initial support for chains (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Jun 9, 2019
1 parent 3291ea6 commit 3151fab
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 23 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
script:
- pytest -vs -m "not test_examples"
- JAX_ENABLE_x64=1 pytest -vs test/test_mcmc.py
- XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/test_mcmc.py -k pmap
- name: "examples"
python: 3.6
script: pytest -vs -m test_examples
6 changes: 5 additions & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import time

import numpy as onp

Expand Down Expand Up @@ -155,8 +156,11 @@ def main(args):
transition_prior, emission_prior, supervised_categories,
supervised_words, unsupervised_words,
)
start = time.time()
samples = mcmc(args.num_warmup, args.num_samples, init_params,
potential_fn=potential_fn, constrain_fn=constrain_fn)
potential_fn=potential_fn, constrain_fn=constrain_fn,
progbar=True)
print('\nMCMC elapsed time:', time.time() - start)
print_results(samples, transition_prob, emission_prob)


Expand Down
5 changes: 4 additions & 1 deletion numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def __call__(self, x):

class _PositiveDefinite(Constraint):
def __call__(self, x):
# check for symmetric
symmetric = np.all(np.all(x == np.swapaxes(x, -2, -1), axis=-1), axis=-1)
# check for the smallest eigenvalue is positive
return np.linalg.eigh(x)[0][..., 0] > 0
positive = np.linalg.eigh(x)[0][..., 0] > 0
return symmetric & positive


class _Real(Constraint):
Expand Down
12 changes: 8 additions & 4 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,18 @@ def init_kernel(init_params,
if run_warmup:
# JIT if progress bar updates not required
if not progbar:
hmc_state, _ = jit(fori_loop, static_argnums=(2,))(0, num_warmup,
warmup_update,
(hmc_state, wa_state))
# PERF: jitting the for loop may be faster on certain models or high
# number of samples.
# TODO: remove this condition when the issue is resolved
if progbar is None: # NB: if progbar=None, we jit fori_loop
hmc_state, _ = jit(fori_loop, static_argnums=(2,))(0, num_warmup, warmup_update,
(hmc_state, wa_state))
else:
hmc_state, _ = fori_loop(0, num_warmup, warmup_update, (hmc_state, wa_state))
else:
with tqdm.trange(num_warmup, desc='warmup') as t:
for i in t:
hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state))
# TODO: set refresh=True when its performance issue is resolved
t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
# Reset `i` and `mean_accept_prob` for fresh diagnostics.
hmc_state.update(i=0, mean_accept_prob=0)
Expand Down
13 changes: 9 additions & 4 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,22 @@ def fori_collect(n, body_fun, init_val, transform=identity, progbar=True, **prog
ravel_fn = lambda x: ravel_pytree(transform(x))[0] # noqa: E731

if not progbar:
collection = np.zeros((n,) + init_val_flat.shape, dtype=init_val_flat.dtype)
collection = np.zeros((n,) + init_val_flat.shape)

def _body_fn(i, vals):
val, collection = vals
val = body_fun(val)
collection = ops.index_update(collection, i, ravel_fn(val))
return val, collection

_, collection = jit(lax.fori_loop, static_argnums=(2,))(0, n, _body_fn,
# PERF: jitting the for loop may be faster on certain models or high
# number of samples.
# TODO: remove this condition when the issue is resolved
if progbar is None: # NB: if progbar=None, we jit fori_loop
_, collection = jit(fori_loop, static_argnums=(2,))(0, n, _body_fn,
(init_val, collection))
else:
_, collection = fori_loop(0, n, _body_fn, (init_val, collection))
else:
diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
progbar_desc = progbar_opts.pop('progbar_desc', '')
Expand All @@ -149,10 +155,9 @@ def _body_fn(i, vals):
val = body_fun(val)
collection.append(jit(ravel_fn)(val))
if diagnostics_fn:
# TODO: set refresh=True when its performance issue is resolved
t.set_postfix_str(diagnostics_fn(val), refresh=False)

# XXX: jax.numpy.stack/concatenate is currently so slow
# XXX: jax.numpy.stack/concatenate is currently slow
collection = onp.stack(collection)

return vmap(unravel_fn)(collection)
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
author='Uber AI Labs',
author_email='npradhan@uber.com',
install_requires=[
# TODO: Remove soon as JAX's API becomes stable
'jax==0.1.35',
'jaxlib>=0.1.14',
# TODO: pin to a specific version for the next release (unless JAX's API becomes stable)
'jax>=0.1.36',
'jaxlib>=0.1.18',
'tqdm',
],
extras_require={
Expand Down
8 changes: 2 additions & 6 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,6 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
with pytest.raises(ValueError):
jax_dist(*oob_params, validate_args=True)

if jax_dist is dist.MultivariateNormal and jax_dist(*valid_params).batch_shape:
pytest.xfail('numpy.linalg.eigh batch rule is not available yet.')
d = jax_dist(*valid_params, validate_args=True)

# Test agreement of log density evaluation on randomly generated samples
Expand Down Expand Up @@ -582,10 +580,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
(constraints.positive, 3, True),
(constraints.positive, np.array([-1, 0, 5]), np.array([False, False, True])),
(constraints.positive_definite, np.array([[1., 0.3], [0.3, 1.]]), True),
pytest.param(constraints.positive_definite,
np.array([[[2., 0.4], [0.3, 2.]], [[1., 0.1], [0.1, 0.]]]),
np.array([False, False]),
marks=pytest.mark.xfail(reason="np.linalg.eigh batching rule is not available yet")),
(constraints.positive_definite, np.array([[[2., 0.4], [0.3, 2.]], [[1., 0.1], [0.1, 0.]]]),
np.array([False, False])),
(constraints.positive_integer, 3, True),
(constraints.positive_integer, np.array([-1., 0., 5.]), np.array([False, False, True])),
(constraints.real, -1, True),
Expand Down
3 changes: 1 addition & 2 deletions test/test_distributions_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def test_discrete_validate_args(jax_dist, valid_args, invalid_args, invalid_samp
dist.norm,
dist.pareto,
dist.t,
pytest.param(dist.trunccauchy, marks=pytest.mark.xfail(
reason='jvp rule for np.arctan is not yet available')),
dist.trunccauchy,
dist.truncnorm,
dist.uniform,
], ids=idfn)
Expand Down
3 changes: 2 additions & 1 deletion test/test_hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def find_reasonable_step_size(m_inv, z, rng, step_size):
# the second window.
welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5))
assert_allclose(inverse_mass_matrix,
np.full((mass_matrix_size,), welford_regularize_term))
np.full((mass_matrix_size,), welford_regularize_term),
atol=1e-7)

window = adaptation_schedule[2]
for t in range(window.start, window.end + 1):
Expand Down
30 changes: 29 additions & 1 deletion test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from numpy.testing import assert_allclose

import jax.numpy as np
from jax import random
from jax import pmap, random
from jax.lib import xla_bridge
from jax.scipy.special import logit

import numpyro.distributions as dist
Expand Down Expand Up @@ -201,3 +202,30 @@ def model(data):

if 'JAX_ENABLE_x64' in os.environ:
assert hmc_states['p'].dtype == np.float64


@pytest.mark.parametrize('algo', ['HMC', 'NUTS'])
def test_pmap(algo):
if xla_bridge.device_count() == 1:
pytest.skip('pmap test requires device_count greater than 1.')

true_mean, true_std = 1., 2.
warmup_steps, num_samples = 1000, 8000

def potential_fn(z):
return 0.5 * np.sum(((z - true_mean) / true_std) ** 2)

init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
init_params = np.array([0., -1.])
rngs = random.split(random.PRNGKey(0), 2)

init_kernel_pmap = pmap(lambda init_param, rng: init_kernel(
init_param, trajectory_length=9, num_warmup=warmup_steps, progbar=False, rng=rng))
init_states = init_kernel_pmap(init_params, rngs)

fori_collect_pmap = pmap(lambda hmc_state: fori_collect(num_samples, sample_kernel, hmc_state,
transform=lambda x: x.z, progbar=False))
chain_samples = fori_collect_pmap(init_states)

assert_allclose(np.mean(chain_samples, axis=1), np.repeat(true_mean, 2), rtol=0.05)
assert_allclose(np.std(chain_samples, axis=1), np.repeat(true_std, 2), rtol=0.05)

0 comments on commit 3151fab

Please sign in to comment.