Skip to content

Commit

Permalink
Support multi sample guides in Trace_ELBO (#1666)
Browse files Browse the repository at this point in the history
* support multi sample guide

* support multi_sample_guide in Trace_ELBO

* Add docs for multi_sample_guide

* only make multi_sample_guide specific to Trace_ELBO

* fix test catch multable state warning

* validate model trace
  • Loading branch information
fehiepsi committed Nov 5, 2023
1 parent eaa29a0 commit d9b52d7
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 24 deletions.
101 changes: 82 additions & 19 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,23 @@ class ELBO:
"""
Determines whether the ELBO objective can support inference of discrete latent variables.
Subclasses that are capable of inferring discrete latent variables should override to `True`
Subclasses that are capable of inferring discrete latent variables should override to `True`.
"""
can_infer_discrete = False

def __init__(self, num_particles=1, vectorize_particles=True):
self.num_particles = num_particles
self.vectorize_particles = vectorize_particles

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
def loss(
self,
rng_key,
param_map,
model,
guide,
*args,
**kwargs,
):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
Expand Down Expand Up @@ -116,15 +124,30 @@ class Trace_ELBO(ELBO):
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
:param multi_sample_guide: Whether to make an assumption that the guide proposes
multiple samples.
"""

def __init__(
self, num_particles=1, vectorize_particles=True, multi_sample_guide=False
):
self.multi_sample_guide = multi_sample_guide
super().__init__(
num_particles=num_particles, vectorize_particles=vectorize_particles
)

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
self,
rng_key,
param_map,
model,
guide,
*args,
**kwargs,
):
def single_particle_elbo(rng_key):
params = param_map.copy()
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(
seeded_guide, args, kwargs, param_map
Expand All @@ -135,29 +158,68 @@ def single_particle_elbo(rng_key):
if site["type"] == "mutable"
}
params.update(mutable_params)
seeded_model = replay(seeded_model, guide_trace)
model_log_density, model_trace = log_density(
seeded_model, args, kwargs, params
)
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="loose")
mutable_params.update(
{
if self.multi_sample_guide:
plates = {
name: site["value"]
for name, site in model_trace.items()
if site["type"] == "mutable"
for name, site in guide_trace.items()
if site["type"] == "plate"
}
)

def get_model_density(key, latent):
with seed(rng_seed=key), substitute(data={**latent, **plates}):
model_log_density, model_trace = log_density(
model, args, kwargs, params
)
_validate_model(model_trace, plate_warning="loose")
return model_log_density

num_guide_samples = None
for name, site in guide_trace.items():
if site["type"] == "sample":
num_guide_samples = site["value"].shape[0]
break
if num_guide_samples is None:
raise ValueError("guide is missing `sample` sites.")
seeds = random.split(model_seed, num_guide_samples)
latents = {
name: site["value"]
for name, site in guide_trace.items()
if (site["type"] == "sample" and site["value"].size > 0)
or (site["type"] == "deterministic")
}
model_log_density = vmap(get_model_density)(seeds, latents)
assert model_log_density.ndim == 1
model_log_density = model_log_density.sum(0)
# log p(z) - log q(z)
elbo_particle = (model_log_density - guide_log_density) / seeds.shape[0]
else:
seeded_model = seed(model, model_seed)
replay_model = replay(seeded_model, guide_trace)
model_log_density, model_trace = log_density(
replay_model, args, kwargs, params
)
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="loose")
mutable_params.update(
{
name: site["value"]
for name, site in model_trace.items()
if site["type"] == "mutable"
}
)
# log p(z) - log q(z)
elbo_particle = model_log_density - guide_log_density

# log p(z) - log q(z)
elbo_particle = model_log_density - guide_log_density
if mutable_params:
if self.num_particles == 1:
return elbo_particle, mutable_params
else:
raise ValueError(
"Currently, we only support mutable states with num_particles=1."
warnings.warn(
"mutable state is currently ignored when num_particles > 1."
)
return elbo_particle, None
else:
return elbo_particle, None

Expand Down Expand Up @@ -288,9 +350,10 @@ def single_particle_elbo(rng_key):
if self.num_particles == 1:
return elbo_particle, mutable_params
else:
raise ValueError(
"Currently, we only support mutable states with num_particles=1."
warnings.warn(
"mutable state is currently ignored when num_particles > 1."
)
return elbo_particle, None
else:
return elbo_particle, None

Expand Down
25 changes: 21 additions & 4 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import tqdm

import jax
from jax import jit, lax, random
from jax.example_libraries import optimizers
import jax.numpy as jnp
Expand Down Expand Up @@ -189,9 +190,25 @@ def init(self, rng_key, *args, init_params=None, **kwargs):
}
if init_params is not None:
init_guide_params.update(init_params)
model_trace = trace(
substitute(replay(model_init, guide_trace), init_guide_params)
).get_trace(*args, **kwargs, **self.static_kwargs)
if getattr(self.loss, "multi_sample_guide", False):
latents = {
name: site["value"][0]
for name, site in guide_trace.items()
if site["type"] == "sample" and site["value"].size > 0
}
latents.update(init_guide_params)
with trace() as model_trace, substitute(data=latents):
model_init(*args, **kwargs, **self.static_kwargs)
for site in model_trace.values():
if site["type"] == "mutable":
raise ValueError(
"mutable state in model is not supported for "
"multi-sample guide."
)
else:
model_trace = trace(
substitute(replay(model_init, guide_trace), init_guide_params)
).get_trace(*args, **kwargs, **self.static_kwargs)

params = {}
inv_transforms = {}
Expand Down Expand Up @@ -363,7 +380,7 @@ def body_fn(svi_state, _):
batch = max(num_steps // 20, 1)
for i in t:
svi_state, loss = jit(body_fn)(svi_state, None)
losses.append(loss)
losses.append(jax.device_get(loss))
if i % batch == 0:
if stable_update:
valid_losses = [x for x in losses[i - batch :] if x == x]
Expand Down
21 changes: 20 additions & 1 deletion test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def guide():

svi = SVI(model, guide, optim.Adam(0.1), elbo(num_particles=num_particles))
if num_particles > 1:
with pytest.raises(ValueError, match="mutable state"):
with pytest.warns(UserWarning, match="mutable state"):
svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update)
return
svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update)
Expand Down Expand Up @@ -738,3 +738,22 @@ def guide(difficulty=0.0):

for i in range(3):
assert_allclose(max_errors[i], 0, atol=atol)


def test_multi_sample_guide():
actual_loc = 3.0
actual_scale = 2.0

def model():
numpyro.sample("x", dist.Normal(actual_loc, actual_scale))

def guide():
loc = numpyro.param("loc", 0.0)
scale = numpyro.param("scale", 1.0, constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale).expand([10]))

svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(multi_sample_guide=True))
svi_results = svi.run(random.PRNGKey(0), 2000)
params = svi_results.params
assert_allclose(params["loc"], actual_loc, rtol=0.1)
assert_allclose(params["scale"], actual_scale, rtol=0.1)

0 comments on commit d9b52d7

Please sign in to comment.