Skip to content

Commit

Permalink
Add two ensemble sampling methods (#1692)
Browse files Browse the repository at this point in the history
* ensemble sampling draft

* rewrite for loop as fori_loop

* added efficiency comment for ESS GaussianMove

* fix typo

* fixed ravel for mixed dtype

* add defaults

* add support for potential_fn

* AIES tests, warnings for AIES

* AIES input validation

* better docs, more input validation

* ESS passing test cases

* add tests for other files

* linting

* refactor ensemble_util

* make test result less close to margin in CI, swap deprecated function

* rename get_nondiagonal_indices,  fix batch_ravel_pytree

* print ensemble kernel diagnostics, smoke test parallel arg

* fix docstring build

* documentation

* skip slow CI tests, unnest test if statements

* fix doctest

* doc rewrite

* fix distribution test
  • Loading branch information
amifalk committed Jan 26, 2024
1 parent 3ef0423 commit 01089cf
Show file tree
Hide file tree
Showing 9 changed files with 1,055 additions and 26 deletions.
34 changes: 33 additions & 1 deletion docs/source/mcmc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ We provide a high-level overview of the MCMC algorithms in NumPyro:
* `BarkerMH <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.barker.BarkerMH>`_ is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables.
* `HMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCGibbs>`_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
* `DiscreteHMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.DiscreteHMCGibbs>`_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
* `SA <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.sa.SA>`_ is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
* `SA <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.sa.SA>`_ is a gradient-free MCMC method. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
* `AIES <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.ensemble.AIES>`_ is a gradient-free ensemble MCMC method that informs Metropolis-Hastings proposals by sharing information between chains. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities, and can be robust to likelihood-free models. AIES generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger).
* `ESS <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.ensemble.ESS>`_ is a gradient-free ensemble MCMC method that shares information between chains to find good slice sampling directions. It tends to be more sample efficient than AIES. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate and may be a good choice for models with non-differentiable log densities. ESS generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger).

Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions <https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence>`_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example <https://num.pyro.ai/en/stable/examples/annotation.html>`_.

Expand Down Expand Up @@ -101,6 +103,30 @@ SA
:show-inheritance:
:member-order: bysource

EnsembleSampler
^^^^^^^^^^^^^^^
.. autoclass:: numpyro.infer.ensemble.EnsembleSampler
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

AIES
^^^^
.. autoclass:: numpyro.infer.ensemble.AIES
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

ESS
^^^
.. autoclass:: numpyro.infer.ensemble.ESS
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. autofunction:: numpyro.infer.hmc.hmc

.. autofunction:: numpyro.infer.hmc.hmc.init_kernel
Expand All @@ -117,6 +143,12 @@ SA

.. autodata:: numpyro.infer.sa.SAState

.. autodata:: numpyro.infer.ensemble.EnsembleSamplerState

.. autodata:: numpyro.infer.ensemble.AIESState

.. autodata:: numpyro.infer.ensemble.ESSState


TensorFlow Kernels
------------------
Expand Down
3 changes: 3 additions & 0 deletions numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.ensemble import AIES, ESS
from numpyro.infer.hmc import HMC, NUTS
from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs
from numpyro.infer.initialization import (
Expand All @@ -29,6 +30,7 @@
from . import autoguide, reparam

__all__ = [
"AIES",
"autoguide",
"init_to_feasible",
"init_to_mean",
Expand All @@ -41,6 +43,7 @@
"BarkerMH",
"DiscreteHMCGibbs",
"ELBO",
"ESS",
"HMC",
"HMCECS",
"HMCGibbs",
Expand Down

0 comments on commit 01089cf

Please sign in to comment.