Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Circular Reparameterization #1080

Merged
merged 32 commits into from Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4a46f7d
Add CircularReparam
alexlyttle Jun 23, 2021
e40e1f7
Change to use jnp.remainder
alexlyttle Jun 23, 2021
cf50a9e
Add circular constraint
alexlyttle Jun 25, 2021
031ef32
Change Von Mises constraint to circular
alexlyttle Jun 25, 2021
e26bdd9
Remove comment
alexlyttle Jun 25, 2021
3fe4bad
Lint and format
alexlyttle Jun 25, 2021
0963ec0
Add helpful error for circular support
alexlyttle Jul 2, 2021
427cc12
Add docstring to VonMises distribution
alexlyttle Jul 2, 2021
a98f78c
Remove trailing whitespace
alexlyttle Jul 2, 2021
a3ca41e
Add circular constraint to gen_values functions
alexlyttle Jul 2, 2021
2e94075
Review response
alexlyttle Jul 8, 2021
6a69e19
Circular now inherits from interval
alexlyttle Jul 9, 2021
3531bcd
Move to after _Interval definition
alexlyttle Jul 9, 2021
fce4e0c
Replace Normal with ImproperUniform
alexlyttle Jul 9, 2021
243312f
Add `test_circular` and `get_circular_moments`
alexlyttle Jul 9, 2021
008c9e4
Lint and format
alexlyttle Jul 9, 2021
8d0db6c
Update docstring
alexlyttle Jul 12, 2021
89ac5a0
Add circular to all
alexlyttle Jul 12, 2021
d6f1499
Assert circular support and simplify
alexlyttle Jul 12, 2021
78747c5
Change from error to warning
alexlyttle Jul 12, 2021
82fad9e
Revert changes to gen_values functions
alexlyttle Jul 12, 2021
a19c417
Add CircularReparm docs
alexlyttle Jul 14, 2021
5e3ce48
Add circular constraint
alexlyttle Jul 14, 2021
26fa700
Change to autofunction
alexlyttle Jul 14, 2021
d9e5a4c
Merge branch 'master' into circular-reparam
alexlyttle Jul 14, 2021
f11c072
Change circular to autodata
alexlyttle Jul 14, 2021
2af2fb6
Make circular instance of _Interval
alexlyttle Jul 19, 2021
55c2fe1
Modify assert for changes to constraints.py
alexlyttle Jul 19, 2021
0e35508
Warn if support is circular
alexlyttle Jul 19, 2021
933f788
Change back to remainder
alexlyttle Jul 19, 2021
138ccdb
Lint and format
alexlyttle Jul 19, 2021
6146a3e
Add raise_warnings flag to helpful_support_errors
alexlyttle Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/distributions.rst
Expand Up @@ -578,6 +578,10 @@ boolean
^^^^^^^
.. autodata:: numpyro.distributions.constraints.boolean

circular
--------
.. autodata:: numpyro.distributions.constraints.circular

corr_cholesky
^^^^^^^^^^^^^
.. autodata:: numpyro.distributions.constraints.corr_cholesky
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reparam.rst
Expand Up @@ -51,3 +51,12 @@ Projected Normal Distributions
:show-inheritance:
:member-order: bysource
:special-members: __call__

Circular Distributions
----------------------
.. autoclass:: numpyro.infer.reparam.CircularReparam
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:special-members: __call__
4 changes: 4 additions & 0 deletions numpyro/distributions/constraints.py
Expand Up @@ -28,6 +28,7 @@

__all__ = [
"boolean",
"circular",
"corr_cholesky",
"corr_matrix",
"dependent",
Expand All @@ -53,6 +54,8 @@
"Constraint",
]

fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
import math

import numpy as np

import jax.numpy
Expand Down Expand Up @@ -454,6 +457,7 @@ def feasible_like(self, prototype):
# See https://github.com/pytorch/pytorch/issues/50616

boolean = _Boolean()
circular = _Interval(-math.pi, math.pi)
corr_cholesky = _CorrCholesky()
corr_matrix = _CorrMatrix()
dependent = _Dependent()
Expand Down
19 changes: 18 additions & 1 deletion numpyro/distributions/directional.py
Expand Up @@ -20,9 +20,26 @@


class VonMises(Distribution):
"""
The von Mises distribution, also known as the circular normal distribution.

This distribution is supported by a circular constraint from -pi to +pi. By
default, the circular support behaves like
``constraints.interval(-math.pi, math.pi)``. To avoid issues at the
boundaries of this interval during sampling, you should reparameterize this
distribution using ``handlers.reparam`` with a
:class:`~numpyro.infer.reparam.CircularReparam` reparametrizer in
the model, e.g.::

@handlers.reparam(config={"direction": CircularReparam()})
def model():
direction = numpyro.sample("direction", VonMises(0.0, 4.0))
...
"""

arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
reparametrized_params = ["loc"]
support = constraints.interval(-math.pi, math.pi)
support = constraints.circular

