Skip to content

Commit

Permalink
Implement dense mass matrix adaptation (#170)
Browse files Browse the repository at this point in the history
* stash

* Add support for dense mass matrix estimation

* rebase with master; address comments; expose diagnostics docs

* use numerically stable version

* fix lint

* fix remaining tests

* fix flaky test

* Xfail flaky test
  • Loading branch information
neerajprad authored and fehiepsi committed May 29, 2019
1 parent 77b7042 commit c843cdc
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 60 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ __pycache__/
# data files
numpyro/examples/.data
examples/.results
examples/*.pdf
numpyro/.DS_Store

# test related
Expand Down
8 changes: 8 additions & 0 deletions docs/source/diagnostics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Diagnostics
===========

.. automodule:: numpyro.diagnostics
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ Numpyro documentation
svi


.. toctree::
:glob:
:maxdepth: 2
:caption: Diagnostics:

diagnostics


.. toctree::
:glob:
:maxdepth: 2
Expand Down
33 changes: 21 additions & 12 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as onp

from jax import device_get
from jax import device_get, tree_flatten


def _compute_chain_variance_stats(x):
Expand All @@ -28,7 +28,8 @@ def gelman_rubin(x):
It is required that ``input.shape[0] >= 2`` and ``input.shape[1] >= 2``.
:param numpy.ndarray x: the input array.
:returns numpy.ndarray: R-hat of ``x``.
:return: R-hat of ``x``.
:rtype: numpy.ndarray
"""
assert x.ndim >= 2
assert x.shape[0] >= 2
Expand All @@ -45,7 +46,8 @@ def split_gelman_rubin(x):
It is required that ``input.shape[1] >= 4``.
:param numpy.ndarray x: the input array.
:returns numpy.ndarray: split R-hat of ``x``.
:return: split R-hat of ``x``.
:rtype: numpy.ndarray
"""
assert x.ndim >= 2
assert x.shape[1] >= 4
Expand Down Expand Up @@ -80,7 +82,8 @@ def autocorrelation(x, axis=0):
:param numpy.array x: the input array.
:param int axis: the dimension to calculate autocorrelation.
:returns numpy.array: autocorrelation of ``x``.
:return: autocorrelation of ``x``.
:rtype: numpy.ndarray
"""
# Ref: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation
# Adapted from Stan implementation
Expand Down Expand Up @@ -115,7 +118,8 @@ def autocovariance(x, axis=0):
:param numpy.ndarray x: the input array.
:param int axis: the dimension to calculate autocovariance.
:returns numpy.ndarray: autocovariance of ``x``.
:return: autocovariance of ``x``.
:rtype: numpy.ndarray
"""
return autocorrelation(x, axis) * x.var(axis=axis, keepdims=True)

Expand All @@ -126,13 +130,15 @@ def effective_sample_size(x):
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
**References:**
[1] `Introduction to Markov Chain Monte Carlo`,
Charles J. Geyer
[2] `Stan Reference Manual version 2.18`,
Stan Development Team
1. *Introduction to Markov Chain Monte Carlo*,
Charles J. Geyer
2. *Stan Reference Manual version 2.18*,
Stan Development Team
:param numpy.ndarray x: the input array.
:returns numpy.ndarray: effective sample size of ``x``.
:return: effective sample size of ``x``.
:rtype: numpy.ndarray
"""
assert x.ndim >= 2
assert x.shape[1] >= 2
Expand Down Expand Up @@ -169,8 +175,9 @@ def hpdi(x, prob=0.89, axis=0):
:param numpy.ndarray x: the input array.
:param float prob: the probability mass of samples within the interval.
:param int axis: the dimension to calculate hpdi.
:returns numpy.ndarray: quantiles of ``input`` at ``(1 - probs) / 2`` and
:return: quantiles of ``input`` at ``(1 - probs) / 2`` and
``(1 + probs) / 2``.
:rtype: numpy.ndarray
"""
x = onp.swapaxes(x, axis, 0)
sorted_x = onp.sort(x, axis=0)
Expand All @@ -192,7 +199,7 @@ def summary(samples, prob=0.89):
"""
Prints a summary table for diagnostics of ``samples``.
:param numpy.ndarray samples: the input samples
:param samples: a collection of input samples.
:param float prob: the probability mass of samples within the HPDI interval.
"""
# FIXME: handle variable with str len > 20
Expand All @@ -204,6 +211,8 @@ def summary(samples, prob=0.89):

# FIXME: maybe allow a `digits` arg to set how many floatting points are needed?
row_format = '{:>20} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}'
if not isinstance(samples, dict):
samples = {'Param:{}'.format(i): v for i, v in enumerate(tree_flatten(samples)[0])}
# TODO: support summary for chains of samples
for name, value in samples.items():
value = device_get(value)
Expand Down
47 changes: 34 additions & 13 deletions numpyro/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from numpyro.util import cond, laxtuple, while_loop

AdaptWindow = laxtuple("AdaptWindow", ["start", "end"])
AdaptState = laxtuple("AdaptState", ["step_size", "inverse_mass_matrix", "ss_state", "mm_state",
"window_idx", "rng"])
AdaptState = laxtuple("AdaptState", ["step_size", "inverse_mass_matrix", "mass_matrix_sqrt",
"ss_state", "mm_state", "window_idx", "rng"])
IntegratorState = laxtuple("IntegratorState", ["z", "r", "potential_energy", "z_grad"])

_TreeInfo = laxtuple('_TreeInfo', ['z_left', 'r_left', 'z_left_grad',
Expand All @@ -22,6 +22,14 @@
'sum_accept_probs', 'num_proposals'])


def _cholesky_inverse(matrix):
# This formulation only takes the inverse of a triangular matrix
# which is more numerically stable.
# Refer to:
# https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
return np.linalg.inv(np.linalg.cholesky(matrix[::-1, ::-1])[::-1, ::-1]).T


def dual_averaging(t0=10, kappa=0.75, gamma=0.05):
"""
Dual Averaging is a scheme to solve convex optimization problems. It belongs
Expand Down Expand Up @@ -114,7 +122,11 @@ def final_fn(state, regularize=False):
cov = scaled_cov + shrinkage
else:
cov = scaled_cov + shrinkage * np.identity(mean.shape[0], dtype=mean.dtype)
return cov
if np.ndim(cov) == 2:
cov_inv_sqrt = _cholesky_inverse(cov)
else:
cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
return cov, cov_inv_sqrt

return init_fn, update_fn, final_fn

Expand Down Expand Up @@ -227,20 +239,26 @@ def _identity_step_size(inverse_mass_matrix, z, rng, step_size):

def warmup_adapter(num_adapt_steps, find_reasonable_step_size=_identity_step_size,
adapt_step_size=True, adapt_mass_matrix=True,
diag_mass=True, target_accept_prob=0.8):
dense_mass=False, target_accept_prob=0.8):
ss_init, ss_update = dual_averaging()
mm_init, mm_update, mm_final = welford_covariance(diagonal=diag_mass)
mm_init, mm_update, mm_final = welford_covariance(diagonal=not dense_mass)
adaptation_schedule = np.array(build_adaptation_schedule(num_adapt_steps))
num_windows = len(adaptation_schedule)

def init_fn(z, rng, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None):
rng, rng_ss = random.split(rng)
if inverse_mass_matrix is None:
assert mass_matrix_size is not None
if diag_mass:
if dense_mass:
inverse_mass_matrix = np.identity(mass_matrix_size)
else:
inverse_mass_matrix = np.ones(mass_matrix_size)
mass_matrix_sqrt = inverse_mass_matrix
else:
if dense_mass:
mass_matrix_sqrt = _cholesky_inverse(inverse_mass_matrix)
else:
inverse_mass_matrix = np.identity(mass_matrix_size)
mass_matrix_sqrt = np.sqrt(np.reciprocal(inverse_mass_matrix))

if adapt_step_size:
step_size = find_reasonable_step_size(inverse_mass_matrix, z, rng_ss, step_size)
Expand All @@ -249,23 +267,25 @@ def init_fn(z, rng, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=No
mm_state = mm_init(inverse_mass_matrix.shape[-1])

window_idx = 0
return AdaptState(step_size, inverse_mass_matrix, ss_state, mm_state, window_idx, rng)
return AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
ss_state, mm_state, window_idx, rng)

def _update_at_window_end(z, rng_ss, state):
step_size, inverse_mass_matrix, ss_state, mm_state, window_idx, rng = state
step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng = state

if adapt_mass_matrix:
inverse_mass_matrix = mm_final(mm_state, regularize=True)
inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state, regularize=True)
mm_state = mm_init(inverse_mass_matrix.shape[-1])

if adapt_step_size:
step_size = find_reasonable_step_size(inverse_mass_matrix, z, rng_ss, step_size)
ss_state = ss_init(np.log(10 * step_size))

return AdaptState(step_size, inverse_mass_matrix, ss_state, mm_state, window_idx, rng)
return AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
ss_state, mm_state, window_idx, rng)

def update_fn(t, accept_prob, z, state):
step_size, inverse_mass_matrix, ss_state, mm_state, window_idx, rng = state
step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng = state
rng, rng_ss = random.split(rng)

# update step size state
Expand All @@ -288,7 +308,8 @@ def update_fn(t, accept_prob, z, state):

t_at_window_end = t == adaptation_schedule[window_idx, 1]
window_idx = np.where(t_at_window_end, window_idx + 1, window_idx)
state = AdaptState(step_size, inverse_mass_matrix, ss_state, mm_state, window_idx, rng)
state = AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
ss_state, mm_state, window_idx, rng)
state = cond(t_at_window_end & is_middle_window,
(z, rng_ss, state), lambda args: _update_at_window_end(*args),
state, lambda x: x)
Expand Down
38 changes: 22 additions & 16 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from jax.random import PRNGKey
from jax.tree_util import register_pytree_node

import numpyro.distributions as dist
from numpyro.diagnostics import summary
from numpyro.hmc_util import IntegratorState, build_tree, find_reasonable_step_size, velocity_verlet, warmup_adapter
from numpyro.util import cond, fori_collect, fori_loop, identity

HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'step_size', 'inverse_mass_matrix', 'rng'])
'mean_accept_prob', 'step_size', 'inverse_mass_matrix', 'mass_matrix_sqrt',
'rng'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
- **i** - iteration. This is reset to 0 after warmup.
Expand Down Expand Up @@ -54,12 +54,16 @@ def _get_num_steps(step_size, trajectory_length):
return num_steps.astype(np.int64)


def _sample_momentum(unpack_fn, inverse_mass_matrix, rng):
if inverse_mass_matrix.ndim == 1:
r = dist.Normal(0., np.sqrt(np.reciprocal(inverse_mass_matrix))).sample(rng)
def _sample_momentum(unpack_fn, mass_matrix_sqrt, rng):
eps = random.normal(rng, np.shape(mass_matrix_sqrt)[:1])
if mass_matrix_sqrt.ndim == 1:
r = np.multiply(mass_matrix_sqrt, eps)
return unpack_fn(r)
elif inverse_mass_matrix.ndim == 2:
raise NotImplementedError
elif mass_matrix_sqrt.ndim == 2:
r = np.dot(mass_matrix_sqrt, eps)
return unpack_fn(r)
else:
raise ValueError("Mass matrix has incorrect number of dims.")


def _euclidean_ke(inverse_mass_matrix, r):
Expand Down Expand Up @@ -151,7 +155,7 @@ def init_kernel(init_params,
step_size=1.0,
adapt_step_size=True,
adapt_mass_matrix=True,
diag_mass=True,
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=2*math.pi,
max_tree_depth=10,
Expand All @@ -172,8 +176,8 @@ def init_kernel(init_params,
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool diag_mass: A flag to decide if mass matrix is diagonal (default)
or dense (if set to ``False``).
:param bool dense_mass: A flag to decide if mass matrix is dense or
diagonal (default when ``dense_mass=False``)
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
Expand Down Expand Up @@ -208,16 +212,17 @@ def init_kernel(init_params,
wa_init, wa_update = warmup_adapter(num_warmup,
adapt_step_size=adapt_step_size,
adapt_mass_matrix=adapt_mass_matrix,
diag_mass=diag_mass,
dense_mass=dense_mass,
target_accept_prob=target_accept_prob,
find_reasonable_step_size=find_reasonable_ss)

rng_hmc, rng_wa = random.split(rng)
wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
r = momentum_generator(wa_state.inverse_mass_matrix, rng)
r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
vv_state = vv_init(z, r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
wa_state.step_size, wa_state.inverse_mass_matrix, rng_hmc)
wa_state.step_size, wa_state.inverse_mass_matrix, wa_state.mass_matrix_sqrt,
rng_hmc)

wa_update = jit(wa_update)
if run_warmup:
Expand All @@ -243,7 +248,8 @@ def warmup_update(t, states):
hmc_state = sample_kernel(hmc_state)
wa_state = wa_update(t, hmc_state.accept_prob, hmc_state.z, wa_state)
hmc_state = hmc_state.update(step_size=wa_state.step_size,
inverse_mass_matrix=wa_state.inverse_mass_matrix)
inverse_mass_matrix=wa_state.inverse_mass_matrix,
mass_matrix_sqrt=wa_state.mass_matrix_sqrt)
return hmc_state, wa_state

def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
Expand Down Expand Up @@ -286,7 +292,7 @@ def sample_kernel(hmc_state):
Hamiltonian dynamics given existing state.
"""
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.inverse_mass_matrix, rng_momentum)
r = momentum_generator(hmc_state.mass_matrix_sqrt, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
vv_state, num_steps, accept_prob = _next(hmc_state.step_size,
hmc_state.inverse_mass_matrix,
Expand All @@ -295,7 +301,7 @@ def sample_kernel(hmc_state):
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / itr
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, mean_accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix,
rng)
hmc_state.mass_matrix_sqrt, rng)

# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
Expand Down

0 comments on commit c843cdc

Please sign in to comment.