Skip to content

Commit

Permalink
Add initial SDVI implementation (#1758)
Browse files Browse the repository at this point in the history
  • Loading branch information
treigerm committed Mar 14, 2024
1 parent 79d9e6b commit 5da6fa5
Show file tree
Hide file tree
Showing 5 changed files with 408 additions and 49 deletions.
12 changes: 12 additions & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,19 @@ SteinVI Kernels
Stochastic Support
~~~~~~~~~~~~~~~~~~

.. autoclass:: numpyro.contrib.stochastic_support.dcc.StochasticSupportInference
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. autoclass:: numpyro.contrib.stochastic_support.dcc.DCC
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. autoclass:: numpyro.contrib.stochastic_support.sdvi.SDVI
:members:
:undoc-members:
:show-inheritance:
Expand Down
2 changes: 2 additions & 0 deletions numpyro/contrib/stochastic_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from numpyro.contrib.stochastic_support.dcc import DCC
from numpyro.contrib.stochastic_support.sdvi import SDVI

__all__ = [
"DCC",
"SDVI",
]
147 changes: 98 additions & 49 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple

import jax
Expand All @@ -16,13 +17,15 @@
DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"])


class DCC:
class StochasticSupportInference(ABC):
"""
Implements the Divide, Conquer, and Combine (DCC) algorithm for models with
stochastic support from [1].
Base class for running inference in programs with stochastic support. Each subclass
decomposes the input model into so called straight-line programs (SLPs) which are
the different control-flow paths in the model. Inference is then run in each SLP
separately and the results are combined to produce an overall posterior.
.. note:: This implementation assumes that all stochastic branching is done based on the
outcomes of discrete sampling sites that are annotated with `infer={"branching": True}`.
outcomes of discrete sampling sites that are annotated with ``infer={"branching": True}``.
For example,
.. code-block:: python
Expand All @@ -35,41 +38,18 @@ def model():
mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)
**References:**
1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*,
Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth
:param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`.
:param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`.
:param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for
local inference. Defaults to :class:`~numpyro.infer.NUTS`.
:param int num_slp_samples: Number of samples to draw from the prior to discover the
straight-line programs (SLPs).
:param int max_slps: Maximum number of SLPs to discover. DCC will not run inference
on more than `max_slps`.
:param float proposal_scale: Scale parameter for the proposal distribution for
estimating the normalization constant of an SLP.
"""

def __init__(
self,
model,
mcmc_kwargs,
kernel_cls=NUTS,
num_slp_samples=1000,
max_slps=124,
proposal_scale=1.0,
):
def __init__(self, model, num_slp_samples, max_slps):
self.model = model
self.kernel_cls = kernel_cls
self.mcmc_kwargs = mcmc_kwargs

self.num_slp_samples = num_slp_samples
self.max_slps = max_slps
self.proposal_scale = proposal_scale

def _find_slps(self, rng_key, *args, **kwargs):
"""
Expand Down Expand Up @@ -111,7 +91,95 @@ def _get_branching_trace(self, tr):
branching_trace[site["name"]] = int(site["value"])
return branching_trace

def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs):
@abstractmethod
def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def _combine_inferences(
self, rng_key, inferences, branching_traces, *args, **kwargs
):
raise NotImplementedError

def run(self, rng_key, *args, **kwargs):
"""
Run inference on each SLP separately and combine the results.
:param jax.random.PRNGKey rng_key: Random number generator key.
:param args: Arguments to the model.
:param kwargs: Keyword arguments to the model.
"""
rng_key, subkey = random.split(rng_key)
branching_traces = self._find_slps(subkey, *args, **kwargs)

inferences = dict()
for key, bt in branching_traces.items():
rng_key, subkey = random.split(rng_key)
inferences[key] = self._run_inference(subkey, bt, *args, **kwargs)

rng_key, subkey = random.split(rng_key)
return self._combine_inferences(
subkey, inferences, branching_traces, *args, **kwargs
)


class DCC(StochasticSupportInference):
"""
Implements the Divide, Conquer, and Combine (DCC) algorithm for models with
stochastic support from [1].
**References:**
1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*,
Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth
**Example:**
.. code-block:: python
def model():
model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
if model1 == 0:
mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
else:
mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)
mcmc_kwargs = dict(
num_warmup=500, num_samples=1000
)
dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
dcc_result = dcc.run(random.PRNGKey(0))
:param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`.
:param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`.
:param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for
local inference. Defaults to :class:`~numpyro.infer.NUTS`.
:param int num_slp_samples: Number of samples to draw from the prior to discover the
straight-line programs (SLPs).
:param int max_slps: Maximum number of SLPs to discover. DCC will not run inference
on more than `max_slps`.
:param float proposal_scale: Scale parameter for the proposal distribution for
estimating the normalization constant of an SLP.
"""

def __init__(
self,
model,
mcmc_kwargs,
kernel_cls=NUTS,
num_slp_samples=1000,
max_slps=124,
proposal_scale=1.0,
):
self.kernel_cls = kernel_cls
self.mcmc_kwargs = mcmc_kwargs

self.proposal_scale = proposal_scale

super().__init__(model, num_slp_samples, max_slps)

def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
"""
Run MCMC on the model conditioned on the given branching trace.
"""
Expand All @@ -122,7 +190,7 @@ def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs):

return mcmc.get_samples()

def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs):
def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwargs):
"""
Weight each SLP proportional to its estimated normalization constant.
The normalization constants are estimated using importance sampling with
Expand Down Expand Up @@ -159,22 +227,3 @@ def log_weight(rng_key, i, slp_model, slp_samples):
normalizer = jax.scipy.special.logsumexp(jnp.array(list(log_Zs.values())))
slp_weights = {k: jnp.exp(v - normalizer) for k, v in log_Zs.items()}
return DCCResult(samples, slp_weights)

