Skip to content

Commit

Permalink
Add diagnostic information to progress bar (#163)
Browse files Browse the repository at this point in the history
* stash

* progress bar updates

* clean up

* clean up docstring

* address comments

* update docs

* update doctest; add to travis

* Remove default options for autodoc

* fix lint

* skip exact output check in doctest
  • Loading branch information
neerajprad authored and fehiepsi committed May 22, 2019
1 parent ae9ef67 commit f9dcb47
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 45 deletions.
8 changes: 6 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ branches:

jobs:
include:
- stage: lint
- stage: lint / docs
before_install: pip install .[doc]
python: 3.6
script: make lint
script:
- make lint
- make docs
- make doctest
- stage: unit
name: "unit tests"
python: 3.6
Expand Down
14 changes: 7 additions & 7 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@

autodoc_inherit_docstrings = False

autodoc_default_options = {
'member-order': 'bysource',
'show-inheritance': True,
'special-members': True,
'undoc-members': True,
# 'exclude-members': '__dict__,__module__,__weakref__',
}
# autodoc_default_options = {
# 'member-order': 'bysource',
# 'show-inheritance': True,
# 'special-members': True,
# 'undoc-members': True,
# 'exclude-members': '__dict__,__module__,__weakref__',
# }

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
Expand Down
2 changes: 2 additions & 0 deletions docs/source/mcmc.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Markov Chain Monte Carlo (MCMC)
===============================

.. autofunction:: numpyro.mcmc.mcmc

.. autofunction:: numpyro.mcmc.hmc

.. autofunction:: numpyro.mcmc.hmc.init_kernel
Expand Down
9 changes: 3 additions & 6 deletions examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from numpyro.examples.datasets import BASEBALL, load_dataset
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect
from numpyro.mcmc import mcmc

"""
Original example from Pyro:
Expand Down Expand Up @@ -137,10 +136,8 @@ def partially_pooled_with_logit(at_bats, hits=None):

def run_inference(model, at_bats, hits, rng, args):
init_params, potential_fn, constrain_fn = initialize_model(rng, model, at_bats, hits)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: constrain_fn(hmc_state.z))
hmc_states = mcmc(args.num_warmup, args.num_samples, init_params,
sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
return hmc_states


Expand Down
111 changes: 95 additions & 16 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,30 @@
from jax.tree_util import register_pytree_node

import numpyro.distributions as dist
from numpyro.diagnostics import summary
from numpyro.hmc_util import IntegratorState, build_tree, find_reasonable_step_size, velocity_verlet, warmup_adapter
from numpyro.util import cond, fori_loop

HMCState = namedtuple('HMCState', ['z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'step_size', 'inverse_mass_matrix', 'rng'])
from numpyro.util import cond, fori_loop, fori_collect, identity

HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'step_size', 'inverse_mass_matrix', 'rng'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
- **i** - iteration. This is reset to 0 after warmup.
- **z** - Python collection representing values (unconstrained samples from
the posterior) at latent sites.
- **z_grad** - Gradient of potential energy w.r.t. latent sample sites.
- **potential_energy** - Potential energy computed at the given value of ``z``.
- **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics).
- **accept_prob** - Acceptance probability of the proposal. Note that ``z``
does not correspond to the proposal if it is rejected.
- **mean_accept_prob** - Mean acceptance probability until current iteration
during warmup adaptation or sampling (for diagnostics).
- **step_size** - Step size to be used by the integrator in the next iteration.
This is adapted during warmup.
- **inverse_mass_matrix** - The inverse mass matrix to be be used for the next
iteration. This is adapted during warmup.
- **rng** - random number generator seed used for the iteration.
"""


register_pytree_node(
Expand Down Expand Up @@ -52,6 +71,12 @@ def _euclidean_ke(inverse_mass_matrix, r):
return 0.5 * np.dot(v, r)


def get_diagnostics_str(hmc_state):
return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(hmc_state.num_steps,
hmc_state.step_size,
hmc_state.mean_accept_prob)


def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
r"""
Hamiltonian Monte Carlo inference, using either fixed number of
Expand All @@ -65,7 +90,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_samples` argument to
any python collection type, provided that `init_params` argument to
`init_kernel` has the same type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum. If not provided, the default is
Expand Down Expand Up @@ -105,10 +130,10 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
>>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(init_params,
... trajectory_length=10,
... num_warmup_steps=300)
... num_warmup=300)
>>> hmc_states = fori_collect(500, sample_kernel, hmc_state,
... transform=lambda x: constrain_fn(x.z))
>>> print(np.mean(hmc_states['beta'], axis=0))
>>> print(np.mean(hmc_states['beta'], axis=0)) # doctest: +SKIP
[0.9153987 2.0754058 2.9621222]
"""
if kinetic_fn is None:
Expand All @@ -119,7 +144,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
momentum_generator = None
wa_update = None

def init_kernel(init_samples,
def init_kernel(init_params,
num_warmup,
step_size=1.0,
adapt_step_size=True,
Expand All @@ -134,8 +159,8 @@ def init_kernel(init_samples,
"""
Initializes the HMC sampler.
:param init_samples: Initial parameters to begin sampling. The type can
must be consistent with the input type to `potential_fn`.
:param init_params: Initial parameters to begin sampling. The type must
be consistent with the input type to `potential_fn`.
:param int num_warmup_steps: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
Expand Down Expand Up @@ -170,7 +195,7 @@ def init_kernel(init_samples,
nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth
trajectory_len = float(trajectory_length)
max_treedepth = max_tree_depth
z = init_samples
z = init_params
z_flat, unravel_fn = ravel_pytree(z)
momentum_generator = partial(_sample_momentum, unravel_fn)

Expand All @@ -189,7 +214,7 @@ def init_kernel(init_samples,
wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
r = momentum_generator(wa_state.inverse_mass_matrix, rng)
vv_state = vv_init(z, r)
hmc_state = HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0.,
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
wa_state.step_size, wa_state.inverse_mass_matrix, rng_hmc)

wa_update = jit(wa_update)
Expand All @@ -200,8 +225,12 @@ def init_kernel(init_samples,
warmup_update,
(hmc_state, wa_state))
else:
for i in tqdm.trange(num_warmup):
hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state))
with tqdm.trange(num_warmup, desc='warmup') as t:
for i in t:
hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state))
t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=True)
# Reset `i` and `mean_accept_prob` for fresh diagnostics.
hmc_state.update(i=0, mean_accept_prob=0)
return hmc_state
else:
return hmc_state, wa_state, warmup_update
Expand Down Expand Up @@ -259,8 +288,11 @@ def sample_kernel(hmc_state):
vv_state, num_steps, accept_prob = _next(hmc_state.step_size,
hmc_state.inverse_mass_matrix,
vv_state, rng_transition)
return HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix, rng)
itr = hmc_state.i + 1
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / itr
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, mean_accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix,
rng)

# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
Expand All @@ -269,3 +301,50 @@ def sample_kernel(hmc_state):
hmc.sample_kernel = sample_kernel

return init_kernel, sample_kernel


def mcmc(num_warmup, num_samples, init_params, sampler='hmc',
constrain_fn=None, print_summary=True, **sampler_kwargs):
"""
Convenience wrapper for MCMC samplers -- runs warmup, prints
diagnostic summary and returns a collections of samples
from the posterior.
:param num_warmup: Number of warmup steps.
:param num_samples: Number of samples to generate from the Markov chain.
:param init_params: Initial parameters to begin sampling. The type can
must be consistent with the input type to `potential_fn`.
:param sampler: currently, only `hmc` is implemented (default).
:param constrain_fn: Callable that converts a collection of unconstrained
sample values returned from the sampler to constrained values that
lie within the support of the sample sites.
:param print_summary: Whether to print diagnostics summary for
each sample site. Default is ``True``.
:param `**sampler_kwargs`: Sampler specific keyword arguments.
- *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and
:func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note
that all arguments must be provided as keywords.
:return: collection of samples from the posterior.
"""
if sampler == 'hmc':
if constrain_fn is None:
constrain_fn = identity
potential_fn = sampler_kwargs.pop('potential_fn')
kinetic_fn = sampler_kwargs.pop('kinetic_fn', None)
algo = sampler_kwargs.pop('algo', 'NUTS')
progbar = sampler_kwargs.pop('progbar', True)

init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo)
hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, **sampler_kwargs)
samples = fori_collect(num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z),
progbar=progbar,
diagnostics_fn=get_diagnostics_str,
progbar_desc='sample')
if print_summary:
summary(samples)
return samples
else:
raise ValueError('sampler: {} not recognized'.format(sampler))
15 changes: 10 additions & 5 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def fori_loop(lower, upper, body_fun, init_val):
return lax.fori_loop(lower, upper, body_fun, init_val)


def _identity(x):
def identity(x):
return x


def fori_collect(n, body_fun, init_val, transform=_identity, progbar=True):
def fori_collect(n, body_fun, init_val, transform=identity, progbar=True, **progbar_opts):
# works like lax.fori_loop but ignores i in body_fn, supports
# postprocessing `transform`, and collects values during the loop
init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
Expand All @@ -118,12 +118,17 @@ def _body_fn(i, vals):
_, collection = jit(lax.fori_loop, static_argnums=(2,))(0, n, _body_fn,
(init_val, collection))
else:
diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
progbar_desc = progbar_opts.pop('progbar_desc', '')
collection = []

val = init_val
for _ in tqdm.trange(n):
val = body_fun(val)
collection.append(jit(ravel_fn)(val))
with tqdm.trange(n, desc=progbar_desc) as t:
for _ in t:
val = body_fun(val)
collection.append(jit(ravel_fn)(val))
if diagnostics_fn:
t.set_postfix_str(diagnostics_fn(val), refresh=True)

# XXX: jax.numpy.stack/concatenate is currently so slow
collection = onp.stack(collection)
Expand Down
14 changes: 5 additions & 9 deletions test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.mcmc import hmc, mcmc
from numpyro.util import fori_collect


Expand All @@ -23,8 +23,8 @@ def potential_fn(z):
return 0.5 * np.sum(((z - true_mean) / true_std) ** 2)

init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
init_samples = np.array(0.)
hmc_state = init_kernel(init_samples,
init_params = np.array(0.)
hmc_state = init_kernel(init_params,
trajectory_length=10,
num_warmup=warmup_steps)
hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
Expand All @@ -48,12 +48,8 @@ def model(labels):
return sample('obs', dist.Bernoulli(logits=logits), obs=labels)

init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, labels)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params,
trajectory_length=10,
num_warmup=warmup_steps)
hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z))
hmc_states = mcmc(warmup_steps, num_samples, init_params, sampler='hmc',
potential_fn=potential_fn, trajectory_length=10, constrain_fn=constrain_fn)
assert_allclose(np.mean(hmc_states['coefs'], 0), true_coefs, atol=0.2)


Expand Down

0 comments on commit f9dcb47

Please sign in to comment.