Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MetropolisHastings algorithm to example of MCMCKernel #680

Merged
merged 2 commits into from Jul 15, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 60 additions & 3 deletions numpyro/infer/mcmc.py
Expand Up @@ -37,8 +37,10 @@ def get_diagnostics_str(mcmc_state):
return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(mcmc_state.num_steps,
mcmc_state.adapt_state.step_size,
mcmc_state.mean_accept_prob)
else:
elif hasattr(mcmc_state, "mean_accept_prob"):
return 'acc. prob={:.2f}'.format(mcmc_state.mean_accept_prob)
else:
return ''


def get_progbar_desc_str(num_warmup, i):
Expand All @@ -51,6 +53,47 @@ class MCMCKernel(ABC):
"""
Defines the interface for the Markov transition kernel that is
used for :class:`~numpyro.infer.MCMC` inference.

If the MCMC state is a namedtuple with `z` field, the
method :meth:`MCMC.get_samples()` will return the result in `z` field.
Otherwise, that method will return the collection of full states.

**Example:**

.. doctest::

>>> from collections import namedtuple
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC

>>> MHState = namedtuple("MHState", ["z", "rng_key"])

>>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
... def __init__(self, potential_fn, step_size=0.1):
... self.potential_fn = potential_fn
... self.step_size = step_size
...
... def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
... return MHState(init_params, rng_key)
...
... def sample(self, state, model_args, model_kwargs):
... z, rng_key = state
... rng_key, key_proposal, key_accept = random.split(rng_key, 3)
... z_proposal = dist.Normal(z, self.step_size).sample(key_proposal)
... accept_prob = jnp.exp(self.potential_fn(z) - self.potential_fn(z_proposal))
... z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z)
... return MHState(z_new, rng_key)

>>> def f(x):
... return ((x - 2) ** 2).sum()

>>> kernel = MetropolisHastings(f)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
>>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.]))
>>> samples = mcmc.get_samples()
"""
def postprocess_fn(self, model_args, model_kwargs):
"""
Expand Down Expand Up @@ -125,7 +168,10 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs):
def _collect_fn(collect_fields):
@cached_by(_collect_fn, collect_fields)
def collect(x):
return attrgetter(*collect_fields)(x[0])
if collect_fields:
return attrgetter(*collect_fields)(x[0])
else:
return x[0]

return collect

Expand Down Expand Up @@ -252,6 +298,15 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]
# filter out fields not available in init_state
avail_fields = []
for field in collect_fields:
try:
attrgetter(field)(init_state)
avail_fields.append(field)
except AttributeError:
pass
collect_fields = tuple(avail_fields) if 'z' in avail_fields else ()

collect_vals = fori_collect(lower_idx,
upper_idx,
Expand All @@ -266,7 +321,9 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col
states, last_val = collect_vals
# Get first argument of type `HMCState`
last_state = last_val[0]
if len(collect_fields) == 1:
if len(collect_fields) <= 1:
# if collect_fields == (), we put the result in `z` field
collect_fields = ('z',)
states = (states,)
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
Expand Down