From d648254d3556ac561bd261e2254f62cb3955da67 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Sat, 18 May 2019 19:26:45 -0700 Subject: [PATCH] Adding some more documentation (#157) * Refactor docs * stash * add svi docs * clean up * address comments * fix variable names * Fix args/kwargs * fix inv transform --- Makefile | 3 + docs/source/conf.py | 10 +++- docs/source/mcmc.rst | 18 ++++-- docs/source/svi.rst | 12 ++-- examples/baseball.py | 2 +- numpyro/hmc_util.py | 31 ++++++++-- numpyro/mcmc.py | 138 ++++++++++++++++++++++++++----------------- numpyro/svi.py | 87 +++++++++++++++++++++++++-- 8 files changed, 224 insertions(+), 77 deletions(-) diff --git a/Makefile b/Makefile index bfc50a828..9d258f3a2 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,9 @@ lint: FORCE format: FORCE isort -rc . +doctest: FORCE + $(MAKE) -C docs doctest + test: lint FORCE pytest -v test diff --git a/docs/source/conf.py b/docs/source/conf.py index 6bd7329f0..d83227fad 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,6 +21,14 @@ # sys.path.insert(0, os.path.abspath('../..')) +# HACK: This is to ensure that local functions are documented by sphinx. +from numpyro.mcmc import hmc # noqa: E402 +from numpyro.svi import svi # noqa: E402 + +os.environ['SPHINX_BUILD'] = '1' +hmc(None, None) +svi(None, None, None, None, None, None) + # -- Project information ----------------------------------------------------- project = u'Numpyro' @@ -61,7 +69,7 @@ 'show-inheritance': True, 'special-members': True, 'undoc-members': True, - 'exclude-members': '__dict__,__module__,__weakref__', + # 'exclude-members': '__dict__,__module__,__weakref__', } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index a9392aea2..044ddad52 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -1,8 +1,16 @@ Markov Chain Monte Carlo (MCMC) =============================== -.. automodule:: numpyro.mcmc - :members: - :undoc-members: - :show-inheritance: - :member-order: bysource +.. autofunction:: numpyro.mcmc.hmc + +.. autofunction:: numpyro.mcmc.hmc.init_kernel + +.. autofunction:: numpyro.mcmc.hmc.sample_kernel + +.. autodata:: numpyro.mcmc.HMCState + + +MCMC Utilities +-------------- + +.. autofunction:: numpyro.hmc_util.initialize_model diff --git a/docs/source/svi.rst b/docs/source/svi.rst index fc012d9dd..83210c8b7 100644 --- a/docs/source/svi.rst +++ b/docs/source/svi.rst @@ -1,8 +1,10 @@ Stochastic Variational Inference (SVI) ====================================== -.. automodule:: numpyro.svi - :members: - :undoc-members: - :show-inheritance: - :member-order: bysource +.. autofunction:: numpyro.svi.svi + +.. autofunction:: numpyro.svi.svi.init_fn + +.. autofunction:: numpyro.svi.svi.update_fn + +.. autofunction:: numpyro.svi.svi.evaluate diff --git a/examples/baseball.py b/examples/baseball.py index 3213acd65..b21c26a2e 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -5,7 +5,7 @@ import jax.numpy as np import jax.random as random from jax.config import config as jax_config -from jax.scipy.misc import logsumexp +from jax.scipy.special import logsumexp import numpyro.distributions as dist from numpyro.examples.datasets import BASEBALL, load_dataset diff --git a/numpyro/hmc_util.py b/numpyro/hmc_util.py index 2df802a1b..37e4a2c89 100644 --- a/numpyro/hmc_util.py +++ b/numpyro/hmc_util.py @@ -561,16 +561,35 @@ def _potential_energy(params): return _potential_energy -def transform_fn(transforms, params, invert=False): - return {k: transforms[k](v) if not invert else transforms[k].inv(v) +def transform_fn(inv_transforms, params, constrain=True): + return {k: inv_transforms[k](v) if constrain else inv_transforms[k].inv(v) for k, v in params.items()} def initialize_model(rng, model, *model_args, **model_kwargs): + """ + Given a model with Pyro primitives, returns a function which, given + unconstrained parameters, evaluates the potential energy (negative + joint density). In addition, this also returns initial parameters + sampled from the prior to initiate MCMC sampling and functions to + transform unconstrained values at sample sites to constrained values + within their respective support. + + :param jax.random.PRNGKey rng: random number generator seed to + sample from the prior. + :param model: Python callable containing Pyro primitives. + :param `*model_args`: args provided to the model. + :param `**model_kwargs`: kwargs provided to the model. + :return: tuple of (`init_params`, `potential_fn`, `inv_transform_fn`) + `init_params` are values from the prior used to initiate MCMC. + `inv_transform_fn` is a callable that uses inverse transforms + to convert unconstrained HMC samples to constrained values that + lie within the site's support. + """ model = seed(model, rng) model_trace = trace(model).get_trace(*model_args, **model_kwargs) sample_sites = {k: v for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed']} - transforms = {k: biject_to(v['fn'].support) for k, v in sample_sites.items()} - init_params = transform_fn(transforms, {k: v['value'] for k, v in sample_sites.items()}, invert=True) - return init_params, potential_energy(model, model_args, model_kwargs, transforms), \ - jax.partial(transform_fn, transforms) + inv_transforms = {k: biject_to(v['fn'].support) for k, v in sample_sites.items()} + init_params = transform_fn(inv_transforms, {k: v['value'] for k, v in sample_sites.items()}, constrain=False) + return init_params, potential_energy(model, model_args, model_kwargs, inv_transforms), \ + jax.partial(transform_fn, inv_transforms, constrain=True) diff --git a/numpyro/mcmc.py b/numpyro/mcmc.py index de6b3f9cd..c7fb22f05 100644 --- a/numpyro/mcmc.py +++ b/numpyro/mcmc.py @@ -1,4 +1,5 @@ import math +import os import tqdm @@ -52,8 +53,8 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'): :param potential_fn: Python callable that computes the potential energy given input parameters. The input parameters to `potential_fn` can be - any python collection type, provided that ``init_samples`` argument to - ``init_kernel`` has the same type. + any python collection type, provided that `init_samples` argument to + `init_kernel` has the same type. :param kinetic_fn: Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy. @@ -63,53 +64,40 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'): one to initialize the sampler, and the second one to generate samples given an existing one. - The arguments taken by `init_kernel` and `sample_kernel` are as follows: - - .. function:: init_kernel - - Initializes the HMC sampler. - - :param init_samples: Initial parameters to begin sampling. The type can - must be consistent with the input type to ``potential_fn``. - :param int num_warmup_steps: Number of warmup steps; samples generated - during warmup are discarded. - :param float step_size: Determines the size of a single step taken by the - verlet integrator while computing the trajectory using Hamiltonian - dynamics. If not specified, it will be set to 1. - :param bool adapt_step_size: A flag to decide if we want to adapt step_size - during warm-up phase using Dual Averaging scheme. - :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass - matrix during warm-up phase using Welford scheme. - :param bool diag_mass: A flag to decide if mass matrix is diagonal (default) - or dense (if set to ``False``). - :param float target_accept_prob: Target acceptance probability for step size - adaptation using Dual Averaging. Increasing this value will lead to a smaller - step size, hence the sampling will be slower but more robust. Default to 0.8. - :param float trajectory_length: Length of a MCMC trajectory for HMC. Default - value is :math:`2\pi`. - :param int max_tree_depth: Max depth of the binary tree created during the doubling - scheme of NUTS sampler. Defaults to 10. - :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, - `init_kernel` returns an initial :func:`~numpyro.mcmc.HMCState` that - can be used to generate samples using MCMC. Else, returns the arguments - and callable that does the initial adaptation. - :param bool progbar: Whether to enable progress bar updates. Defaults to - ``True``. - :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of - step size is done at the beginning of each adaptation window to achieve - `target_acceptance_prob`. - :param jax.random.PRNGKey rng: random key to be used as the source of - randomness. - - - .. function:: sample_kernel - - Given a :func:`~numpyro.mcmc.HMCState`, run HMC with fixed (possibly - adapted) step size and return :func:`~numpyro.mcmc.HMCState`. - - :param hmc_state: Current sample (and associated state). - :return: new proposed :func:`~numpyro.mcmc.HMCState` from simulating - Hamiltonian dynamics given existing state. + **Example** + + .. testsetup:: + + import jax + from jax import random + import jax.numpy as np + import numpyro.distributions as dist + from numpyro.handlers import sample + from numpyro.hmc_util import initialize_model + from numpyro.mcmc import hmc + from numpyro.util import fori_collect + + .. doctest:: + + >>> true_coefs = np.array([1., 2., 3.]) + >>> data = random.normal(random.PRNGKey(2), (2000, 3)) + >>> dim = 3 + >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) + >>> + >>> def model(data, labels): + ... coefs_mean = np.zeros(dim) + ... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3))) + ... return sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) + >>> + >>> init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(0), model, data, labels) + >>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') + >>> hmc_state = init_kernel(init_params, + ... trajectory_length=10, + ... num_warmup_steps=300) + >>> hmc_states = fori_collect(500, sample_kernel, hmc_state, + ... transform=lambda x: transform_fn(x.z)) + >>> print(np.mean(hmc_states['beta'], axis=0)) + [0.9153987 2.0754058 2.9621222] """ if kinetic_fn is None: kinetic_fn = _euclidean_ke @@ -132,6 +120,41 @@ def init_kernel(init_samples, progbar=True, heuristic_step_size=True, rng=PRNGKey(0)): + """ + Initializes the HMC sampler. + + :param init_samples: Initial parameters to begin sampling. The type can + must be consistent with the input type to `potential_fn`. + :param int num_warmup_steps: Number of warmup steps; samples generated + during warmup are discarded. + :param float step_size: Determines the size of a single step taken by the + verlet integrator while computing the trajectory using Hamiltonian + dynamics. If not specified, it will be set to 1. + :param bool adapt_step_size: A flag to decide if we want to adapt step_size + during warm-up phase using Dual Averaging scheme. + :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass + matrix during warm-up phase using Welford scheme. + :param bool diag_mass: A flag to decide if mass matrix is diagonal (default) + or dense (if set to ``False``). + :param float target_accept_prob: Target acceptance probability for step size + adaptation using Dual Averaging. Increasing this value will lead to a smaller + step size, hence the sampling will be slower but more robust. Default to 0.8. + :param float trajectory_length: Length of a MCMC trajectory for HMC. Default + value is :math:`2\\pi`. + :param int max_tree_depth: Max depth of the binary tree created during the doubling + scheme of NUTS sampler. Defaults to 10. + :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, + `init_kernel` returns an initial :data:`HMCState` that can be used to + generate samples using MCMC. Else, returns the arguments and callable + that does the initial adaptation. + :param bool progbar: Whether to enable progress bar updates. Defaults to + ``True``. + :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of + step size is done at the beginning of each adaptation window to achieve + `target_acceptance_prob`. + :param jax.random.PRNGKey rng: random key to be used as the source of + randomness. + """ step_size = float(step_size) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth trajectory_len = float(trajectory_length) @@ -214,6 +237,14 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng): @jit def sample_kernel(hmc_state): + """ + Given an existing :data:`HMCState`, run HMC with fixed (possibly adapted) + step size and return a new :data:`HMCState`. + + :param hmc_state: Current sample (and associated state). + :return: new proposed :data:`HMCState` from simulating + Hamiltonian dynamics given existing state. + """ rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3) r = momentum_generator(hmc_state.inverse_mass_matrix, rng_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) @@ -223,9 +254,10 @@ def sample_kernel(hmc_state): return HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps, accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix, rng) - # populate docs for `init_kernel` and `sample_kernel` - component_docs = hmc.__doc__.split('.. function::') - init_kernel.__doc__ = '\n'.join(component_docs[1].split('\n')[1:]) - sample_kernel.__doc__ = '\n'.join(component_docs[2].split('\n')[1:]) + # Make `init_kernel` and `sample_kernel` visible from the global scope once + # `hmc` is called for sphinx doc generation. + if 'SPHINX_BUILD' in os.environ: + hmc.init_kernel = init_kernel + hmc.sample_kernel = sample_kernel return init_kernel, sample_kernel diff --git a/numpyro/svi.py b/numpyro/svi.py index 3597af863..1d0f12c21 100644 --- a/numpyro/svi.py +++ b/numpyro/svi.py @@ -1,3 +1,5 @@ +import os + from jax import random, value_and_grad from numpyro.handlers import replay, seed, substitute, trace @@ -12,7 +14,34 @@ def _seed(model, guide, rng): def svi(model, guide, loss, optim_init, optim_update, get_params, **kwargs): + """ + Stochastic Variational Inference given an ELBo loss objective. + + :param model: Python callable with Pyro primitives for the model. + :param guide: Python callable with Pyro primitives for the guide + (recognition network). + :param loss: ELBo loss, i.e. negative Evidence Lower Bound, to minimize. + :param optim_init: initialization function returned by a JAX optimizer. + see: :mod:`jax.experimental.optimizers`. + :param optim_update: update function for the optimizer + :param get_params: function to get current parameters values given the + optimizer state. + :param `**kwargs`: static arguments for the model / guide, i.e. arguments + that remain constant during fitting. + :return: tuple of `(init_fn, update_fn, evaluate)`. + """ def init_fn(rng, model_args=(), guide_args=(), params=None): + """ + + :param jax.random.PRNGKey rng: random number generator seed. + :param tuple model_args: arguments to the model (these can possibly vary during + the course of fitting). + :param tuple guide_args: arguments to the guide (these can possibly vary during + the course of fitting). + :param dict params: initial parameter values to condition on. This can be + useful forx + :return: initial optimizer state. + """ assert isinstance(model_args, tuple) assert isinstance(guide_args, tuple) model_init, guide_init = _seed(model, guide, rng) @@ -29,6 +58,18 @@ def init_fn(rng, model_args=(), guide_args=(), params=None): return optim_init(params) def update_fn(i, opt_state, rng, model_args=(), guide_args=()): + """ + Take a single step of SVI (possibly on a batch / minibatch of data), + using the optimizer. + + :param int i: represents the i'th iteration over the epoch, passed as an + argument to the optimizer's update function. + :param opt_state: current optimizer state. + :param jax.random.PRNGKey rng: random number generator seed. + :param tuple model_args: dynamic arguments to the model. + :param tuple guide_args: dynamic arguments to the guide. + :return: tuple of `(loss_val, opt_state, rng)`. + """ model_init, guide_init = _seed(model, guide, rng) params = get_params(opt_state) loss_val, grads = value_and_grad(loss)(params, model_init, guide_init, model_args, guide_args, kwargs) @@ -37,20 +78,54 @@ def update_fn(i, opt_state, rng, model_args=(), guide_args=()): return loss_val, opt_state, rng def evaluate(opt_state, rng, model_args=(), guide_args=()): + """ + Take a single step of SVI (possibly on a batch / minibatch of data). + + :param opt_state: current optimizer state. + :param jax.random.PRNGKey rng: random number generator seed. + :param tuple model_args: arguments to the model (these can possibly vary during + the course of fitting). + :param tuple guide_args: arguments to the guide (these can possibly vary during + the course of fitting). + :return: evaluate ELBo loss given the current parameter values + (held within `opt_state`). + """ model_init, guide_init = _seed(model, guide, rng) params = get_params(opt_state) return loss(params, model_init, guide_init, model_args, guide_args, kwargs) + # Make local functions visible from the global scope once + # `svi` is called for sphinx doc generation. + if 'SPHINX_BUILD' in os.environ: + svi.init_fn = init_fn + svi.update_fn = update_fn + svi.evaluate = evaluate + return init_fn, update_fn, evaluate -# This is a basic implementation of the Evidence Lower Bound, which is the -# fundamental objective in Variational Inference. -# See http://pyro.ai/examples/svi_part_i.html for details. -# This implementation has various limitations (for example it only supports -# random variablbes with reparameterized samplers), but all the ELBO -# implementations in Pyro share the same basic logic. def elbo(param_map, model, guide, model_args, guide_args, kwargs): + """ + This is the most basic implementation of the Evidence Lower Bound, which is the + fundamental objective in Variational Inference. This implementation has various + limitations (for example it only supports random variablbes with reparameterized + samplers) but can be used as a template to build more sophisticated loss + objectives. + + For more details, refer to http://pyro.ai/examples/svi_part_i.html. + + :param dict param_map: dictionary of current parameter values keyed by site + name. + :param model: Python callable with Pyro primitives for the model. + :param guide: Python callable with Pyro primitives for the guide + (recognition network). + :param tuple model_args: arguments to the model (these can possibly vary during + the course of fitting). + :param tuple guide_args: arguments to the guide (these can possibly vary during + the course of fitting). + :param dict kwargs: static keyword arguments to the model / guide. + :return: negative of the Evidence Lower Bound (ELBo) to be minimized. + """ guide_log_density, guide_trace = log_density(guide, guide_args, kwargs, param_map) model_log_density, _ = log_density(replay(model, guide_trace), model_args, kwargs, param_map) # log p(z) - log q(z)