def __init__(self, loc, concentration, validate_args=None):
"""von Mises distribution for sampling directions.
Expand Down
30 changes: 30 additions & 0 deletions numpyro/infer/reparam.py
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
import math

import jax.numpy as jnp

Expand Down Expand Up @@ -288,3 +289,32 @@ def transform_sample(self, latent):
"""
x_unconstrained = self.transform(latent)
return self.guide._unpack_and_constrain(x_unconstrained, self.params)


class CircularReparam(Reparam):
"""
Reparametrizer for :class:`~numpyro.distributions.VonMises` latent
variables.
"""

def __call__(self, name, fn, obs):
# Support must be circular
support = fn.support
if isinstance(support, constraints.independent):
support = fn.support.base_constraint
assert support is constraints.circular

# Draw parameter-free noise.
new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
value = numpyro.sample(
f"{name}_unwrapped",
new_fn,
obs=obs,
)

# Differentiably transform.
value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi

# Simulate a pyro.deterministic() site.
numpyro.factor(f"{name}_factor", fn.log_prob(value))
return None, value
24 changes: 20 additions & 4 deletions numpyro/infer/util.py
Expand Up @@ -397,7 +397,7 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
)
else:
support = v["fn"].support
with helpful_support_errors(v):
with helpful_support_errors(v, raise_warnings=True):
inv_transforms[k] = biject_to(support)
# XXX: the following code filters out most situations with dynamic supports
args = ()
Expand Down Expand Up @@ -961,12 +961,28 @@ def single_loglik(samples):


@contextmanager
def helpful_support_errors(site):
def helpful_support_errors(site, raise_warnings=False):
name = site["name"]
support = getattr(site["fn"], "support", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for completeness, you can add

if isinstance(support, constraints.independent):
    support = support.base_constraint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I've added that now

if isinstance(support, constraints.independent):
support = support.base_constraint

# Warnings
if raise_warnings:
if support is constraints.circular:
msg = (
f"Continuous inference poorly handles circular sample site '{name}'. "
+ "Consider using VonMises distribution together with "
+ "a reparameterizer, e.g. "
+ f"numpyro.handlers.reparam(config={{'{name}': CircularReparam()}})."
)
warnings.warn(msg, UserWarning)

# Exceptions
try:
yield
except NotImplementedError as e:
name = site["name"]
support_name = repr(site["fn"].support).lower()
support_name = repr(support).lower()
if "integer" in support_name or "boolean" in support_name:
# TODO: mention enumeration when it is supported in SVI
raise ValueError(
Expand Down
58 changes: 58 additions & 0 deletions test/infer/test_reparam.py
Expand Up @@ -15,6 +15,7 @@
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoIAFNormal
from numpyro.infer.reparam import (
CircularReparam,
LocScaleReparam,
NeuTraReparam,
ProjectedNormalReparam,
Expand All @@ -37,6 +38,26 @@ def get_moments(x):
return jnp.stack([m1, m2, m3, m4])


# Helper functions to get central circular moments
def mean_vector(x, m, n):
s = jnp.mean(jnp.sin(n * (x - m)), axis=0)
c = jnp.mean(jnp.cos(n * (x - m)), axis=0)
return s, c


def circular_moment(x, n):
m = jnp.arctan2(*mean_vector(x, 0.0, n)) # circular mean
s, c = mean_vector(x, m, n)
# direction = jnp.arctan2(s, c)
length = jnp.hypot(s, c)
return length
# return jnp.array([direction, length])


def get_circular_moments(x):
return jnp.stack([circular_moment(x, i) for i in range(1, 5)])


def test_syntax():
loc = np.random.uniform(-1.0, 1.0, ())
scale = np.random.uniform(0.5, 1.5, ())
Expand Down Expand Up @@ -268,3 +289,40 @@ def get_actual_probe(concentration):
expected_grad = jacobian(get_expected_probe)(concentration)
actual_grad = jacobian(get_actual_probe)(concentration)
assert_allclose(actual_grad, expected_grad, atol=0.05)


@pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str)
def test_circular(shape):
# Define two models which should return the same distributions
# This model is the expected distribution
def model_exp(loc, concentration):
with numpyro.plate_stack("plates", shape):
with numpyro.plate("particles", 10000):
numpyro.sample("x", dist.VonMises(loc, concentration))

# This model is for inference
reparam = CircularReparam()

@numpyro.handlers.reparam(config={"x": reparam})
def model_act(loc, concentration):
numpyro.sample("x", dist.VonMises(loc, concentration))

def get_expected_probe(loc, concentration):
with numpyro.handlers.trace() as trace:
with numpyro.handlers.seed(rng_seed=0):
model_exp(loc, concentration)
return get_circular_moments(trace["x"]["value"])

def get_actual_probe(loc, concentration):
kernel = NUTS(model_act)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1)
mcmc.run(random.PRNGKey(0), loc, concentration)
samples = mcmc.get_samples()
return get_circular_moments(samples["x"])

loc = np.random.uniform(-np.pi, np.pi, shape)
concentration = np.random.lognormal(1.0, 1.0, shape)
expected_probe = get_expected_probe(loc, concentration)
actual_probe = get_actual_probe(loc, concentration)

assert_allclose(actual_probe, expected_probe, atol=0.1)