Skip to content

Commit

Permalink
Adding some more documentation (#157)
Browse files Browse the repository at this point in the history
* Refactor docs

* stash

* add svi docs

* clean up

* address comments

* fix variable names

* Fix args/kwargs

* fix inv transform
  • Loading branch information
neerajprad authored and fehiepsi committed May 19, 2019
1 parent 6ed0135 commit d648254
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 77 deletions.
3 changes: 3 additions & 0 deletions Makefile
Expand Up @@ -6,6 +6,9 @@ lint: FORCE
format: FORCE
isort -rc .

doctest: FORCE
$(MAKE) -C docs doctest

test: lint FORCE
pytest -v test

Expand Down
10 changes: 9 additions & 1 deletion docs/source/conf.py
Expand Up @@ -21,6 +21,14 @@
#
sys.path.insert(0, os.path.abspath('../..'))

# HACK: This is to ensure that local functions are documented by sphinx.
from numpyro.mcmc import hmc # noqa: E402
from numpyro.svi import svi # noqa: E402

os.environ['SPHINX_BUILD'] = '1'
hmc(None, None)
svi(None, None, None, None, None, None)

# -- Project information -----------------------------------------------------

project = u'Numpyro'
Expand Down Expand Up @@ -61,7 +69,7 @@
'show-inheritance': True,
'special-members': True,
'undoc-members': True,
'exclude-members': '__dict__,__module__,__weakref__',
# 'exclude-members': '__dict__,__module__,__weakref__',
}

# Add any paths that contain templates here, relative to this directory.
Expand Down
18 changes: 13 additions & 5 deletions docs/source/mcmc.rst
@@ -1,8 +1,16 @@
Markov Chain Monte Carlo (MCMC)
===============================

.. automodule:: numpyro.mcmc
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autofunction:: numpyro.mcmc.hmc

.. autofunction:: numpyro.mcmc.hmc.init_kernel

.. autofunction:: numpyro.mcmc.hmc.sample_kernel

.. autodata:: numpyro.mcmc.HMCState


MCMC Utilities
--------------

.. autofunction:: numpyro.hmc_util.initialize_model
12 changes: 7 additions & 5 deletions docs/source/svi.rst
@@ -1,8 +1,10 @@
Stochastic Variational Inference (SVI)
======================================

.. automodule:: numpyro.svi
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autofunction:: numpyro.svi.svi

.. autofunction:: numpyro.svi.svi.init_fn

.. autofunction:: numpyro.svi.svi.update_fn

.. autofunction:: numpyro.svi.svi.evaluate
2 changes: 1 addition & 1 deletion examples/baseball.py
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as np
import jax.random as random
from jax.config import config as jax_config
from jax.scipy.misc import logsumexp
from jax.scipy.special import logsumexp

import numpyro.distributions as dist
from numpyro.examples.datasets import BASEBALL, load_dataset
Expand Down
31 changes: 25 additions & 6 deletions numpyro/hmc_util.py
Expand Up @@ -561,16 +561,35 @@ def _potential_energy(params):
return _potential_energy


