Skip to content

Commit

Permalink
Merge warmup update and sample kernel (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Jun 10, 2019
1 parent 3151fab commit cc5aaf1
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 90 deletions.
8 changes: 4 additions & 4 deletions examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ def benchmark_hmc(args, features, labels):
init_kernel, sample_kernel = hmc(potential_fn, algo=args.algo)
t0 = time.time()
# TODO: Use init_params from `initialize_model` instead of fixed params.
hmc_state, _, _ = init_kernel(init_params, num_warmup=0, step_size=step_size,
trajectory_length=trajectory_length,
adapt_step_size=False, run_warmup=False)
hmc_state = init_kernel(init_params, num_warmup=0, step_size=step_size,
trajectory_length=trajectory_length,
adapt_step_size=False, run_warmup=False)
t1 = time.time()
print("time for hmc_init: ", t1 - t0)

def transform(state): return {'coefs': state.z['coefs'],
'num_steps': state.num_steps}

hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=transform)
hmc_states = fori_collect(0, args.num_samples, sample_kernel, hmc_state, transform=transform)
num_leapfrogs = np.sum(hmc_states['num_steps'])
print('number of leapfrog steps: ', num_leapfrogs)
print('avg. time for each step: ', (time.time() - t1) / num_leapfrogs)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main(args):
init_params, potential_fn, constrain_fn = initialize_model(init_rng, model, returns)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup, rng=sample_rng)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
hmc_states = fori_collect(0, args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: constrain_fn(hmc_state.z))
print_results(hmc_states, dates)

Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run_inference(dept, male, applications, admit, rng, args):
rng, glmm, dept, male, applications, admit)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
hmc_states = fori_collect(0, args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: constrain_fn(hmc_state.z))
return hmc_states

Expand Down
99 changes: 53 additions & 46 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from numpyro.util import cond, fori_collect, fori_loop, identity

HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'step_size', 'inverse_mass_matrix', 'mass_matrix_sqrt',
'rng'])
'mean_accept_prob', 'adapt_state', 'rng'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
Expand All @@ -30,10 +29,16 @@
does not correspond to the proposal if it is rejected.
- **mean_accept_prob** - Mean acceptance probability until current iteration
during warmup adaptation or sampling (for diagnostics).
- **step_size** - Step size to be used by the integrator in the next iteration.
This is adapted during warmup.
- **inverse_mass_matrix** - The inverse mass matrix to be be used for the next
iteration. This is adapted during warmup.
- **adapt_state** - A ``AdaptState`` namedtuple which contains adaptation information
during warmup:
+ **step_size** - Step size to be used by the integrator in the next iteration.
+ **inverse_mass_matrix** - The inverse mass matrix to be used for the next
iteration.
+ **mass_matrix_sqrt** - The square root of mass matrix to be used for the next
iteration. In case of dense mass, this is the Cholesky factorization of the
mass matrix.
- **rng** - random number generator seed used for the iteration.
"""

Expand Down Expand Up @@ -80,7 +85,7 @@ def _euclidean_ke(inverse_mass_matrix, r):

def get_diagnostics_str(hmc_state):
return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(hmc_state.num_steps,
hmc_state.step_size,
hmc_state.adapt_state.step_size,
hmc_state.mean_accept_prob)


Expand Down Expand Up @@ -143,7 +148,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
>>> hmc_state = init_kernel(init_params,
... trajectory_length=10,
... num_warmup=300)
>>> samples = fori_collect(500, sample_kernel, hmc_state,
>>> samples = fori_collect(0, 500, sample_kernel, hmc_state,
... transform=lambda state: constrain_fn(state.z))
>>> print(np.mean(samples['beta'], axis=0)) # doctest: +SKIP
[0.9153987 2.0754058 2.9621222]
Expand All @@ -155,6 +160,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
max_treedepth = None
momentum_generator = None
wa_update = None
wa_steps = None
if algo not in {'HMC', 'NUTS'}:
raise ValueError('`algo` must be one of `HMC` or `NUTS`.')

Expand Down Expand Up @@ -206,7 +212,8 @@ def init_kernel(init_params,
randomness.
"""
step_size = float(step_size)
nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth
nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
wa_steps = num_warmup
trajectory_len = float(trajectory_length)
max_treedepth = max_tree_depth
z = init_params
Expand All @@ -229,40 +236,24 @@ def init_kernel(init_params,
r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
vv_state = vv_init(z, r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
wa_state.step_size, wa_state.inverse_mass_matrix, wa_state.mass_matrix_sqrt,
rng_hmc)
wa_state, rng_hmc)

wa_update = jit(wa_update)
if run_warmup:
if run_warmup and num_warmup > 0:
# JIT if progress bar updates not required
if not progbar:
# PERF: jitting the for loop may be faster on certain models or high
# number of samples.
# TODO: remove this condition when the issue is resolved
# TODO: keep jit version and remove non-jit version for the next jax release
if progbar is None: # NB: if progbar=None, we jit fori_loop
hmc_state, _ = jit(fori_loop, static_argnums=(2,))(0, num_warmup, warmup_update,
(hmc_state, wa_state))
hmc_state = jit(fori_loop, static_argnums=(2,))(
0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
else:
hmc_state, _ = fori_loop(0, num_warmup, warmup_update, (hmc_state, wa_state))
hmc_state = fori_loop(
0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
else:
with tqdm.trange(num_warmup, desc='warmup') as t:
for i in t:
hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state))
hmc_state = sample_kernel(hmc_state)
t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
# Reset `i` and `mean_accept_prob` for fresh diagnostics.
hmc_state.update(i=0, mean_accept_prob=0)
return hmc_state
else:
return hmc_state, wa_state, warmup_update

