Skip to content

Commit

Permalink
bump jaxns to >=2.0.1 (#1546)
Browse files Browse the repository at this point in the history
* * bump jaxns to >= 2.0.0
* adjusted the wrapped to use new structure.

* * points user to JAXNS's readthedocs

* * addressed comments
* make format run

* * addressed comments
* make format run

* * bump docs to python 3.9

* * skip pytest when jaxns import error. (Should still be an error for < 3.9)

* * Only install jaxns if python>=3.9

* * make format

* * Noticed that JAXNS was organised under the MCMC headings, however it's not an MCMC method, so I broke it out into its own section.

* * I made jaxns compatible with 3.8 again by removing some annotations.

* * Downgrade readthedocs again to 3.8

* * make format

* * Fix E402
* make format

* * nits
* make format

---------

Co-authored-by: Du Phan <fehiepsi@gmail.com>
  • Loading branch information
Joshuaalbert and fehiepsi committed Mar 15, 2023
1 parent 849d4cf commit bb9e1ba
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 39 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ As discussed above, model [reparameterization](https://num.pyro.ai/en/latest/rep
- [HMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#hmcgibbs) combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
- [DiscreteHMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#discretehmcgibbs) combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
- [SA](https://num.pyro.ai/en/latest/mcmc.html#sa) is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
- [NestedSampler](https://num.pyro.ai/en/latest/contrib.html#nested-sampling) offers a wrapper for [jaxns](https://github.com/Joshuaalbert/jaxns). See [here](https://github.com/pyro-ppl/numpyro/blob/master/examples/gaussian_shells.py) for an example.

Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see [restrictions](https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence)). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the [annotation example](https://num.pyro.ai/en/stable/examples/annotation.html).

### Nested Sampling
- [NestedSampler](https://num.pyro.ai/en/latest/contrib.html#nested-sampling) offers a wrapper for [jaxns](https://github.com/Joshuaalbert/jaxns). See [JAXNS's readthedocs](https://jaxns.readthedocs.io/en/latest/) for examples and [Nested Sampling for Gaussian Shells](https://num.pyro.ai/en/stable/examples/gaussian_shells.html) example for how to apply the sampler on numpyro models. Can handle arbitrary models, including ones with discrete RVs, and non-invertible transformations.

### Stochastic variational inference
- Variational objectives
- [Trace_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.Trace_ELBO) is our basic ELBO implementation.
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ funsor
ipython<=8.6.0 # strict the version for https://github.com/ipython/ipython/issues/13845
jax
jaxlib
jaxns==1.0.0
jaxns>=2.0.1
Jinja2<3.1
matplotlib
multipledispatch
Expand Down
8 changes: 7 additions & 1 deletion docs/source/mcmc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ We provide a high-level overview of the MCMC algorithms in NumPyro:
* `HMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCGibbs>`_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
* `DiscreteHMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.DiscreteHMCGibbs>`_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
* `SA <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.sa.SA>`_ is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
* `NestedSampler <https://num.pyro.ai/en/latest/contrib.html#nested-sampling>`_ offers a wrapper for `jaxns <https://github.com/Joshuaalbert/jaxns>`_. See `here <https://github.com/pyro-ppl/numpyro/blob/master/examples/gaussian_shells.py>`_ for an example.

Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions <https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence>`_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example <https://num.pyro.ai/en/stable/examples/annotation.html>`_.

Expand All @@ -20,6 +19,13 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l
:show-inheritance:
:member-order: bysource

Nested Sampling
===============================

Nested Sampling is a non-MCMC approach that works for arbitrary probability models, and is particularly well suited to complex posteriors:

* `NestedSampler <https://num.pyro.ai/en/latest/contrib.html#nested-sampling>`_ offers a wrapper for `jaxns <https://github.com/Joshuaalbert/jaxns>`_. See `JAXNS's readthedocs <https://jaxns.readthedocs.io/en/latest/>`_ for examples and `Nested Sampling for Gaussian Shells <https://num.pyro.ai/en/stable/examples/gaussian_shells.html>`_ example for how to apply the sampler on numpyro models. Can handle arbitrary models, including ones with discrete RVs, and non-invertible transformations.

MCMC Kernels
------------

Expand Down
97 changes: 63 additions & 34 deletions numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@

from functools import singledispatch

from jax import nn, random, tree_util
from jax import random
import jax.numpy as jnp

try:
from jaxns import (
NestedSampler as OrigNestedSampler,
ExactNestedSampler as OrigNestedSampler,
Model,
NestedSamplerResults,
Prior,
PriorModelGen,
TerminationCondition,
plot_cornerplot,
plot_diagnostics,
resample,
summary,
)
from jaxns.prior_transforms import ContinuousPrior, PriorChain
from jaxns.prior_transforms.prior import UniformBase
except ImportError as e:
raise ImportError(
"To use this module, please install `jaxns` package. It can be"
" installed with `pip install jaxns`"
" installed with `pip install jaxns` with python>=3.8"
) from e

import tensorflow_probability.substrates.jax as tfp

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import reparam, seed, trace
Expand All @@ -30,14 +36,7 @@

__all__ = ["NestedSampler"]


class UniformPrior(ContinuousPrior):
def __init__(self, name, shape):
prior_base = UniformBase(shape, jnp.result_type(float))
super().__init__(name, shape, parents=[], tracked=True, prior_base=prior_base)

def transform_U(self, U, **kwargs):
return U
tfpd = tfp.distributions


@singledispatch
Expand Down Expand Up @@ -118,8 +117,6 @@ def __call__(self, name, fn, obs):
return None, transform(x)


# TODO: Consider deprecating this wrapper. It might be better to only provide some
# utilities to help converting a NumPyro model to a Jaxns loglikelihood function.
class NestedSampler:
"""
(EXPERIMENTAL) A wrapper for `jaxns <https://github.com/Joshuaalbert/jaxns>`_ ,
Expand Down Expand Up @@ -189,7 +186,7 @@ def __init__(
)
self._samples = None
self._log_weights = None
self._results = None
self._results: NestedSamplerResults | None = None

def run(self, rng_key, *args, **kwargs):
"""
Expand Down Expand Up @@ -246,24 +243,58 @@ def run(self, rng_key, *args, **kwargs):
loglik_fn = local_dict["loglik_fn"]

# use NestedSampler with identity prior chain
prior_chain = PriorChain()
for name in param_names:
prior = UniformPrior(name + "_base", prototype_trace[name]["fn"].shape())
prior_chain.push(prior)
# XXX: the `marginalised` keyword in jaxns can be used to get expectation of some
# quantity over posterior samples; it can be helpful to expose it in this wrapper
ns = OrigNestedSampler(
loglik_fn,
prior_chain,
def prior_model() -> PriorModelGen:
params = []
for name in param_names:
shape = prototype_trace[name]["fn"].shape()
param = yield Prior(
tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)),
name=name + "_base",
)
params.append(param)
return tuple(params)

model = Model(prior_model=prior_model, log_likelihood=loglik_fn)

default_constructor_kwargs = dict(
num_live_points=model.U_ndims * 25,
num_parallel_samplers=1,
max_samples=1e4,
uncert_improvement_patience=2,
)
default_termination_kwargs = dict(live_evidence_frac=1e-4)
# Fill-in missing values with defaults. This allows user to inspect what was actually used by inspecting
# these dictionaries
list(
map(
lambda item: self.constructor_kwargs.setdefault(*item),
default_constructor_kwargs.items(),
)
)
list(
map(
lambda item: self.termination_kwargs.setdefault(*item),
default_termination_kwargs.items(),
)
)

exact_ns = OrigNestedSampler(
model=model,
**self.constructor_kwargs,
)
results = ns(rng_sampling, **self.termination_kwargs)

termination_reason, state = exact_ns(
rng_sampling,
term_cond=TerminationCondition(**self.termination_kwargs),
)
results = exact_ns.to_results(state, termination_reason)

# transform base samples back to original domains
# Here we only transform the first valid num_samples samples
# NB: the number of weighted samples obtained from jaxns is results.num_samples
# and only the first num_samples values of results.samples are valid.
num_samples = results.total_num_samples
samples = tree_util.tree_map(lambda x: x[:num_samples], results.samples)
samples = results.samples
predictive = Predictive(
reparam_model, samples, return_sites=param_names + deterministics
)
Expand All @@ -283,11 +314,10 @@ def get_samples(self, rng_key, num_samples):
raise RuntimeError(
"NestedSampler.run(...) method should be called first to obtain results."
)

samples, log_weights = self.get_weighted_samples()
p = nn.softmax(log_weights)
idx = random.choice(rng_key, log_weights.shape[0], (num_samples,), p=p)
return {k: v[idx] for k, v in samples.items()}
weighted_samples, sample_weights = self.get_weighted_samples()
return resample(
rng_key, weighted_samples, sample_weights, S=num_samples, replace=True
)

def get_weighted_samples(self):
"""
Expand All @@ -298,8 +328,7 @@ def get_weighted_samples(self):
"NestedSampler.run(...) method should be called first to obtain results."
)

num_samples = self._results.total_num_samples
return self._results.samples, self._results.log_dp_mean[:num_samples]
return self._results.samples, self._results.log_dp_mean

def print_summary(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"flax",
"funsor>=0.4.1",
"graphviz",
"jaxns==1.0.0",
"jaxns>=2.0.1",
"matplotlib",
"optax>=0.0.6",
"pylab-sdk", # jaxns dependency
Expand Down
6 changes: 5 additions & 1 deletion test/contrib/test_nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import jax.numpy as jnp

import numpyro
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam

try:
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam
except ImportError:
pytestmark = pytest.mark.skip(reason="jaxns is not installed")
import numpyro.distributions as dist
from numpyro.distributions.transforms import AffineTransform, ExpTransform

Expand Down

0 comments on commit bb9e1ba

Please sign in to comment.