Skip to content

Commit

Permalink
Add complex constraint and real Fourier transform. (#1762)
Browse files Browse the repository at this point in the history
* Add `complex` constraint.

* Add real fast Fourier transform.

* Remove redundant domain and codomain definitions.

* Use numpy for `isreal` check.

* Return `log_abs_det_jacobian` with correct batch shape and add test.

* Update parameter names for `RealFastFourierTransform`.

* Remove `isreal` check.

* Broadcast shapes in `log_abs_det_jacobian`.

* Update construction of domain and codomain for `RealFastFourierTransform`.

* Fix incorrect Jacobian.
  • Loading branch information
tillahoffmann committed Mar 16, 2024
1 parent c988ad0 commit 136f7c0
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 1 deletion.
11 changes: 11 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
__all__ = [
"boolean",
"circular",
"complex",
"corr_cholesky",
"corr_matrix",
"dependent",
Expand Down Expand Up @@ -629,6 +630,15 @@ def feasible_like(self, prototype):
)


class _Complex(_SingletonConstraint):
def __call__(self, x):
# XXX: consider to relax this condition to [-inf, inf] interval
return (x == x) & (x != float("inf")) & (x != float("-inf"))

def feasible_like(self, prototype):
return jax.numpy.zeros_like(prototype)


class _Real(_SingletonConstraint):
def __call__(self, x):
# XXX: consider to relax this condition to [-inf, inf] interval
Expand Down Expand Up @@ -692,6 +702,7 @@ def feasible_like(self, prototype):

boolean = _Boolean()
circular = _Circular()
complex = _Complex()
corr_cholesky = _CorrCholesky()
corr_matrix = _CorrMatrix()
dependent = _Dependent()
Expand Down
88 changes: 87 additions & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"LowerCholeskyAffine",
"PermuteTransform",
"PowerTransform",
"RealFastFourierTransform",
"ReshapeTransform",
"SigmoidTransform",
"SimplexToOrderedTransform",
Expand Down Expand Up @@ -1190,7 +1191,7 @@ def _inverse(self, y):
return jnp.reshape(y, self.inverse_shape(jnp.shape(y)))

def log_abs_det_jacobian(self, x, y, intermediates=None):
return 0.0
return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)])

def tree_flatten(self):
aux_data = {
Expand All @@ -1207,6 +1208,86 @@ def __eq__(self, other):
)


def _normalize_rfft_shape(input_shape, shape):
if shape is None:
return input_shape
return input_shape[: len(input_shape) - len(shape)] + shape


class RealFastFourierTransform(Transform):
"""
N-dimensional discrete fast Fourier transform for real input.
:param transform_shape: Length of each transformed axis to use from the input,
defaults to the input size.
:param transform_ndims: Number of trailing dimensions to transform.
"""

def __init__(
self,
transform_shape=None,
transform_ndims=1,
) -> None:
if isinstance(transform_shape, int):
transform_shape = (transform_shape,)
if transform_shape is not None and len(transform_shape) != transform_ndims:
raise ValueError(
f"Length of transform shape ({transform_shape}) does not match number "
f"of dimensions to transform ({transform_ndims})."
)
self.transform_shape = transform_shape
self.transform_ndims = transform_ndims

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
axes = tuple(range(-self.transform_ndims, 0))
return jnp.fft.rfftn(x, self.transform_shape, axes)

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
axes = tuple(range(-self.transform_ndims, 0))
return jnp.fft.irfftn(y, self.transform_shape, axes)