def transform_fn(transforms, params, invert=False):
return {k: transforms[k](v) if not invert else transforms[k].inv(v)
def transform_fn(inv_transforms, params, constrain=True):
return {k: inv_transforms[k](v) if constrain else inv_transforms[k].inv(v)
for k, v in params.items()}


def initialize_model(rng, model, *model_args, **model_kwargs):
"""
Given a model with Pyro primitives, returns a function which, given
unconstrained parameters, evaluates the potential energy (negative
joint density). In addition, this also returns initial parameters
sampled from the prior to initiate MCMC sampling and functions to
transform unconstrained values at sample sites to constrained values
within their respective support.
:param jax.random.PRNGKey rng: random number generator seed to
sample from the prior.
:param model: Python callable containing Pyro primitives.
:param `*model_args`: args provided to the model.
:param `**model_kwargs`: kwargs provided to the model.
:return: tuple of (`init_params`, `potential_fn`, `inv_transform_fn`)
`init_params` are values from the prior used to initiate MCMC.
`inv_transform_fn` is a callable that uses inverse transforms
to convert unconstrained HMC samples to constrained values that
lie within the site's support.
"""
model = seed(model, rng)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
sample_sites = {k: v for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed']}
transforms = {k: biject_to(v['fn'].support) for k, v in sample_sites.items()}
init_params = transform_fn(transforms, {k: v['value'] for k, v in sample_sites.items()}, invert=True)
return init_params, potential_energy(model, model_args, model_kwargs, transforms), \
jax.partial(transform_fn, transforms)
inv_transforms = {k: biject_to(v['fn'].support) for k, v in sample_sites.items()}
init_params = transform_fn(inv_transforms, {k: v['value'] for k, v in sample_sites.items()}, constrain=False)
return init_params, potential_energy(model, model_args, model_kwargs, inv_transforms), \
jax.partial(transform_fn, inv_transforms, constrain=True)
138 changes: 85 additions & 53 deletions numpyro/mcmc.py
@@ -1,4 +1,5 @@
import math
import os

import tqdm

Expand Down Expand Up @@ -52,8 +53,8 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that ``init_samples`` argument to
``init_kernel`` has the same type.
any python collection type, provided that `init_samples` argument to
`init_kernel` has the same type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum. If not provided, the default is
euclidean kinetic energy.
Expand All @@ -63,53 +64,40 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
one to initialize the sampler, and the second one to generate samples
given an existing one.
The arguments taken by `init_kernel` and `sample_kernel` are as follows:
.. function:: init_kernel
Initializes the HMC sampler.
:param init_samples: Initial parameters to begin sampling. The type can
must be consistent with the input type to ``potential_fn``.
:param int num_warmup_steps: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool diag_mass: A flag to decide if mass matrix is diagonal (default)
or dense (if set to ``False``).
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\pi`.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Defaults to 10.
:param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
`init_kernel` returns an initial :func:`~numpyro.mcmc.HMCState` that
can be used to generate samples using MCMC. Else, returns the arguments
and callable that does the initial adaptation.
:param bool progbar: Whether to enable progress bar updates. Defaults to
``True``.
:param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
step size is done at the beginning of each adaptation window to achieve
`target_acceptance_prob`.
:param jax.random.PRNGKey rng: random key to be used as the source of
randomness.
.. function:: sample_kernel
Given a :func:`~numpyro.mcmc.HMCState`, run HMC with fixed (possibly
adapted) step size and return :func:`~numpyro.mcmc.HMCState`.
:param hmc_state: Current sample (and associated state).
:return: new proposed :func:`~numpyro.mcmc.HMCState` from simulating
Hamiltonian dynamics given existing state.
**Example**
.. testsetup::
import jax
from jax import random
import jax.numpy as np
import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect
.. doctest::
>>> true_coefs = np.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
... coefs_mean = np.zeros(dim)
... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
... return sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels)
>>>
>>> init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(0), model, data, labels)
>>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(init_params,
... trajectory_length=10,
... num_warmup_steps=300)
>>> hmc_states = fori_collect(500, sample_kernel, hmc_state,
... transform=lambda x: transform_fn(x.z))
>>> print(np.mean(hmc_states['beta'], axis=0))
[0.9153987 2.0754058 2.9621222]
"""
if kinetic_fn is None:
kinetic_fn = _euclidean_ke
Expand All @@ -132,6 +120,41 @@ def init_kernel(init_samples,
progbar=True,
heuristic_step_size=True,
rng=PRNGKey(0)):
"""
Initializes the HMC sampler.
:param init_samples: Initial parameters to begin sampling. The type can
must be consistent with the input type to `potential_fn`.
:param int num_warmup_steps: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool diag_mass: A flag to decide if mass matrix is diagonal (default)
or dense (if set to ``False``).
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\\pi`.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Defaults to 10.
:param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
`init_kernel` returns an initial :data:`HMCState` that can be used to
generate samples using MCMC. Else, returns the arguments and callable
that does the initial adaptation.
:param bool progbar: Whether to enable progress bar updates. Defaults to
``True``.
:param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
step size is done at the beginning of each adaptation window to achieve
`target_acceptance_prob`.
:param jax.random.PRNGKey rng: random key to be used as the source of
randomness.
"""
step_size = float(step_size)
nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth
trajectory_len = float(trajectory_length)
Expand Down Expand Up @@ -214,6 +237,14 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):

@jit
def sample_kernel(hmc_state):
"""
Given an existing :data:`HMCState`, run HMC with fixed (possibly adapted)
step size and return a new :data:`HMCState`.
:param hmc_state: Current sample (and associated state).
:return: new proposed :data:`HMCState` from simulating
Hamiltonian dynamics given existing state.
"""
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.inverse_mass_matrix, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
Expand All @@ -223,9 +254,10 @@ def sample_kernel(hmc_state):
return HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix, rng)

# populate docs for `init_kernel` and `sample_kernel`
component_docs = hmc.__doc__.split('.. function::')
init_kernel.__doc__ = '\n'.join(component_docs[1].split('\n')[1:])
sample_kernel.__doc__ = '\n'.join(component_docs[2].split('\n')[1:])
# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
if 'SPHINX_BUILD' in os.environ:
hmc.init_kernel = init_kernel
hmc.sample_kernel = sample_kernel

return init_kernel, sample_kernel

0 comments on commit d648254

Please sign in to comment.