Skip to content

Commit

Permalink
Add ZeroSumNormal distribution (#1751)
Browse files Browse the repository at this point in the history
* added zerosumnormal and tests

* added edge case handling for support shape

* removed commented out functions

* added zerosumnormal to docs

* fixed zerosumnormal support shape default

* Added v1 of docstrings for zerosumnormal

* updated zsn docstring

* improved init shape handling for zerosumnormal

* improved docstrings

* added ZeroSumTransform

* made n_zerosum_axes an attribute for the zerosumtransform

* removed commented out lines

* added zerosumtransform class

* switched zsn from ParameterFreeTransform to Transform

* changed ZeroSumNormal to transformed distibutrion

* changed input to tuple for _transform_to_zero_sum

* added forward and inverse shape to transform, fixed zero_sum constraint handling

* fixed failing zsn tests

* added docstring, removed whitespace, fixed missing import

* fixed allclose to be assert allclose

* linted and formatted

* added sample code to docstring for zsn

* updated docstring

* removed list from ZeroSum constraint call

* removed unneeded iteration, updated docstring

* updated constraint code

* added ZeroSumTransform to docs

* fixed transform shapes

* added doctest example for zsn

* added constraint test

* added zero_sum constraint to docs

* added type hinting to transforms file

* fixed docs formatting

* moved skip zsn from test_gof earlier

* reversed zerosumtransform

* broadcasted mean and var of zsn

* added stricter zero_sum constraint tol, improved mean and var functions

* fixed _transform_to_zero_sum

* removed shape promote from zsn, changed broadcast to zeros_like

* chose better zsn test cases

* Update zero_sum constraint feasible_like

Co-authored-by: Till Hoffmann <tillahoffmann@gmail.com>

* fixed docstring for doctests

---------

Co-authored-by: Till Hoffmann <tillahoffmann@gmail.com>
  • Loading branch information
kylejcaron and tillahoffmann committed Mar 30, 2024
1 parent 84973a9 commit 68eb218
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 1 deletion.
19 changes: 19 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ Weibull
:show-inheritance:
:member-order: bysource

ZeroSumNormal
^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.ZeroSumNormal
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Discrete Distributions
----------------------
Expand Down Expand Up @@ -820,6 +827,9 @@ unit_interval
^^^^^^^^^^^^^
.. autodata:: numpyro.distributions.constraints.unit_interval

zero_sum
^^^^^^^^
.. autodata:: numpyro.distributions.constraints.zero_sum

Transforms
----------
Expand Down Expand Up @@ -1014,6 +1024,15 @@ StickBreakingTransform
:show-inheritance:
:member-order: bysource

ZeroSumTransform
^^^^^^^^^^^^^^^^

.. autoclass:: numpyro.distributions.transforms.ZeroSumTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource


Flows
-----
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
StudentT,
Uniform,
Weibull,
ZeroSumNormal,
)
from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta
from numpyro.distributions.directional import (
Expand Down Expand Up @@ -196,4 +197,5 @@
"ZeroInflatedDistribution",
"ZeroInflatedPoisson",
"ZeroInflatedNegativeBinomial2",
"ZeroSumNormal",
]
25 changes: 25 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"softplus_lower_cholesky",
"softplus_positive",
"unit_interval",
"zero_sum",
"Constraint",
]

Expand Down Expand Up @@ -697,6 +698,29 @@ def feasible_like(self, prototype):
return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5))


class _ZeroSum(Constraint):
def __init__(self, event_dim=1):
self.event_dim = event_dim
super().__init__()

def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10
zerosum_true = True
for dim in range(-self.event_dim, 0):
zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol)
return zerosum_true

def __eq__(self, other):
return type(self) is type(other) and self.event_dim == other.event_dim

def feasible_like(self, prototype):
return jax.numpy.zeros_like(prototype)

def tree_flatten(self):
return (self.event_dim,), (("event_dim",), dict())


# TODO: Make types consistent
# See https://github.com/pytorch/pytorch/issues/50616

