Skip to content

Commit

Permalink
Add explicit reparametrizer. (#1754)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Mar 18, 2024
1 parent 013e54c commit fe7b693
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/source/reparam.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ Circular Distributions
:show-inheritance:
:member-order: bysource
:special-members: __call__

Explicit Reparameterization
---------------------------
.. autoclass:: numpyro.infer.reparam.ExplicitReparam
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:special-members: __call__
46 changes: 46 additions & 0 deletions numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from abc import ABC, abstractmethod
import math
from typing import Iterable

import numpy as np

Expand Down Expand Up @@ -346,3 +347,48 @@ def __call__(self, name, fn, obs):
# Simulate a pyro.deterministic() site.
numpyro.factor(f"{name}_factor", fn.log_prob(value))
return None, value


class ExplicitReparam(Reparam):
"""
Explicit reparametrizer of a latent variable :code:`x` to a transformed space
:code:`y = transform(x)` with more amenable geometry. This reparametrizer is similar
to :class:`.TransformReparam` but allows reparametrizations to be decoupled from the
model declaration.
:param transform: Bijective transform to the reparameterized space.
**Example:**
.. doctest::
>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import handlers, distributions as dist
>>> from numpyro.infer import MCMC, NUTS
>>> from numpyro.infer.reparam import ExplicitReparam
>>>
>>> def model():
... numpyro.sample("x", dist.Gamma(4, 4))
>>>
>>> # Sample in unconstrained space using a soft-plus instead of exp transform.
>>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv)
>>> reparametrized = handlers.reparam(model, {"x": reparam})
>>> kernel = NUTS(model=reparametrized)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1)
>>> mcmc.run(random.PRNGKey(2)) # doctest: +SKIP
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
"""
def __init__(self, transform):
if isinstance(transform, Iterable) and all(
isinstance(t, dist.transforms.Transform) for t in transform
):
transform = dist.transforms.ComposeTransform(transform)
self.transform = transform

def __call__(self, name, fn, obs):
assert obs is None, "ExplicitReparam does not support observe statements"
transformed = dist.TransformedDistribution(fn, self.transform)
x = numpyro.sample(f"{name}_base", transformed)
return None, self.transform.inv(x)
14 changes: 14 additions & 0 deletions test/infer/test_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from numpyro.infer.autoguide import AutoIAFNormal
from numpyro.infer.reparam import (
CircularReparam,
ExplicitReparam,
LocScaleReparam,
NeuTraReparam,
ProjectedNormalReparam,
Expand Down Expand Up @@ -399,3 +400,16 @@ def model():
reparam_model = handlers.reparam(model, config={"x": LocScaleReparam(0)})
with pytest.raises(ValueError, match="LocScaleReparam.*"):
handlers.seed(reparam_model, rng_seed=0)()


def test_explicit_reparam():
def model():
numpyro.sample("x", dist.Gamma(4, 4))

reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv)
reparametrized = handlers.reparam(model, {"x": reparam})
kernel = NUTS(model=reparametrized)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1)
mcmc.run(random.PRNGKey(2))
samples = mcmc.get_samples()
assert abs(samples["x"].mean() - 1) < 0.1

0 comments on commit fe7b693

Please sign in to comment.