Skip to content

Commit

Permalink
Modify HMC documentation to show docstrings for internal functions (#140
Browse files Browse the repository at this point in the history
)

* Add sphinx and rtd support for Numpyro

* fix lint

* add missing docs

* Modify HMC documentation to show docstrings for internal functions
  • Loading branch information
neerajprad authored and fehiepsi committed May 9, 2019
1 parent 85a8a79 commit fed0bc6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 48 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ numpyro/examples/.data

# docs
docs/build
docs/.DS_Store
105 changes: 57 additions & 48 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
Hamiltonian Monte Carlo inference, using either fixed number of
steps or the No U-Turn Sampler (NUTS) with adaptive path length.
**References**
**References:**
[1] `MCMC Using Hamiltonian Dynamics`,
Radford M. Neal
[2] `The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo`,
Matthew D. Hoffman, and Andrew Gelman.
1. *MCMC Using Hamiltonian Dynamics*, Radford M. Neal
2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*,
Matthew D. Hoffman, and Andrew Gelman.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
Expand All @@ -63,6 +62,54 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
:return init_kernel, sample_kernel: Returns a tuple of callables, the first
one to initialize the sampler, and the second one to generate samples
given an existing one.
The arguments taken by `init_kernel` and `sample_kernel` are as follows:
.. function:: init_kernel
Initializes the HMC sampler.
:param init_samples: Initial parameters to begin sampling. The type can
must be consistent with the input type to ``potential_fn``.
:param int num_warmup_steps: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool diag_mass: A flag to decide if mass matrix is diagonal (default)
or dense (if set to ``False``).
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\pi`.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Defaults to 10.
:param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
`init_kernel` returns an initial :func:`~numpyro.mcmc.HMCState` that
can be used to generate samples using MCMC. Else, returns the arguments
and callable that does the initial adaptation.
:param bool progbar: Whether to enable progress bar updates. Defaults to
``True``.
:param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
step size is done at the beginning of each adaptation window to achieve
`target_acceptance_prob`.
:param jax.random.PRNGKey rng: random key to be used as the source of
randomness.
.. function:: sample_kernel
Given a :func:`~numpyro.mcmc.HMCState`, run HMC with fixed (possibly
adapted) step size and return :func:`~numpyro.mcmc.HMCState`.
:param hmc_state: Current sample (and associated state).
:return: new proposed :func:`~numpyro.mcmc.HMCState` from simulating
Hamiltonian dynamics given existing state.
"""
if kinetic_fn is None:
kinetic_fn = _euclidean_ke
Expand All @@ -85,41 +132,6 @@ def init_kernel(init_samples,
progbar=True,
heuristic_step_size=True,
rng=PRNGKey(0)):
r"""
Initializes the HMC sampler.
:param init_samples: Initial parameters to begin sampling. The type can
must be consistent with the input type to ``potential_fn``.
:param int num_warmup_steps: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool diag_mass: A flag to decide if mass matrix is diagonal (default)
or dense (if set to ``False``).
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\pi`.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Default to 10.
:param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
`init_kernel` returns an initial :func:`~numpyro.mcmc.HMCState` that
can be used to generate samples using MCMC. Else, returns the arguments
and callable that does the initial adaptation.
:param bool progbar: Whether to enable progress bar updates. Defaults to
``True``.
:param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
step size is done at the beginning of each adaptation window to achieve
`target_acceptance_prob`.
:param jax.random.PRNGKey rng: random key to be used as the source of
randomness.
"""
step_size = float(step_size)
nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth
trajectory_len = float(trajectory_length)
Expand Down Expand Up @@ -202,14 +214,6 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):

@jit
def sample_kernel(hmc_state):
r"""
Given a :func:`~numpyro.mcmc.HMCState`, run HMC with fixed (possibly
adapted) step size and return :func:`~numpyro.mcmc.HMCState`.
:param hmc_state: Current sample (and associated state).
:return: new proposed :func:`~numpyro.mcmc.HMCState` from simulating
Hamiltonian dynamics given existing state.
"""
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.inverse_mass_matrix, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
Expand All @@ -219,4 +223,9 @@ def sample_kernel(hmc_state):
return HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix, rng)

# populate docs for `init_kernel` and `sample_kernel`
component_docs = hmc.__doc__.split('.. function::')
init_kernel.__doc__ = '\n'.join(component_docs[1].split('\n')[1:])
sample_kernel.__doc__ = '\n'.join(component_docs[2].split('\n')[1:])

return init_kernel, sample_kernel

0 comments on commit fed0bc6

Please sign in to comment.