Expand Down Expand Up @@ -731,3 +755,4 @@ def feasible_like(self, prototype):
sphere = _Sphere()
unit_interval = _UnitInterval()
open_interval = _OpenInterval
zero_sum = _ZeroSum
95 changes: 95 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
ExpTransform,
PowerTransform,
SigmoidTransform,
ZeroSumTransform,
)
from numpyro.distributions.util import (
add_diag,
Expand Down Expand Up @@ -2438,3 +2439,97 @@ def cdf(self, value):

def icdf(self, value):
return self._ald.icdf(value)


class ZeroSumNormal(TransformedDistribution):
r"""
Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or
more axes are constrained to sum to zero (the last axis by default).
.. math::
\begin{align*}
ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
n = \text{number of zero-sum axes}
\end{align*}
:param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is
enforced.
:param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero.
**Example:**
.. doctest::
>>> from numpy.testing import assert_allclose
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, NUTS
>>> N = 1000
>>> n_categories = 20
>>> rng_key = random.PRNGKey(0)
>>> key1, key2, key3 = random.split(rng_key, 3)
>>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,))
>>> beta = random.normal(key2, shape=(n_categories,))
>>> beta -= beta.mean(-1)
>>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,))
>>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories
... N = len(category_ind)
... alpha = numpyro.sample("alpha", dist.Normal(0, 2.5))
... beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,)))
... sigma = numpyro.sample("sigma", dist.Exponential(1))
... with numpyro.plate("observations", N):
... mu = alpha + beta[category_ind]
... obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
... return obs
>>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9)
>>> mcmc = MCMC(
... sampler=nuts_kernel,
... num_samples=1_000, num_warmup=1_000, num_chains=4
... )
>>> mcmc.run(random.PRNGKey(0), category_ind=category_ind, y=y)
>>> posterior_samples = mcmc.get_samples()
>>> # Confirm everything along last axis sums to zero
>>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3)
**References**
[1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
[3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
"""

arg_constraints = {"scale": constraints.positive}
reparametrized_params = ["scale"]

def __init__(self, scale, event_shape, *, validate_args=None):
event_ndim = len(event_shape)
transformed_shape = tuple(size - 1 for size in event_shape)
self.scale = scale
super().__init__(
Normal(0, scale).expand(transformed_shape).to_event(event_ndim),
ZeroSumTransform(event_ndim),
validate_args=validate_args,
)

@constraints.dependent_property(is_discrete=False)
def support(self):
return constraints.zero_sum(len(self.event_shape))

@property
def mean(self):
return jnp.zeros(self.batch_shape + self.event_shape)

@property
def variance(self):
event_ndim = len(self.event_shape)
zero_sum_axes = tuple(range(-event_ndim, 0))
theoretical_var = jnp.square(self.scale)
for axis in zero_sum_axes:
theoretical_var *= 1 - 1 / self.event_shape[axis]

return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape)
93 changes: 93 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import weakref

import numpy as np
from numpy.core.numeric import normalize_axis_tuple

from jax import lax, vmap
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -50,6 +51,7 @@
"StickBreakingTransform",
"Transform",
"UnpackTransform",
"ZeroSumTransform",
]


Expand Down Expand Up @@ -1380,6 +1382,92 @@ def __eq__(self, other):
return jnp.array_equal(self.transition_matrix, other.transition_matrix)


class ZeroSumTransform(Transform):
"""A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]
:param transform_ndims: Number of trailing dimensions to transform.
**References**
[1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
[3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
"""

def __init__(self, transform_ndims: int = 1) -> None:
self.transform_ndims = transform_ndims

@property
def domain(self) -> constraints.Constraint:
return constraints.independent(constraints.real, self.transform_ndims)

@property
def codomain(self) -> constraints.Constraint:
return constraints.zero_sum(self.transform_ndims)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
for axis in zero_sum_axes:
x = self.extend_axis(x, axis=axis)
return x

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
for axis in zero_sum_axes:
y = self.extend_axis_rev(y, axis=axis)
return y

def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]

n = array.shape[normalized_axis]
last = jnp.take(array, jnp.array([-1]), axis=normalized_axis)

sum_vals = -last * jnp.sqrt(n)
norm = sum_vals / (jnp.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis
return array[(*slice_before, slice(None, -1))] + norm

def extend_axis(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
n = array.shape[axis] + 1

sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (jnp.sqrt(n) + n)
fill_val = norm - sum_vals / jnp.sqrt(n)

out = jnp.concatenate([array, fill_val], axis=axis)
return out - norm

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
return jnp.zeros_like(x, shape=shape)

def forward_shape(self, shape: tuple) -> tuple:
return shape[: -self.transform_ndims] + tuple(
s + 1 for s in shape[-self.transform_ndims :]
)

def inverse_shape(self, shape: tuple) -> tuple:
return shape[: -self.transform_ndims] + tuple(
s - 1 for s in shape[-self.transform_ndims :]
)

def tree_flatten(self):
aux_data = {
"transform_ndims": self.transform_ndims,
}
return (), ((), aux_data)

def __eq__(self, other):
return (
isinstance(other, ZeroSumTransform)
and self.transform_ndims == other.transform_ndims
)


##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down Expand Up @@ -1530,3 +1618,8 @@ def _transform_to_softplus_lower_cholesky(constraint):
@biject_to.register(constraints.simplex)
def _transform_to_simplex(constraint):
return StickBreakingTransform()


@biject_to.register(constraints.zero_sum)
def _transform_to_zero_sum(constraint):
return ZeroSumTransform(constraint.event_dim)
1 change: 1 addition & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])):
dict(),
),
"open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()),
"zero_sum": T(constraints.zero_sum, (), dict(event_dim=1)),
}

# TODO: BijectorConstraint
Expand Down

0 comments on commit 68eb218

Please sign in to comment.