def run(self, rng_key, *args, **kwargs):
"""
Run DCC and collect samples for all SLPs.
:param jax.random.PRNGKey rng_key: Random number generator key.
:param args: Arguments to the model.
:param kwargs: Keyword arguments to the model.
"""
rng_key, subkey = random.split(rng_key)
branching_traces = self._find_slps(subkey, *args, **kwargs)

samples = dict()
for key, bt in branching_traces.items():
rng_key, subkey = random.split(rng_key)
samples[key] = self._run_mcmc(subkey, bt, *args, **kwargs)

rng_key, subkey = random.split(rng_key)
return self._combine_samples(subkey, samples, branching_traces, *args, **kwargs)
122 changes: 122 additions & 0 deletions numpyro/contrib/stochastic_support/sdvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple

import jax
import jax.numpy as jnp

from numpyro.contrib.stochastic_support.dcc import StochasticSupportInference
from numpyro.handlers import condition
from numpyro.infer import (
SVI,
Trace_ELBO,
TraceEnum_ELBO,
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.autoguide import AutoNormal

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])

VALID_ELBOS = (Trace_ELBO, TraceMeanField_ELBO, TraceEnum_ELBO, TraceGraph_ELBO)


class SDVI(StochasticSupportInference):
"""
Implements the Support Decomposition Variational Inference (SDVI) algorithm for models with
stochastic support from [1]. This implementation creates a separate guide for each SLP, trains
the guides separately, and then combines the guides by weighting them proportional to their ELBO
estimates.
**References:**
1. *Rethinking Variational Inference for Probabilistic Programs with Stochastic Support*,
Tim Reichelt, Luke Ong, Tom Rainforth
**Example:**
.. code-block:: python
def model():
model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
if model1 == 0:
mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
else:
mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)
sdvi = SDVI(model, numpyro.optim.Adam(step_size=0.001))
sdvi_result = sdvi.run(random.PRNGKey(0))
:param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`.
:param optimizer: An instance of :class:`~numpyro.optim._NumpyroOptim`, a
``jax.example_libraries.optimizers.Optimizer`` or an Optax
``GradientTransformation``. Gets passed to :class:`~numpyro.infer.SVI`.
:param int svi_num_steps: Number of steps to run SVI for each SLP.
:param int combine_elbo_particles: Number of particles to estimate ELBO for computing
SLP weights.
:param guide_init: A constructor for the guide. This should be a callable that returns a
:class:`~numpyro.infer.autoguide.AutoGuide` instance. Defaults to
:class:`~numpyro.infer.autoguide.AutoNormal`.
:param loss: ELBO loss for SVI. Defaults to :class:`~numpyro.infer.Trace_ELBO`.
:param bool svi_progress_bar: Whether to use a progress bar for SVI.
:param int num_slp_samples: Number of samples to draw from the prior to discover the
straight-line programs (SLPs).
:param int max_slps: Maximum number of SLPs to discover. DCC will not run inference
on more than `max_slps`.
"""

def __init__(
self,
model,
optimizer,
svi_num_steps=1000,
combine_elbo_particles=1000,
guide_init=AutoNormal,
loss=Trace_ELBO(),
svi_progress_bar=False,
num_slp_samples=1000,
max_slps=124,
):
self.guide_init = guide_init
self.optimizer = optimizer
self.svi_num_steps = svi_num_steps
self.svi_progress_bar = svi_progress_bar

if not isinstance(loss, VALID_ELBOS):
err_str = ", ".join(x.__name__ for x in VALID_ELBOS)
raise ValueError(f"loss must be an instance of: ({err_str})")
self.loss = loss
self.combine_elbo_particles = combine_elbo_particles

super().__init__(model, num_slp_samples, max_slps)

def _run_inference(self, rng_key, branching_trace, *args, **kwargs):
"""
Run SVI on a given SLP defined by its branching trace.
"""
slp_model = condition(self.model, branching_trace)
guide = self.guide_init(slp_model)
svi = SVI(slp_model, guide, self.optimizer, loss=self.loss)
svi_result = svi.run(
rng_key,
self.svi_num_steps,
*args,
progress_bar=self.svi_progress_bar,
**kwargs,
)
return guide, svi_result.params

def _combine_inferences(self, rng_key, guides, branching_traces, *args, **kwargs):
"""Weight each SLP proportional to its estimated ELBO."""
elbos = {}
for bt, (guide, param_map) in guides.items():
slp_model = condition(self.model, branching_traces[bt])
elbos[bt] = -Trace_ELBO(num_particles=self.combine_elbo_particles).loss(
rng_key, param_map, slp_model, guide, *args, **kwargs
)

normalizer = jax.scipy.special.logsumexp(jnp.array(list(elbos.values())))
slp_weights = {k: jnp.exp(v - normalizer) for k, v in elbos.items()}
return SDVIResult(guides, slp_weights)

0 comments on commit 5da6fa5

Please sign in to comment.