Skip to content

Commit

Permalink
Add MetropolisHastings algorithm to example of MCMCKernel (#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jul 15, 2020
1 parent 6f6984a commit 0857645
Showing 1 changed file with 60 additions and 3 deletions.
63 changes: 60 additions & 3 deletions numpyro/infer/mcmc.py
Expand Up @@ -37,8 +37,10 @@ def get_diagnostics_str(mcmc_state):
return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(mcmc_state.num_steps,
mcmc_state.adapt_state.step_size,
mcmc_state.mean_accept_prob)
else:
elif hasattr(mcmc_state, "mean_accept_prob"):
return 'acc. prob={:.2f}'.format(mcmc_state.mean_accept_prob)
else:
return ''


def get_progbar_desc_str(num_warmup, i):
Expand All @@ -51,6 +53,47 @@ class MCMCKernel(ABC):
"""
Defines the interface for the Markov transition kernel that is
used for :class:`~numpyro.infer.MCMC` inference.
If the MCMC state is a namedtuple with `z` field, the
method :meth:`MCMC.get_samples()` will return the result in `z` field.
Otherwise, that method will return the collection of full states.
**Example:**
.. doctest::
>>> from collections import namedtuple
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC
>>> MHState = namedtuple("MHState", ["z", "rng_key"])
>>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
... def __init__(self, potential_fn, step_size=0.1):
... self.potential_fn = potential_fn
... self.step_size = step_size
...
... def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
... return MHState(init_params, rng_key)
...
... def sample(self, state, model_args, model_kwargs):
... z, rng_key = state
... rng_key, key_proposal, key_accept = random.split(rng_key, 3)
... z_proposal = dist.Normal(z, self.step_size).sample(key_proposal)
... accept_prob = jnp.exp(self.potential_fn(z) - self.potential_fn(z_proposal))
... z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z)
... return MHState(z_new, rng_key)
>>> def f(x):
... return ((x - 2) ** 2).sum()

This comment has been minimized.

Copy link
@aflaxman

aflaxman Dec 30, 2021

Would it make sense to change this to have the potential function be the log-loss, instead of the negative log loss? I was testing my understanding by changing this to a uniform distribution and initially assumed that

def f(x):
    return jnp.where(x <= 1, 
                     jnp.where(x>=0, 0, -jnp.inf),
                     -jnp.inf
                    )

would be the way to specify it.

This comment has been minimized.

Copy link
@fehiepsi

fehiepsi Dec 30, 2021

Author Member

This code assumes f is a potential function. If you want to use log prob, either you can provide: lambda x: -f(x) to the kerbel, or revise the implementation of the kernel to account for that.

This comment has been minimized.

Copy link
@aflaxman

aflaxman Dec 30, 2021

I see. I think I got mixed up with pymc2 notation, where there are "potential" objects that serve the same purpose as your numpyro.factor term. Thanks for the quick response, and for all of your work on this!

>>> kernel = MetropolisHastings(f)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
>>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.]))
>>> samples = mcmc.get_samples()
"""
def postprocess_fn(self, model_args, model_kwargs):
"""
Expand Down Expand Up @@ -125,7 +168,10 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs):
def _collect_fn(collect_fields):
@cached_by(_collect_fn, collect_fields)
def collect(x):
return attrgetter(*collect_fields)(x[0])
if collect_fields:
return attrgetter(*collect_fields)(x[0])
else:
return x[0]

return collect

Expand Down Expand Up @@ -252,6 +298,15 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]
# filter out fields not available in init_state
avail_fields = []
for field in collect_fields:
try:
attrgetter(field)(init_state)
avail_fields.append(field)
except AttributeError:
pass
collect_fields = tuple(avail_fields) if 'z' in avail_fields else ()

collect_vals = fori_collect(lower_idx,
upper_idx,
Expand All @@ -266,7 +321,9 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col
states, last_val = collect_vals
# Get first argument of type `HMCState`
last_state = last_val[0]
if len(collect_fields) == 1:
if len(collect_fields) <= 1:
# if collect_fields == (), we put the result in `z` field
collect_fields = ('z',)
states = (states,)
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
Expand Down

0 comments on commit 0857645

Please sign in to comment.