Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MetropolisHastings algorithm to example of MCMCKernel (#680)
- Loading branch information
Showing
1 changed file
with
60 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,8 +37,10 @@ def get_diagnostics_str(mcmc_state): | |
return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(mcmc_state.num_steps, | ||
mcmc_state.adapt_state.step_size, | ||
mcmc_state.mean_accept_prob) | ||
else: | ||
elif hasattr(mcmc_state, "mean_accept_prob"): | ||
return 'acc. prob={:.2f}'.format(mcmc_state.mean_accept_prob) | ||
else: | ||
return '' | ||
|
||
|
||
def get_progbar_desc_str(num_warmup, i): | ||
|
@@ -51,6 +53,47 @@ class MCMCKernel(ABC): | |
""" | ||
Defines the interface for the Markov transition kernel that is | ||
used for :class:`~numpyro.infer.MCMC` inference. | ||
If the MCMC state is a namedtuple with `z` field, the | ||
method :meth:`MCMC.get_samples()` will return the result in `z` field. | ||
Otherwise, that method will return the collection of full states. | ||
**Example:** | ||
.. doctest:: | ||
>>> from collections import namedtuple | ||
>>> from jax import random | ||
>>> import jax.numpy as jnp | ||
>>> import numpyro | ||
>>> import numpyro.distributions as dist | ||
>>> from numpyro.infer import MCMC | ||
>>> MHState = namedtuple("MHState", ["z", "rng_key"]) | ||
>>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel): | ||
... def __init__(self, potential_fn, step_size=0.1): | ||
... self.potential_fn = potential_fn | ||
... self.step_size = step_size | ||
... | ||
... def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): | ||
... return MHState(init_params, rng_key) | ||
... | ||
... def sample(self, state, model_args, model_kwargs): | ||
... z, rng_key = state | ||
... rng_key, key_proposal, key_accept = random.split(rng_key, 3) | ||
... z_proposal = dist.Normal(z, self.step_size).sample(key_proposal) | ||
... accept_prob = jnp.exp(self.potential_fn(z) - self.potential_fn(z_proposal)) | ||
... z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z) | ||
... return MHState(z_new, rng_key) | ||
>>> def f(x): | ||
... return ((x - 2) ** 2).sum() | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
fehiepsi
Author
Member
|
||
>>> kernel = MetropolisHastings(f) | ||
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) | ||
>>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.])) | ||
>>> samples = mcmc.get_samples() | ||
""" | ||
def postprocess_fn(self, model_args, model_kwargs): | ||
""" | ||
|
@@ -125,7 +168,10 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs): | |
def _collect_fn(collect_fields): | ||
@cached_by(_collect_fn, collect_fields) | ||
def collect(x): | ||
return attrgetter(*collect_fields)(x[0]) | ||
if collect_fields: | ||
return attrgetter(*collect_fields)(x[0]) | ||
else: | ||
return x[0] | ||
|
||
return collect | ||
|
||
|
@@ -252,6 +298,15 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col | |
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) | ||
lower_idx = self._collection_params["lower"] | ||
upper_idx = self._collection_params["upper"] | ||
# filter out fields not available in init_state | ||
avail_fields = [] | ||
for field in collect_fields: | ||
try: | ||
attrgetter(field)(init_state) | ||
avail_fields.append(field) | ||
except AttributeError: | ||
pass | ||
collect_fields = tuple(avail_fields) if 'z' in avail_fields else () | ||
|
||
collect_vals = fori_collect(lower_idx, | ||
upper_idx, | ||
|
@@ -266,7 +321,9 @@ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, col | |
states, last_val = collect_vals | ||
# Get first argument of type `HMCState` | ||
last_state = last_val[0] | ||
if len(collect_fields) == 1: | ||
if len(collect_fields) <= 1: | ||
# if collect_fields == (), we put the result in `z` field | ||
collect_fields = ('z',) | ||
states = (states,) | ||
states = dict(zip(collect_fields, states)) | ||
# Apply constraints if number of samples is non-zero | ||
|
Would it make sense to change this to have the potential function be the log-loss, instead of the negative log loss? I was testing my understanding by changing this to a uniform distribution and initially assumed that
would be the way to specify it.