Skip to content

Commit

Permalink
Add auto-batched (low-rank) multivariate normal guides. (#1737)
Browse files Browse the repository at this point in the history
* Add `ReshapeTransform`.

* Add `AutoBatchedMultivariateNormal`.

* Refactor to use `AutoBatchedMixin`.

* Add `AutoLowRankMultivariateNormal`.

* Fix import order.

* Disable batching along event dimensions.
  • Loading branch information
tillahoffmann committed Feb 21, 2024
1 parent a92bd0d commit b35fcec
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 2 deletions.
59 changes: 59 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"LowerCholeskyAffine",
"PermuteTransform",
"PowerTransform",
"ReshapeTransform",
"SigmoidTransform",
"SimplexToOrderedTransform",
"SoftplusTransform",
Expand Down Expand Up @@ -1141,6 +1142,64 @@ def __eq__(self, other):
return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn


def _get_target_shape(shape, forward_shape, inverse_shape):
batch_ndims = len(shape) - len(inverse_shape)
return shape[:batch_ndims] + forward_shape


class ReshapeTransform(Transform):
"""
Reshape a sample, leaving batch dimensions unchanged.
:param forward_shape: Shape to transform the sample to.
:param inverse_shape: Shape of the sample for the inverse transform.
"""

domain = constraints.real
codomain = constraints.real

def __init__(self, forward_shape, inverse_shape) -> None:
forward_size = math.prod(forward_shape)
inverse_size = math.prod(inverse_shape)
if forward_size != inverse_size:
raise ValueError(
f"forward shape {forward_shape} (size {forward_size}) and inverse "
f"shape {inverse_shape} (size {inverse_size}) are not compatible"
)
self._forward_shape = forward_shape
self._inverse_shape = inverse_shape

def forward_shape(self, shape):
return _get_target_shape(shape, self._forward_shape, self._inverse_shape)

def inverse_shape(self, shape):
return _get_target_shape(shape, self._inverse_shape, self._forward_shape)

def __call__(self, x):
return jnp.reshape(x, self.forward_shape(jnp.shape(x)))

def _inverse(self, y):
return jnp.reshape(y, self.inverse_shape(jnp.shape(y)))

def log_abs_det_jacobian(self, x, y, intermediates=None):
return 0.0

def tree_flatten(self):
aux_data = {
"_forward_shape": self._forward_shape,
"_inverse_shape": self._inverse_shape,
}
return (), ((), aux_data)

def __eq__(self, other):
return (
isinstance(other, ReshapeTransform)
and self._forward_shape == other._forward_shape
and self._inverse_shape == other._inverse_shape
)



##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down
154 changes: 154 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from contextlib import ExitStack
from functools import partial
import math
import warnings

import numpy as np
Expand All @@ -29,6 +30,7 @@
IndependentTransform,
LowerCholeskyAffine,
PermuteTransform,
ReshapeTransform,
UnpackTransform,
biject_to,
)
Expand All @@ -50,6 +52,8 @@
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
"AutoBatchedLowRankMultivariateNormal",
"AutoBatchedMultivariateNormal",
"AutoContinuous",
"AutoGuide",
"AutoGuideList",
Expand Down Expand Up @@ -1808,6 +1812,106 @@ def quantiles(self, params, quantiles):
return self._unpack_and_constrain(latent, params)


class AutoBatchedMixin:
"""
Mixin to infer the batch and event shapes of batched auto guides.
"""

# Available from AutoContinuous.
latent_dim: int

def __init__(self, *args, **kwargs):
self._batch_shape = None
self._event_shape = None
# Pop the number of batch dimensions and pass the rest to the other constructor.
self.batch_ndim = kwargs.pop("batch_ndim")
super().__init__(*args, **kwargs)

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)

# Extract the batch shape.
batch_shape = None
for site in self.prototype_trace.values():
if site["type"] == "sample" and not site["is_observed"]:
shape = site["value"].shape
if site["value"].ndim < self.batch_ndim + site["fn"].event_dim:
raise ValueError(
f"Expected {self.batch_ndim} batch dimensions, but site "
f"`{site['name']}` only has shape {shape}."
)
shape = shape[:self.batch_ndim]
if batch_shape is None:
batch_shape = shape
elif shape != batch_shape:
raise ValueError("Encountered inconsistent batch shapes.")
self._batch_shape = batch_shape

# Save the event shape of the non-batched part. This will always be a vector.
batch_size = math.prod(self._batch_shape)
if self.latent_dim % batch_size:
raise RuntimeError(
f"Incompatible batch shape {batch_shape} (size {batch_size}) and "
f"latent dims {self.latent_dim}."
)
self._event_shape = (self.latent_dim // batch_size,)

def _get_batched_posterior(self):
raise NotImplementedError

def _get_posterior(self):
return dist.TransformedDistribution(
self._get_batched_posterior(),
ReshapeTransform((self.latent_dim,), self._batch_shape + self._event_shape),
)


class AutoBatchedMultivariateNormal(AutoBatchedMixin, AutoContinuous):
"""
This implementation of :class:`AutoContinuous` uses a batched MultivariateNormal
distribution to construct a guide over the entire latent space.
The guide does not depend on the model's ``*args, **kwargs``.
Usage::
guide = AutoBatchedMultivariateNormal(model, batch_ndim=1, ...)
svi = SVI(model, guide, ...)
"""

scale_tril_constraint = constraints.scaled_unit_lower_cholesky

def __init__(
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_uniform,
init_scale=0.1,
batch_ndim=1,
):
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape)
loc = numpyro.param("{}_loc".format(self.prefix), init_latent)
init_scale = (
jnp.ones(self._batch_shape + (1, 1))
* jnp.identity(init_latent.shape[-1])
* self._init_scale
)
scale_tril = numpyro.param(
"{}_scale_tril".format(self.prefix),
init_scale,
constraint=self.scale_tril_constraint,
)
return dist.MultivariateNormal(loc, scale_tril=scale_tril)