def forward_shape(self, shape: tuple) -> tuple:
# Dimensions remain unchanged except the last transformed dimension.
shape = _normalize_rfft_shape(shape, self.transform_shape)
return shape[:-1] + (shape[-1] // 2 + 1,)

def inverse_shape(self, shape: tuple) -> tuple:
if self.transform_shape:
return _normalize_rfft_shape(shape, self.transform_shape)
size = 2 * (shape[-1] - 1)
return shape[:-1] + (size,)

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
return jnp.zeros_like(x, shape=shape)

def tree_flatten(self):
aux_data = {
"transform_shape": self.transform_shape,
"transform_ndims": self.transform_ndims,
}
return (), ((), aux_data)

@property
def domain(self) -> constraints.Constraint:
return constraints.independent(constraints.real, self.transform_ndims)

@property
def codomain(self) -> constraints.Constraint:
return constraints.independent(constraints.complex, self.transform_ndims)

def __eq__(self, other):
return (
isinstance(other, RealFastFourierTransform)
and self.transform_ndims == other.transform_ndims
and self.transform_shape == other.transform_shape
)


##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down Expand Up @@ -1334,6 +1415,11 @@ def _transform_to_positive_ordered_vector(constraint):
return ComposeTransform([OrderedTransform(), ExpTransform()])


@biject_to.register(constraints.complex)
def _transform_to_complex(constraint):
return IdentityTransform()


@biject_to.register(constraints.real)
def _transform_to_real(constraint):
return IdentityTransform()
Expand Down
1 change: 1 addition & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SINGLETON_CONSTRAINTS = {
"boolean": constraints.boolean,
"circular": constraints.circular,
"complex": constraints.complex,
"corr_cholesky": constraints.corr_cholesky,
"corr_matrix": constraints.corr_matrix,
"l1_ball": constraints.l1_ball,
Expand Down
85 changes: 85 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
OrderedTransform,
PermuteTransform,
PowerTransform,
RealFastFourierTransform,
ReshapeTransform,
ScaledUnitLowerCholeskyTransform,
SigmoidTransform,
Expand All @@ -37,6 +38,7 @@
SoftplusTransform,
StickBreakingTransform,
UnpackTransform,
biject_to,
)


Expand Down Expand Up @@ -83,6 +85,11 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
(_a(2.0),),
dict(),
),
"rfft": T(
RealFastFourierTransform,
(),
dict(transform_shape=(3, 4, 5), transform_ndims=3),
),
"simplex_to_ordered": T(
SimplexToOrderedTransform,
(_a(1.0),),
Expand Down Expand Up @@ -228,3 +235,81 @@ def test_reshape_transform_invalid():

with pytest.raises(TypeError, match="cannot reshape array"):
ReshapeTransform((2, 3), (6,))(jnp.arange(2))


@pytest.mark.parametrize(
"input_shape, shape, ndims",
[
((10,), None, 1),
((11,), 11, 1),
((10, 18), None, 2),
((10, 19), (7, 8), 2),
],
)
def test_real_fast_fourier_transform(input_shape, shape, ndims):
x1 = random.normal(random.key(17), input_shape)
transform = RealFastFourierTransform(shape, ndims)
y = transform(x1)
assert transform.codomain(y).all()
assert y.shape == transform.forward_shape(x1.shape)
x2 = transform.inv(y)
assert transform.domain(x2).all()
if x1.shape == x2.shape:
assert jnp.allclose(x2, x1, atol=1e-6)


@pytest.mark.parametrize(
"transform, shape",
[
(AffineTransform(3, 2.5), ()),
(CholeskyTransform(), (10,)),
(ComposeTransform([SoftplusTransform(), SigmoidTransform()]), ()),
(CorrCholeskyTransform(), (15,)),
(CorrMatrixCholeskyTransform(), (15,)),
(ExpTransform(), ()),
(IdentityTransform(), ()),
(IndependentTransform(ExpTransform(), 2), (3, 4)),
(L1BallTransform(), (9,)),
(LowerCholeskyAffine(jnp.ones(3), jnp.eye(3)), (3,)),
(LowerCholeskyTransform(), (10,)),
(OrderedTransform(), (5,)),
(PermuteTransform(jnp.roll(jnp.arange(7), 2)), (7,)),
(PowerTransform(2.5), ()),
(RealFastFourierTransform(7), (7,)),
(RealFastFourierTransform((8, 9), 2), (8, 9)),
(ReshapeTransform((5, 2), (10,)), (10,)),
(ReshapeTransform((15,), (3, 5)), (3, 5)),
(ScaledUnitLowerCholeskyTransform(), (6,)),
(SigmoidTransform(), ()),
(SimplexToOrderedTransform(), (5,)),
(SoftplusLowerCholeskyTransform(), (10,)),
(SoftplusTransform(), ()),
(StickBreakingTransform(), (11,)),
],
)
def test_bijective_transforms(transform, shape):
if isinstance(transform, type):
pytest.skip()
# Get a sample from the support of the distribution.
batch_shape = (13,)
unconstrained = random.normal(random.key(17), batch_shape + shape)
x1 = biject_to(transform.domain)(unconstrained)

# Transform forward and backward, checking shapes, values, and Jacobian shape.
y = transform(x1)
assert y.shape == transform.forward_shape(x1.shape)

x2 = transform.inv(y)
assert x2.shape == transform.inverse_shape(y.shape)
# Some transforms are a bit less stable; we give them larger tolerances.
atol = 1e-6
less_stable_transforms = (
CorrCholeskyTransform,
L1BallTransform,
StickBreakingTransform,
)
if isinstance(transform, less_stable_transforms):
atol = 1e-2
assert jnp.allclose(x1, x2, atol=atol)

assert transform.log_abs_det_jacobian(x1, y).shape == batch_shape

0 comments on commit 136f7c0

Please sign in to comment.