Skip to content

Commit

Permalink
Support for mutable params (#1016)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jun 23, 2021
1 parent 42763d4 commit 5bcea01
Show file tree
Hide file tree
Showing 14 changed files with 545 additions and 166 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -17,7 +17,7 @@ install: FORCE
pip install -e .[dev,doc,test,examples]

doctest: FORCE
$(MAKE) -C docs doctest
JAX_PLATFORM_NAME=cpu $(MAKE) -C docs doctest

test: lint FORCE
pytest -v test
Expand Down
3 changes: 2 additions & 1 deletion examples/covtype.py
Expand Up @@ -174,7 +174,8 @@ def benchmark_hmc(args, features, labels):
subsample_size = 1000
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
params, losses = svi.run(random.PRNGKey(2), 2000, features, labels)
svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
params, losses = svi_result.params, svi_result.losses
plt.plot(losses)
plt.show()

Expand Down
5 changes: 2 additions & 3 deletions examples/hmcecs.py
Expand Up @@ -50,9 +50,8 @@ def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
optimizer = numpyro.optim.Adam(step_size=1e-3)
guide = autoguide.AutoDelta(model)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
params, losses = svi.run(
svi_key, args.num_svi_steps, data, obs, args.subsample_size
)
svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size)
params, losses = svi_result.params, svi_result.losses
ref_params = {"theta": params["theta_auto_loc"]}

# taylor proxy estimates log likelihood (ll) by
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/control_flow/cond.py
Expand Up @@ -111,7 +111,7 @@ def cond(pred, true_fun, false_fun, operand):
... return cond(cluster > 0, true_fun, false_fun, None)
>>>
>>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
>>> params, losses = svi.run(random.PRNGKey(0), num_steps=2500)
>>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500)
.. warning:: This is an experimental utility function that allows users to use
JAX control flow with NumPyro's effect handlers. Currently, `sample` and
Expand Down
218 changes: 178 additions & 40 deletions numpyro/contrib/module.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions numpyro/handlers.py
Expand Up @@ -600,7 +600,7 @@ class scope(Messenger):
"""
This handler prepend a prefix followed by a divider to the name of sample sites.
Example::
**Example**
.. doctest::
Expand Down Expand Up @@ -745,7 +745,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg["type"] not in ("sample", "param", "plate")) or msg.get(
if (msg["type"] not in ("sample", "param", "mutable", "plate")) or msg.get(
"_control_flow_done", False
):
if msg["type"] == "control_flow":
Expand Down
195 changes: 124 additions & 71 deletions numpyro/infer/elbo.py
Expand Up @@ -15,7 +15,60 @@
from numpyro.infer.util import get_importance_trace, log_density


class Trace_ELBO:
class ELBO:
"""
Base class for all ELBO objectives.
Subclasses should implement either :meth:`loss` or :meth:`loss_with_mutable_state`.
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
"""

def __init__(self, num_particles=1):
self.num_particles = num_particles

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: negative of the Evidence Lower Bound (ELBO) to be minimized.
"""
return self.loss_with_mutable_state(
rng_key, param_map, model, guide, *args, **kwargs
)["loss"]

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
):
"""
Likes :meth:`loss` but also update and return the mutable state, which stores the
values at :func:`~numpyro.mutable` sites.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: a tuple of ELBO loss and the mutable state
"""
raise NotImplementedError("This ELBO objective does not support mutable state.")