def warmup_update(t, states):
hmc_state, wa_state = states
hmc_state = sample_kernel(hmc_state)
wa_state = wa_update(t, hmc_state.accept_prob, hmc_state.z, wa_state)
hmc_state = hmc_state.update(step_size=wa_state.step_size,
inverse_mass_matrix=wa_state.inverse_mass_matrix,
mass_matrix_sqrt=wa_state.mass_matrix_sqrt)
return hmc_state, wa_state
return hmc_state

def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
num_steps = _get_num_steps(step_size, trajectory_len)
Expand Down Expand Up @@ -304,16 +295,26 @@ def sample_kernel(hmc_state):
Hamiltonian dynamics given existing state.
"""
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.mass_matrix_sqrt, rng_momentum)
r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
vv_state, num_steps, accept_prob = _next(hmc_state.step_size,
hmc_state.inverse_mass_matrix,
vv_state, num_steps, accept_prob = _next(hmc_state.adapt_state.step_size,
hmc_state.adapt_state.inverse_mass_matrix,
vv_state, rng_transition)
# not update adapt_state after warmup phase
adapt_state = cond(hmc_state.i < wa_steps,
(hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state),
lambda args: wa_update(*args),
hmc_state.adapt_state,
lambda x: x)

itr = hmc_state.i + 1
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / itr
n = np.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
# Reset `mean_accept_prob` for fresh diagnostics.
mean_accept_prob = np.where(hmc_state.i == wa_steps, 0., hmc_state.mean_accept_prob)
mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n

return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, mean_accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix,
hmc_state.mass_matrix_sqrt, rng)
accept_prob, mean_accept_prob, adapt_state, rng)

# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
Expand Down Expand Up @@ -398,12 +399,18 @@ def mcmc(num_warmup, num_samples, init_params, sampler='hmc',
progbar = sampler_kwargs.pop('progbar', True)

init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo)
hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, **sampler_kwargs)
samples = fori_collect(num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z),
progbar=progbar,
diagnostics_fn=get_diagnostics_str,
progbar_desc='sample')
if progbar:
hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, **sampler_kwargs)
samples = fori_collect(0, num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z),
progbar=progbar,
diagnostics_fn=get_diagnostics_str,
progbar_desc='sample')
else:
hmc_state = init_kernel(init_params, num_warmup, run_warmup=False, **sampler_kwargs)
samples = fori_collect(num_warmup, num_warmup + num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z),
progbar=progbar)
if print_summary:
summary(samples)
return samples
Expand Down
29 changes: 16 additions & 13 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,23 @@ def identity(x):
return x


def fori_collect(n, body_fun, init_val, transform=identity, progbar=True, **progbar_opts):
def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=True, **progbar_opts):
"""
This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
effect of collecting values from the loop body. In addition, this allows for
post-processing of these samples via `transform`, and progress bar updates.
Note that, in some cases, `progbar=False` can be faster, when collecting a
Note that, `progbar=False` will be faster, especially when collecting a
lot of samples. Refer to example usage in :func:`~numpyro.mcmc.hmc`.
:param int n: number of times to run the loop body.
:param int lower: the index to start the collective work. In other words,
we will skip collecting the first `lower` values.
:param int upper: number of times to run the loop body.
:param body_fun: a callable that takes a collection of
`np.ndarray` and returns a collection with the same shape and
`dtype`.
:param init_val: initial value to pass as argument to `body_fun`. Can
be any Python collection type containing `np.ndarray` objects.
:param transform: A callable
:param transform: a callable to post-process the values returned by `body_fn`.
:param progbar: whether to post progress bar updates.
:param `**progbar_opts`: optional additional progress bar arguments. A
`diagnostics_fn` can be supplied which when passed the current value
Expand All @@ -124,36 +126,37 @@ def fori_collect(n, body_fun, init_val, transform=identity, progbar=True, **prog
:return: collection with the same type as `init_val` with values
collected along the leading axis of `np.ndarray` objects.
"""
assert lower < upper
init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
ravel_fn = lambda x: ravel_pytree(transform(x))[0] # noqa: E731

if not progbar:
collection = np.zeros((n,) + init_val_flat.shape)
collection = np.zeros((upper - lower,) + init_val_flat.shape)

def _body_fn(i, vals):
val, collection = vals
val = body_fun(val)
i = np.where(i >= lower, i - lower, 0)
collection = ops.index_update(collection, i, ravel_fn(val))
return val, collection

# PERF: jitting the for loop may be faster on certain models or high
# number of samples.
# TODO: remove this condition when the issue is resolved
# TODO: keep jit version and remove non-jit version for the next jax release
if progbar is None: # NB: if progbar=None, we jit fori_loop
_, collection = jit(fori_loop, static_argnums=(2,))(0, n, _body_fn,
_, collection = jit(fori_loop, static_argnums=(2,))(0, upper, _body_fn,
(init_val, collection))
else:
_, collection = fori_loop(0, n, _body_fn, (init_val, collection))
_, collection = fori_loop(0, upper, _body_fn, (init_val, collection))
else:
diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
progbar_desc = progbar_opts.pop('progbar_desc', '')
collection = []

val = init_val
with tqdm.trange(n, desc=progbar_desc) as t:
for _ in t:
with tqdm.trange(upper, desc=progbar_desc) as t:
for i in t:
val = body_fun(val)
collection.append(jit(ravel_fn)(val))
if i >= lower:
collection.append(jit(ravel_fn)(val))
if diagnostics_fn:
t.set_postfix_str(diagnostics_fn(val), refresh=False)

Expand Down

0 comments on commit cc5aaf1

Please sign in to comment.