Skip to content

Commit

Permalink
Add support for local variables in RenyiELBO (#1608)
Browse files Browse the repository at this point in the history
* add support for local variables in renyielbo

* remove deprecated default kwarg in jaxns constructor

* lint

* allow users to specify same plate indices across particles

* fix failing jaxns test

* address comment and add more throughout tests for renyi elbo

* add nonnested plate test for renyi elbo

* use scale[0] instead of scale.mean() to save a bit of time

* Make CI work again

* remove supporting mc pre renyi

* run black

* lint

* fix beta bernoulli test

* address comments

* fix nonnested test because we donot allow nonnested subsample plates now

* revise error message in renyi
  • Loading branch information
fehiepsi committed Jul 10, 2023
1 parent 428dee9 commit 4799b84
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 43 deletions.
23 changes: 12 additions & 11 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ class seed(Messenger):
:param fn: Python callable with NumPyro primitives.
:param rng_seed: a random number generator seed.
:type rng_seed: int, jnp.ndarray scalar, or jax.random.PRNGKey
:param list hide_types: an optional list of side types to skip seeding, e.g. ['plate'].
.. note::
Expand Down Expand Up @@ -703,7 +704,7 @@ class seed(Messenger):
>>> assert x == y
"""

def __init__(self, fn=None, rng_seed=None):
def __init__(self, fn=None, rng_seed=None, hide_types=None):
if isinstance(rng_seed, int) or (
isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed)
):
Expand All @@ -715,19 +716,19 @@ def __init__(self, fn=None, rng_seed=None):
):
raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed)))
self.rng_key = rng_seed
self.hide_types = [] if hide_types is None else hide_types
super(seed, self).__init__(fn)

def process_message(self, msg):
if (
msg["type"] == "sample"
and not msg["is_observed"]
and msg["kwargs"]["rng_key"] is None
) or msg["type"] in ["prng_key", "plate", "control_flow"]:
if msg["value"] is not None:
# no need to create a new key when value is available
return
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg["kwargs"]["rng_key"] = rng_key_sample
if msg["type"] in self.hide_types:
return
if msg["type"] not in ["sample", "prng_key", "plate", "control_flow"]:
return
if (msg["kwargs"]["rng_key"] is not None) or (msg["value"] is not None):
# no need to create a new key when value is available
return
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg["kwargs"]["rng_key"] = rng_key_sample


class substitute(Messenger):
Expand Down
137 changes: 110 additions & 27 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,27 @@ class RenyiELBO(ELBO):
:param num_particles: The number of particles/samples
used to form the objective (gradient) estimator. Default is 2.
Example::
def model(data):
with numpyro.plate("batch", 10000, subsample_size=100):
latent = numpyro.sample("latent", dist.Normal(0, 1))
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("data", dist.Bernoulli(logits=latent), obs=batch)
def guide(data):
w_loc = numpyro.param("w_loc", 1.)
w_scale = numpyro.param("w_scale", 1.)
with numpyro.plate("batch", 10000, subsample_size=100):
batch = numpyro.subsample(data, event_dim=0)
loc = w_loc * batch
scale = jnp.exp(w_scale * batch)
numpyro.sample("latent", dist.Normal(loc, scale))
elbo = RenyiELBO(num_particles=10)
svi = SVI(model, guide, optax.adam(0.1), elbo)
**References:**
1. *Renyi Divergence Variational Inference*, Yingzhen Li, Richard E. Turner
Expand All @@ -327,37 +348,99 @@ def __init__(self, alpha=0, num_particles=2):
self.alpha = alpha
super().__init__(num_particles=num_particles)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
def single_particle_elbo(rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(
seeded_guide, args, kwargs, param_map
)
# NB: we only want to substitute params not available in guide_trace
model_param_map = {
k: v for k, v in param_map.items() if k not in guide_trace
}
seeded_model = replay(seeded_model, guide_trace)
model_log_density, model_trace = log_density(
seeded_model, args, kwargs, model_param_map
)
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="loose")
def _single_particle_elbo(self, model, guide, param_map, args, kwargs, rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
model_trace, guide_trace = get_importance_trace(
seeded_model, seeded_guide, args, kwargs, param_map
)
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="loose")

site_plates = {
name: {frame for frame in site["cond_indep_stack"]}
for name, site in model_trace.items()
if site["type"] == "sample"
}
# We will compute Renyi elbos separately across dimensions
# defined in indep_plates. Then the final elbo is the sum
# of those independent elbos.
if site_plates:
indep_plates = set.intersection(*site_plates.values())
else:
indep_plates = set()
for frame in set.union(*site_plates.values()):
if frame not in indep_plates:
subsample_size = frame.size
size = model_trace[frame.name]["args"][0]
if size > subsample_size:
raise ValueError(
"RenyiELBO only supports subsampling in plates that are common"
" to all sample sites, e.g. a data plate that encloses the"
" entire model."
)

# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
indep_plate_scale = 1.0
for frame in indep_plates:
subsample_size = frame.size
size = model_trace[frame.name]["args"][0]
if size > subsample_size:
indep_plate_scale = indep_plate_scale * size / subsample_size
indep_plate_dims = {frame.dim for frame in indep_plates}

log_densities = {}
for trace_type, tr in {"guide": guide_trace, "model": model_trace}.items():
log_densities[trace_type] = 0.0
for site in tr.values():
if site["type"] != "sample":
continue
log_prob = site["log_prob"]
squeeze_axes = ()
for dim in range(log_prob.ndim):
neg_dim = dim - log_prob.ndim
if neg_dim in indep_plate_dims:
continue
log_prob = jnp.sum(log_prob, axis=dim, keepdims=True)
squeeze_axes = squeeze_axes + (dim,)
log_prob = jnp.squeeze(log_prob, squeeze_axes)
log_densities[trace_type] = log_densities[trace_type] + log_prob

# log p(z) - log q(z)
elbo = log_densities["model"] - log_densities["guide"]
# Log probabilities at indep_plates dimensions are scaled to MC approximate
# the "full size" log probabilities. Because we want to compute Renyi elbos
# separately across indep_plates dimensions, we will remove such scale now.
# We will apply such scale after getting those Renyi elbos.
return elbo / indep_plate_scale, indep_plate_scale

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
plate_key, rng_key = random.split(rng_key)
model = seed(
model, plate_key, hide_types=["sample", "prng_key", "control_flow"]
)
guide = seed(
guide, plate_key, hide_types=["sample", "prng_key", "control_flow"]
)
single_particle_elbo = partial(
self._single_particle_elbo, model, guide, param_map, args, kwargs
)

rng_keys = random.split(rng_key, self.num_particles)
elbos = vmap(single_particle_elbo)(rng_keys)
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
assert common_plate_scale.shape == (self.num_particles,)
assert elbos.shape[0] == self.num_particles
scaled_elbos = (1.0 - self.alpha) * elbos
avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles)
avg_log_exp = logsumexp(scaled_elbos, axis=0) - jnp.log(self.num_particles)
assert avg_log_exp.shape == elbos.shape[1:]
weights = jnp.exp(scaled_elbos - avg_log_exp)
renyi_elbo = avg_log_exp / (1.0 - self.alpha)
weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles
return -(stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo)
weighted_elbo = (stop_gradient(weights) * elbos).mean(0)
assert renyi_elbo.shape == elbos.shape[1:]
assert weighted_elbo.shape == elbos.shape[1:]
loss = -(stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo)
# common_plate_scale should be the same across particles.
return loss.sum() * common_plate_scale[0]


def _get_plate_stacks(trace):
Expand Down Expand Up @@ -994,12 +1077,12 @@ def single_particle_elbo(rng_key):
for key in deps:
site = guide_trace[key]
if site["infer"].get("enumerate") == "parallel":
for plate in (
for p in (
frozenset(site["log_measure"].inputs) & elim_plates
):
raise ValueError(
"Expected model enumeration to be no more global than guide enumeration, but found "
f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')."
f"model enumeration sites upstream of guide site '{key}' in plate('{p}')."
"Try converting some model enumeration sites to guide enumeration sites."
)
cost_terms.append((cost, scale, deps))
Expand Down
119 changes: 114 additions & 5 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,116 @@ def renyi_loss_fn(x):
assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)


def test_renyi_local():
def model(subsample_size=None):
with numpyro.plate("N", 100, subsample_size=subsample_size):
numpyro.sample("x", dist.Normal(0, 1))
numpyro.sample("obs", dist.Bernoulli(0.6), obs=1)

def guide(subsample_size=None):
with numpyro.plate("N", 100, subsample_size=subsample_size):
numpyro.sample("x", dist.Normal(0, 1))

def renyi_loss_fn(subsample_size=None):
return RenyiELBO(num_particles=10).loss(
random.PRNGKey(0), {}, model, guide, subsample_size
)

# Test that the scales are applied correctly.
# Here for each particle, log_p - log_q = log(0.6)
full_loss = renyi_loss_fn()
subsample_loss = renyi_loss_fn(subsample_size=2)
assert_allclose(full_loss, -jnp.log(0.6) * 100, rtol=1e-6)
assert_allclose(subsample_loss, full_loss, rtol=1e-6)


def test_renyi_nonnested_plates():
def model():
with numpyro.plate("N", 10):
numpyro.sample("x", dist.Normal(0, 1))

with numpyro.plate("M", 10):
numpyro.sample("y", dist.Normal(0, 1))

def get_elbo():
renyi_elbo = RenyiELBO(num_particles=10)
return renyi_elbo._single_particle_elbo(
model,
model,
{},
(),
{},
random.PRNGKey(0),
)

elbo, _ = get_elbo()
assert elbo.shape == ()


@pytest.mark.parametrize(
"n,k",
[(3, 5), (2, 5), (3, 3), (2, 3)],
ids=str,
)
def test_renyi_create_plates(n, k):
P = 10
N, M, K = 3, 4, 5
data = jnp.linspace(0.1, 0.9, N * M * K).reshape((N, M, K))

def model(data, n=N, k=K, fix_indices=True):
with numpyro.plate("N", N, subsample_size=n, dim=-3):
with numpyro.plate("M", M, dim=-2):
with numpyro.plate("K", K, subsample_size=k, dim=-1):
if fix_indices:
batch = data[:n, :, :k]
else:
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("data", dist.Bernoulli(batch), obs=1)

def guide(data, n=N, k=K, fix_indices=True):
pass

def get_elbo(n=N, k=K, fix_indices=True):
renyi_elbo = RenyiELBO(num_particles=P)
return renyi_elbo._single_particle_elbo(
model,
guide,
{},
(data,),
dict(n=n, k=k, fix_indices=fix_indices),
random.PRNGKey(0),
)

def get_renyi(n=N, k=K, fix_indices=True):
renyi_elbo = RenyiELBO(num_particles=P)
return -renyi_elbo.loss(
random.PRNGKey(0), {}, model, guide, data, n=n, k=k, fix_indices=fix_indices
)

elbo, scale = get_elbo(n=n, k=k)
expected_shape = (n, M, k)
expected_scale = N * K / n / k
expected_elbo = jnp.log(data)[:n, :, :k]
assert elbo.shape == expected_shape
assert_allclose(scale, expected_scale, rtol=1e-6)
assert_allclose(elbo, expected_elbo, rtol=1e-6)

renyi = get_renyi(n=n, k=k)
assert_allclose(renyi, elbo.sum() * scale, rtol=1e-6)

if (n, k) == (2, 5):
renyi_random = get_renyi(n=2, fix_indices=False)
renyi_idx01 = jnp.log(data)[jnp.array([0, 1])].sum() * N / 2
renyi_idx02 = jnp.log(data)[jnp.array([0, 2])].sum() * N / 2
renyi_idx12 = jnp.log(data)[jnp.array([1, 2])].sum() * N / 2
atol = jnp.min(
jnp.abs(jnp.stack([renyi_idx01, renyi_idx02, renyi_idx12]) - renyi_random)
)
assert_allclose(atol, 0.0, atol=1e-5)


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
@pytest.mark.parametrize("optimizer", [optim.Adam(0.05), optimizers.adam(0.05)])
@pytest.mark.parametrize("optimizer", [optim.Adam(0.01), optimizers.adam(0.01)])
def test_beta_bernoulli(elbo, optimizer):
data = jnp.array([1.0] * 8 + [0.0] * 2)

Expand All @@ -85,13 +193,14 @@ def body_fn(i, val):
svi_state, _ = svi.update(val, data)
return svi_state

svi_state = fori_loop(0, 2000, body_fn, svi_state)
svi_state = fori_loop(0, 10000, body_fn, svi_state)
params = svi.get_params(svi_state)
actual_posterior_mean = (data.sum() + 1) / (data.shape[0] + 2)
assert_allclose(
params["alpha_q"] / (params["alpha_q"] + params["beta_q"]),
0.8,
atol=0.05,
rtol=0.05,
actual_posterior_mean,
atol=0.03,
rtol=0.03,
)


Expand Down

0 comments on commit 4799b84

Please sign in to comment.