class Trace_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI. The estimator is constructed
along the lines of references [1] and [2]. There are no restrictions on the
Expand Down Expand Up @@ -43,52 +96,56 @@ class Trace_ELBO:
def __init__(self, num_particles=1):
self.num_particles = num_particles

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: negative of the Evidence Lower Bound (ELBO) to be minimized.
"""

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
):
def single_particle_elbo(rng_key):
params = param_map.copy()
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(
seeded_guide, args, kwargs, param_map
)
mutable_params = {
name: site["value"]
for name, site in guide_trace.items()
if site["type"] == "mutable"
}
params.update(mutable_params)
seeded_model = replay(seeded_model, guide_trace)
model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)
model_log_density, model_trace = log_density(
seeded_model, args, kwargs, params
)
mutable_params.update(
{
name: site["value"]
for name, site in model_trace.items()
if site["type"] == "mutable"
}
)

# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
elbo_particle = model_log_density - guide_log_density
if mutable_params:
if self.num_particles == 1:
return elbo_particle, mutable_params
else:
raise ValueError(
"Currently, we only support mutable states with num_particles=1."
)
else:
return elbo_particle, None

# Return (-elbo) since by convention we do gradient descent on a loss and
# the ELBO is a lower bound that needs to be maximized.
if self.num_particles == 1:
return -single_particle_elbo(rng_key)
elbo, mutable_state = single_particle_elbo(rng_key)
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))


class ELBO(Trace_ELBO):
def __init__(self, num_particles=1):
warnings.warn(
"Using ELBO directly in SVI is deprecated. Please use Trace_ELBO class instead.",
FutureWarning,
)
super().__init__(num_particles=num_particles)
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


def _get_log_prob_sum(site):
Expand Down Expand Up @@ -128,7 +185,7 @@ def _check_mean_field_requirement(model_trace, guide_trace):
)


class TraceMeanField_ELBO(Trace_ELBO):
class TraceMeanField_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI. This is currently the only
ELBO estimator in NumPyro that uses analytic KL divergences when those
Expand All @@ -146,30 +203,31 @@ class TraceMeanField_ELBO(Trace_ELBO):
dependency structures.
"""

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: negative of the Evidence Lower Bound (ELBO) to be minimized.
"""

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
):
def single_particle_elbo(rng_key):
params = param_map.copy()
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
subs_guide = substitute(seeded_guide, data=param_map)
guide_trace = trace(subs_guide).get_trace(*args, **kwargs)
subs_model = substitute(replay(seeded_model, guide_trace), data=param_map)
mutable_params = {
name: site["value"]
for name, site in guide_trace.items()
if site["type"] == "mutable"
}
params.update(mutable_params)
subs_model = substitute(replay(seeded_model, guide_trace), data=params)
model_trace = trace(subs_model).get_trace(*args, **kwargs)
mutable_params.update(
{
name: site["value"]
for name, site in model_trace.items()
if site["type"] == "mutable"
}
)
_check_mean_field_requirement(model_trace, guide_trace)

elbo_particle = 0
Expand All @@ -196,16 +254,26 @@ def single_particle_elbo(rng_key):
assert site["infer"].get("is_auxiliary")
elbo_particle = elbo_particle - _get_log_prob_sum(site)

return elbo_particle
if mutable_params:
if self.num_particles == 1:
return elbo_particle, mutable_params
else:
raise ValueError(
"Currently, we only support mutable states with num_particles=1."
)
else:
return elbo_particle, None

if self.num_particles == 1:
return -single_particle_elbo(rng_key)
elbo, mutable_state = single_particle_elbo(rng_key)
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


class RenyiELBO(Trace_ELBO):
class RenyiELBO(ELBO):
r"""
An implementation of Renyi's :math:`\alpha`-divergence
variational inference following reference [1].
Expand Down Expand Up @@ -235,24 +303,9 @@ def __init__(self, alpha=0, num_particles=2):
"for the case alpha = 1."
)
self.alpha = alpha
super(RenyiELBO, self).__init__(num_particles=num_particles)
super().__init__(num_particles=num_particles)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
"""

def single_particle_elbo(rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
Expand Down Expand Up @@ -458,7 +511,7 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes):
return downstream_costs, downstream_guide_cost_nodes


class TraceGraph_ELBO:
class TraceGraph_ELBO(ELBO):
"""
A TraceGraph implementation of ELBO-based SVI. The gradient estimator
is constructed along the lines of reference [1] specialized to the case
Expand All @@ -479,7 +532,7 @@ class TraceGraph_ELBO:
"""

def __init__(self, num_particles=1):
self.num_particles = num_particles
super().__init__(num_particles=num_particles)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Expand Down

0 comments on commit 5bcea01

Please sign in to comment.