class AutoLowRankMultivariateNormal(AutoContinuous):
"""
This implementation of :class:`AutoContinuous` uses a LowRankMultivariateNormal
Expand Down Expand Up @@ -1886,6 +1990,56 @@ def quantiles(self, params, quantiles):
return self._unpack_and_constrain(latent, params)


class AutoBatchedLowRankMultivariateNormal(AutoBatchedMixin, AutoContinuous):
"""
This implementation of :class:`AutoContinuous` uses a batched
AutoLowRankMultivariateNormal distribution to construct a guide over the entire
latent space. The guide does not depend on the model's ``*args, **kwargs``.
Usage::
guide = AutoBatchedLowRankMultivariateNormal(model, batch_ndim=1, ...)
svi = SVI(model, guide, ...)
"""

scale_constraint = constraints.softplus_positive

def __init__(
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_uniform,
init_scale=0.1,
rank=None,
batch_ndim=1,
):
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
self.rank = rank
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
rank = int(round(self._event_shape[0]**0.5)) if self.rank is None else self.rank
init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape)
loc = numpyro.param("{}_loc".format(self.prefix), init_latent)
cov_factor = numpyro.param(
"{}_cov_factor".format(self.prefix),
jnp.zeros(self._batch_shape + self._event_shape + (rank,))
)
scale = numpyro.param(
"{}_scale".format(self.prefix),
jnp.full(self._batch_shape + self._event_shape, self._init_scale),
constraint=self.scale_constraint,
)
cov_diag = scale * scale
cov_factor = cov_factor * scale[..., None]
return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)


class AutoLaplaceApproximation(AutoContinuous):
r"""
Laplace approximation (quadratic approximation) approximates the posterior
Expand Down
62 changes: 61 additions & 1 deletion test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.testing import assert_allclose
import pytest

from jax import jacobian, jit, lax, random
from jax import jacobian, jit, lax, random, vmap
from jax.example_libraries.stax import Dense
import jax.numpy as jnp
from jax.tree_util import tree_all, tree_map
Expand All @@ -23,6 +23,8 @@
from numpyro.handlers import substitute
from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from numpyro.infer.autoguide import (
AutoBatchedLowRankMultivariateNormal,
AutoBatchedMultivariateNormal,
AutoBNAFNormal,
AutoDAIS,
AutoDelta,
Expand Down Expand Up @@ -1251,3 +1253,61 @@ def model():
assert_allclose(
samples["x"].mean(axis=0), jnp.arange(-5, 5), atol=0.2, rtol=0.1
)


@pytest.mark.parametrize(
"auto_class",
[
AutoBatchedMultivariateNormal,
AutoBatchedLowRankMultivariateNormal,
],
)
def test_auto_batched(auto_class) -> None:
# Model for batched multivariate normal.
off_diag = jnp.asarray([-0.2, 0, 0.5])
covs = off_diag[:, None, None] + jnp.eye(4)

def model():
with numpyro.plate("N", off_diag.shape[0]):
numpyro.sample("x", dist.MultivariateNormal(0, covs))

# Run inference.
guide = auto_class(model)
svi = SVI(model, guide, optax.adam(0.001), Trace_ELBO())
result = svi.run(random.PRNGKey(0), 10000)
samples = guide.sample_posterior(
random.PRNGKey(1), result.params, sample_shape=(1000,)
)

# Verify off-diagonal entries are correlated.
empirical_covs = vmap(jnp.cov)(jnp.moveaxis(samples["x"], 0, 2))
i, j = jnp.triu_indices(3, 1)
empirical_off_diag = empirical_covs[:, i, j].mean(axis=1)
corrcoef = jnp.corrcoef(off_diag, empirical_off_diag)[0, 1]
assert corrcoef > 0.99


@pytest.mark.parametrize(
"auto_class",
[
AutoBatchedMultivariateNormal,
AutoBatchedLowRankMultivariateNormal,
],
)
def test_auto_batched_shapes(auto_class) -> None:
def model(n, m):
distribution = dist.Normal().expand([7]).to_event(1)
with numpyro.plate("n", n):
x = numpyro.sample("x", distribution)
with numpyro.plate("m", m):
y = numpyro.sample("y", distribution)
return x, y

with numpyro.handlers.seed(rng_seed=0):
auto_class(model)(3, 3)

with pytest.raises(ValueError, match="inconsistent batch shapes"):
auto_class(model)(3, 4)

with pytest.raises(ValueError, match="Expected 2 batch dimensions"):
auto_class(model, batch_ndim=2)(3, 3)

0 comments on commit b35fcec

Please sign in to comment.