Skip to content

Commit

Permalink
Docs/stein mixtures (#1605)
Browse files Browse the repository at this point in the history
* Fixing docs for SteinVi.

* Added mixture_guide_predictive to einstein __init__.py. Added docs for MixtureGuidePredictive and SteinVI.

* added `language='en'` to `conf.py`

* removed docstring from `steinvi.run`
  • Loading branch information
OlaRonning committed Jun 15, 2023
1 parent 0b9e0f0 commit 7291cba
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = "en"

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
2 changes: 1 addition & 1 deletion examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def main(args):
guide=stein.guide,
params=stein.get_params(result.state),
num_samples=100,
guide_sites=stein.guide_param_names,
guide_sites=stein.guide_sites,
)
xte, _, _ = normalize(
data.xte, xtr_mean, xtr_std
Expand Down
2 changes: 1 addition & 1 deletion examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def main(args):
guide,
params=results.params,
num_samples=1,
guide_sites=steinvi.guide_param_names,
guide_sites=steinvi.guide_sites,
)
seqs, rev_seqs, lengths = load_data("valid")
pred_notes = pred(
Expand Down
2 changes: 2 additions & 0 deletions numpyro/contrib/einstein/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from numpyro.contrib.einstein.stein_kernels import (
GraphicalKernel,
IMQKernel,
Expand All @@ -23,4 +24,5 @@
"GraphicalKernel",
"MixtureKernel",
"ProbabilityProductKernel",
"MixtureGuidePredictive",
]
28 changes: 21 additions & 7 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@


class MixtureGuidePredictive:
"""
For single mixture component use numpyro.infer.Predictive.
"""(EXPERIMENTAL INTERFACE) This class constructs the predictive distribution for
:class:`numpyro.contrib.einstein.steinvi.SteinVi`
.. Note:: For single mixture component use numpyro.infer.Predictive.
.. warning::
The `MixtureGuidePredictive` is experimental and will likely be replaced by
:class:`numpyro.infer.util.Predictive` in the future.
:param Callable model: Python callable containing Pyro primitives.
:param Callable guide: Python callable containing Pyro primitives to get posterior samples of sites.
:param Dict params: Dictionary of values for param sites of model/guide
:param Sequence guide_sites: Names of sites that contribute to the Stein mixture.
:param Optional[int] num_samples:
:param Optional[Sequence[str]] return_sites: Sites to return. By default, only sample sites not present
in the guide are returned.
:param str mixture_assignment_sitename: Name of site for mixture component assignment for sites not in the Stein
mixture.
"""

def __init__(
Expand All @@ -25,8 +41,6 @@ def __init__(
guide_sites: Sequence,
num_samples: Optional[int] = None,
return_sites: Optional[Sequence[str]] = None,
infer_discrete: bool = False,
parallel: bool = False,
mixture_assignment_sitename="mixture_assignments",
):
self.model_predictive = partial(
Expand All @@ -37,11 +51,11 @@ def __init__(
},
num_samples=num_samples,
return_sites=return_sites,
infer_discrete=infer_discrete,
parallel=parallel,
infer_discrete=False,
parallel=False,
)
self._batch_shape = (num_samples,)
self.parallel = parallel
self.parallel = False
self.guide_params = {
name: param for name, param in params.items() if name in guide_sites
}
Expand Down
127 changes: 85 additions & 42 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable

from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.random import KeyArray
from jax.tree_util import tree_map

from numpyro import handlers
Expand All @@ -24,6 +25,7 @@
from numpyro.distributions.transforms import IdentityTransform
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _guess_max_plate_nesting, transform_fn
from numpyro.optim import _NumPyroOptim
from numpyro.util import fori_collect, ravel_pytree

SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"])
Expand All @@ -35,36 +37,68 @@ def _numel(shape):


class SteinVI:
"""Variational inference with stein mixtures.
"""Variational inference with Stein mixtures.
:param model: Python callable with Pyro primitives for the model.
**Example:**
.. doctest::
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.contrib.einstein import MixtureGuidePredictive, SteinVI, RBFKernel
>>> def model(data):
... f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
... with numpyro.plate("N", data.shape[0] if data is not None else 10):
... numpyro.sample("obs", dist.Bernoulli(f), obs=data)
>>> def guide(data):
... alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
... beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
... constraint=constraints.positive)
... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))
>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> optimizer = numpyro.optim.Adam(step_size=0.0005)
>>> stein = SteinVI(model, guide, optimizer, kernel_fn=RBFKernel())
>>> stein_result = stein.run(random.PRNGKey(0), 2000, data)
>>> params = stein_result.params
>>> # use guide to make predictive
>>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=1000, guide_sites=stein.guide_sites)
>>> samples = predictive(random.PRNGKey(1), data=None)
:param Callable model: Python callable with Pyro primitives for the model.
:param guide: Python callable with Pyro primitives for the guide
(recognition network).
:param optim: an instance of :class:`~numpyro.optim._NumpyroOptim`.
:param kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein inference
:param num_stein_particles: number of particles for Stein inference.
(More particles give more mixture components and therefore likely capture more of the posterior distribution)
:param num_elbo_particles: number of particles for to approximate the attractive force gradient.
:param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`.
:param SteinKernel kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein mixture
inference.
:param num_stein_particles: Number of particles (i.e., mixture components) in the Stein mixture.
:param num_elbo_particles: Number of Monte Carlo draws used to approximate the attractive force gradient.
(More particles give better gradient approximations)
:param loss_temperature: scaling of loss factor
:param repulsion_temperature: scaling of repulsive forces (Non-linear Stein)
:param classic_guide_param_fn: predicate on names of parameters in guide which should be optimized classically
without Stein (E.g. parameters for large normal networks or other transformation)
:param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments
that remain constant during fitting.
:param Float loss_temperature: Scaling factor of the attractive force.
:param Float repulsion_temperature: Scaling factor of the repulsive force (Non-linear Stein)
:param Callable non_mixture_guide_param_fn: predicate on names of parameters in guide which should be optimized
classically without Stein (E.g. parameters for large normal networks or other transformation)
:param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments that remain constant
during inference.
"""

def __init__(
self,
model,
guide,
optim,
model: Callable,
guide: Callable,
optim: _NumPyroOptim,
kernel_fn: SteinKernel,
num_stein_particles: int = 10,
num_elbo_particles: int = 10,
loss_temperature: float = 1.0,
repulsion_temperature: float = 1.0,
classic_guide_params_fn: Callable[[str], bool] = lambda name: False,
non_mixture_guide_params_fn: Callable[[str], bool] = lambda name: False,
enum=True,
**static_kwargs,
):
Expand All @@ -82,8 +116,8 @@ def __init__(
self.loss_temperature = loss_temperature
self.repulsion_temperature = repulsion_temperature
self.enum = enum
self.model_params_fn = classic_guide_params_fn
self.guide_param_names = None
self.non_mixture_params_fn = non_mixture_guide_params_fn
self.guide_sites = None
self.constrain_fn = None
self.uconstrain_fn = None
self.particle_transform_fn = None
Expand Down Expand Up @@ -178,14 +212,17 @@ def _reinit(seed):

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
model_uparams = {
p: v
for p, v in unconstr_params.items()
if p not in self.guide_param_names or self.model_params_fn(p)
}
non_mixture_uparams = (
{ # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}
)
stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in model_uparams
p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams
}

# 1. Collect each guide parameter into monolithic particles that capture correlations
# between parameter values across each individual particle
stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
Expand All @@ -197,7 +234,9 @@ def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
attractive_key, classic_key = random.split(rng_key)

# 2. Calculate gradients for each particle
def kernel_particles_loss_fn(rng_key, particles):
def kernel_particles_loss_fn(
rng_key, particles
): # TODO: rewrite using def to utilize jax caching
particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles)
grads = vmap(
lambda i: grad(
Expand All @@ -215,7 +254,7 @@ def kernel_particles_loss_fn(rng_key, particles):
select_index=i,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(model_uparams),
param_map=self.constrain_fn(non_mixture_uparams),
)
)(
random.split(
Expand All @@ -237,13 +276,16 @@ def particle_transform_fn(particle):
ctparticle, _ = ravel_pytree(ctparams)
return tparticle, ctparticle

# 2.1 Lift particles to constraint space
tstein_particles, ctstein_particles = vmap(particle_transform_fn)(
stein_particles
)

# 2.2 Compute particle gradients (for attractive force)
particle_ljp_grads = kernel_particles_loss_fn(attractive_key, ctstein_particles)

classic_param_grads = grad(
# 2.2 Compute non-mixture parameter gradients
non_mixture_param_grads = grad(
lambda cps: -self.stein_loss.loss(
classic_key,
self.constrain_fn(cps),
Expand All @@ -253,14 +295,14 @@ def particle_transform_fn(particle):
*args,
**kwargs,
)
)(model_uparams)
)(non_mixture_uparams)

# 3. Calculate kernel on monolithic particle
kernel = self.kernel_fn.compute( # TODO: Fix to use Stein loss
# 3. Calculate kernel of particles
kernel = self.kernel_fn.compute(
stein_particles, particle_info, kernel_particles_loss_fn
)

# 4. Calculate the attractive force and repulsive force on the monolithic particles
# 4. Calculate the attractive force and repulsive force on the particles
attractive_force = vmap(
lambda y: jnp.sum(
vmap(
Expand Down Expand Up @@ -317,16 +359,17 @@ def _update_force(attr_force, rep_force, jac):
stein_param_grads = unravel_pytree_batched(particle_grads)

# 6. Return loss and gradients (based on parameter forces)
res_grads = tree_map(lambda x: -x, {**classic_param_grads, **stein_param_grads})
res_grads = tree_map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
)
return jnp.linalg.norm(particle_grads), res_grads

def init(self, rng_key, *args, **kwargs):
"""
:param jax.random.PRNGKey rng_key: random number generator seed.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
def init(self, rng_key: KeyArray, *args, **kwargs):
"""Register random variable transformations, constraints and determine initialize positions of the particles.
:param KeyArray rng_key: Random number generator seed.
:param args: Arguments to the model / guide.
:param kwargs: Keyword arguments to the model / guide.
:return: initial :data:`SteinVIState`
"""
rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4)
Expand Down Expand Up @@ -373,7 +416,7 @@ def init(self, rng_key, *args, **kwargs):
)
if site["name"] in guide_init_params:
pval, _ = guide_init_params[site["name"]]
if self.model_params_fn(site["name"]):
if self.non_mixture_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
else:
pval = site["value"]
Expand All @@ -384,7 +427,7 @@ def init(self, rng_key, *args, **kwargs):
if should_enum:
mpn = _guess_max_plate_nesting(model_trace)
self._inference_model = enum(config_enumerate(self.model), -mpn - 1)
self.guide_param_names = guide_param_names
self.guide_sites = guide_param_names
self.constrain_fn = partial(transform_fn, inv_transforms)
self.uconstrain_fn = partial(transform_fn, transforms)
self.particle_transforms = particle_transforms
Expand Down

0 comments on commit 7291cba

Please sign in to comment.