Skip to content

Commit

Permalink
Verify result of biject_to satisfies the constraint. (#1770)
Browse files Browse the repository at this point in the history
* Verify result of `biject_to` satisfies the constraint.

* Replace global `jax.numpy` arrays by `numpy` arrays.
  • Loading branch information
tillahoffmann committed Mar 30, 2024
1 parent 68eb218 commit 4dbc625
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 18 deletions.
4 changes: 2 additions & 2 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ class _CorrMatrix(_SingletonConstraint):
def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
# check for symmetric
symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1)
symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1))
# check for the smallest eigenvalue is positive
positive = jnp.linalg.eigh(x)[0][..., 0] > 0
positive = jnp.linalg.eigvalsh(x)[..., 0] > 0
# check for diagonal equal to 1
unit_variance = jnp.all(
jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1
Expand Down
67 changes: 51 additions & 16 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from functools import partial
import math

import numpy as np
import pytest

from jax import jacfwd, jit, random, tree_map, vmap
import jax.numpy as jnp

from numpyro.distributions import constraints
from numpyro.distributions.flows import (
BlockNeuralAutoregressiveTransform,
InverseAutoregressiveTransform,
Expand Down Expand Up @@ -49,9 +51,6 @@ def _unpack(x):
return (x,)


_a = jnp.asarray


def _smoke_neural_network():
return None, None

Expand All @@ -61,31 +60,29 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):


TRANSFORMS = {
"affine": T(
AffineTransform, (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), dict()
),
"affine": T(AffineTransform, (np.array([1.0, 2.0]), np.array([3.0, 4.0])), dict()),
"compose": T(
ComposeTransform,
(
[
AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),
AffineTransform(np.array([1.0, 2.0]), np.array([3.0, 4.0])),
ExpTransform(),
],
),
dict(),
),
"independent": T(
IndependentTransform,
(AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),),
(AffineTransform(np.array([1.0, 2.0]), np.array([3.0, 4.0])),),
dict(reinterpreted_batch_ndims=1),
),
"lower_cholesky_affine": T(
LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)), dict()
LowerCholeskyAffine, (np.array([1.0, 2.0]), np.eye(2)), dict()
),
"permute": T(PermuteTransform, (jnp.array([1, 0]),), dict()),
"permute": T(PermuteTransform, (np.array([1, 0]),), dict()),
"power": T(
PowerTransform,
(_a(2.0),),
(np.array(2.0),),
dict(),
),
"rfft": T(
Expand All @@ -95,12 +92,12 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
),
"recursive_linear": T(
RecursiveLinearTransform,
(jnp.eye(5),),
(np.eye(5),),
dict(),
),
"simplex_to_ordered": T(
SimplexToOrderedTransform,
(_a(1.0),),
(np.array(1.0),),
dict(),
),
"unpack": T(UnpackTransform, (), dict(unpack_fn=_unpack)),
Expand All @@ -124,7 +121,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
# autoregressive_nn is a non-jittable arg, which does not fit well with
# the current test pipeline, which assumes jittable args, and non-jittable kwargs
partial(InverseAutoregressiveTransform, _smoke_neural_network),
(_a(-1.0), _a(1.0)),
(np.array(-1.0), np.array(1.0)),
dict(),
),
"bna": T(
Expand Down Expand Up @@ -279,10 +276,10 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims):
(IdentityTransform(), ()),
(IndependentTransform(ExpTransform(), 2), (3, 4)),
(L1BallTransform(), (9,)),
(LowerCholeskyAffine(jnp.ones(3), jnp.eye(3)), (3,)),
(LowerCholeskyAffine(np.ones(3), np.eye(3)), (3,)),
(LowerCholeskyTransform(), (10,)),
(OrderedTransform(), (5,)),
(PermuteTransform(jnp.roll(jnp.arange(7), 2)), (7,)),
(PermuteTransform(np.roll(np.arange(7), 2)), (7,)),
(PowerTransform(2.5), ()),
(RealFastFourierTransform(7), (7,)),
(RealFastFourierTransform((8, 9), 2), (8, 9)),
Expand Down Expand Up @@ -354,3 +351,41 @@ def test_batched_recursive_linear_transform():
y = transform(x)
assert y.shape == x.shape
assert jnp.allclose(x, transform.inv(y), atol=1e-6)


@pytest.mark.parametrize(
"constraint, shape",
[
(constraints.circular, (3,)),
(constraints.complex, (3,)),
(constraints.corr_cholesky, (10, 10)),
(constraints.corr_matrix, (21,)),
(constraints.greater_than(3), ()),
(constraints.interval(8, 13), (17,)),
(constraints.l1_ball, (4,)),
(constraints.less_than(-1), ()),
(constraints.lower_cholesky, (21,)),
(constraints.open_interval(3, 4), ()),
(constraints.ordered_vector, (5,)),
(constraints.positive_definite, (6,)),
(constraints.positive_ordered_vector, (7,)),
(constraints.positive, (7,)),
(constraints.real_matrix, (17,)),
(constraints.real_vector, (18,)),
(constraints.real, (3,)),
(constraints.scaled_unit_lower_cholesky, (21,)),
(constraints.simplex, (3,)),
(constraints.softplus_lower_cholesky, (21,)),
(constraints.softplus_positive, (2,)),
(constraints.unit_interval, (4,)),
],
ids=str,
)
def test_biject_to(constraint, shape):
batch_shape = (13, 19)
unconstrained = random.normal(random.key(93), batch_shape + shape)
constrained = biject_to(constraint)(unconstrained)
passed = constraint.check(constrained)
expected_shape = constrained.shape[: constrained.ndim - constraint.event_dim]
assert passed.shape == expected_shape
assert jnp.all(passed)

0 comments on commit 4dbc625

Please sign in to comment.