Skip to content

Commit

Permalink
Support pickling MCMC objects with enumeration (#1577)
Browse files Browse the repository at this point in the history
* Remove inner functions

* Add tests

* Run black and isort

* Remove unncessary else
  • Loading branch information
tare committed Apr 19, 2023
1 parent 0589859 commit ffca0b8
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 22 deletions.
55 changes: 33 additions & 22 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def plate_to_enum_plate():
numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)


def _config_enumerate_fn(site, default):
"""helper function used internally in config_enumerate"""
if (
site["type"] == "sample"
and (not site["is_observed"])
and site["fn"].has_enumerate_support
):
return {"enumerate": site["infer"].get("enumerate", default)}
return {}


def config_enumerate(fn=None, default="parallel"):
"""
Configures enumeration for all relevant sites in a NumPyro model.
Expand Down Expand Up @@ -69,16 +80,18 @@ def model(*args, **kwargs):
if fn is None: # support use as a decorator
return functools.partial(config_enumerate, default=default)

def config_fn(site):
if (
site["type"] == "sample"
and (not site["is_observed"])
and site["fn"].has_enumerate_support
):
return {"enumerate": site["infer"].get("enumerate", default)}
return {}
return infer_config(fn, functools.partial(_config_enumerate_fn, default=default))

return infer_config(fn, config_fn)

def _config_kl_fn(site, sites):
"""helper function used internally in config_kl"""
if (
site["type"] == "sample"
and (not site["is_observed"])
and (sites is None or site["name"] in sites)
):
return {"kl": site["infer"].get("kl", "analytic")}
return {}


def config_kl(fn=None, sites=None):
Expand Down Expand Up @@ -107,16 +120,7 @@ def model(*args, **kwargs):
if fn is None: # support use as a decorator
return functools.partial(config_kl, sites=sites)

def config_fn(site):
if (
site["type"] == "sample"
and (not site["is_observed"])
and (sites is None or site["name"] in sites)
):
return {"kl": site["infer"].get("kl", "analytic")}
return {}

return infer_config(fn, config_fn)
return infer_config(fn, functools.partial(_config_kl_fn, sites=sites))


def _get_shift(name):
Expand Down Expand Up @@ -225,7 +229,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
if name.startswith("_time"):
time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
history = max(
history, max(_get_shift(s) for s in dim_to_name.values())
history,
max(_get_shift(s) for s in dim_to_name.values()),
)
if history == 0:
log_factors.append(log_prob_factor)
Expand Down Expand Up @@ -282,7 +287,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
raise ValueError(
"Expected the joint log density is a scalar, but got {}. "
"There seems to be something wrong at the following sites: {}.".format(
result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
result.data.shape,
{k.split("__BOUND")[0] for k in result.inputs},
)
)
return result, model_trace, log_measures
Expand Down Expand Up @@ -310,6 +316,11 @@ def model(*args, **kwargs):
:return: log of joint density and a corresponding model trace
"""
result, model_trace, _ = _enum_log_density(
model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
model,
model_args,
model_kwargs,
params,
funsor.ops.logaddexp,
funsor.ops.add,
)
return result.data, model_trace
61 changes: 61 additions & 0 deletions test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from jax.tree_util import tree_all, tree_map

import numpyro
from numpyro.contrib.funsor import config_kl
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.constraints import (
boolean,
circular,
Expand Down Expand Up @@ -49,6 +51,7 @@
DiscreteHMCGibbs,
MixedHMC,
Predictive,
TraceEnum_ELBO,
)
from numpyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoNormal

Expand All @@ -69,6 +72,19 @@ def logistic_regression():
numpyro.sample("obs", dist.Bernoulli(logits=x), obs=batch)


def gmm(data, K):
mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("num_clusters", K, dim=-1):
cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.0))
with numpyro.plate("data", data.shape[0], dim=-1):
assignments = numpyro.sample(
"assignments",
dist.Categorical(mix_proportions),
infer={"enumerate": "parallel"},
)
numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data)


@pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA])
def test_pickle_hmc(kernel):
mcmc = MCMC(kernel(normal_model), num_warmup=10, num_samples=10)
Expand All @@ -77,6 +93,24 @@ def test_pickle_hmc(kernel):
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))


@pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA])
def test_pickle_hmc_enumeration(kernel):
K, N = 3, 1000

true_cluster_means = jnp.array([1.0, 5.0, 10.0])
true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
cluster_assignments = dist.Categorical(true_mix_proportions).sample(
random.PRNGKey(0), (N,)
)
data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(
random.PRNGKey(1)
)
mcmc = MCMC(kernel(gmm), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), data, K)
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))


@pytest.mark.parametrize("kernel", [DiscreteHMCGibbs, MixedHMC])
def test_pickle_discrete_hmc(kernel):
mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10)
Expand Down Expand Up @@ -176,3 +210,30 @@ def test_mcmc_pickle_post_warmup():
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
pickled_mcmc.post_warmup_state = pickled_mcmc.last_state
pickled_mcmc.run(random.PRNGKey(1))


def bernoulli_regression(data):
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Bernoulli(f), obs=data)


def test_beta_bernoulli():
data = jnp.array([1.0] * 8 + [0.0] * 2)

def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

pickled_model = pickle.loads(pickle.dumps(config_kl(bernoulli_regression)))
optim = numpyro.optim.Adam(1e-2)
svi = SVI(config_kl(bernoulli_regression), guide, optim, TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 3, data)
params = svi_result.params

svi = SVI(pickled_model, guide, optim, TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 3, data)
pickled_params = svi_result.params

tree_all(tree_map(assert_allclose, params, pickled_params))

0 comments on commit ffca0b8

Please sign in to comment.