Skip to content

Commit

Permalink
Add RecursiveLinearTransform for linear state space models. (#1766)
Browse files Browse the repository at this point in the history
* Format reparam module to comply with style guide.

* Add `RealFastFourierTransform` to documentation.

* Ignore `venv` directory for `update_headers.py` script.

* Ignore autogenerated documentation sources.

* Add numerical Jacobian check for bijective transforms.

* Add `RecursiveLinearTransform`.

* Use matrix multiplication operator and fix Jacobian.

* Use non-trivial transition matrix in test.

* Specify that transition matrices must (batches of) square matrices.

* Fix `scan` implementation for batched transition matrices and add test.
  • Loading branch information
tillahoffmann committed Mar 25, 2024
1 parent 4c2a559 commit ad6861a
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ numpyro/examples/.data
# docs
docs/build
docs/.DS_Store
docs/source/examples
docs/source/tutorials
docs/source/getting_started.rst
20 changes: 19 additions & 1 deletion docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ EulerMaruyama
:undoc-members:
:show-inheritance:
:member-order: bysource

Exponential
^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.Exponential
Expand Down Expand Up @@ -948,6 +948,24 @@ PowerTransform
:show-inheritance:
:member-order: bysource

RealFastFourierTransform
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: numpyro.distributions.transforms.RealFastFourierTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

RecursiveLinearTransform
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: numpyro.distributions.transforms.RecursiveLinearTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

ScaledUnitLowerCholeskyTransform
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.ScaledUnitLowerCholeskyTransform
Expand Down
92 changes: 92 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,98 @@ def __eq__(self, other):
)


class RecursiveLinearTransform(Transform):
"""
Apply a linear transformation recursively such that
:math:`y_t = A y_{t - 1} + x_t` for :math:`t > 0`, where :math:`x_t` and :math:`y_t`
are vectors and :math:`A` is a square transition matrix. The series is initialized
by :math:`y_0 = 0`.
:param transition_matrix: Squared transition matrix :math:`A` for successive states
or a batch of transition matrices.
**Example:**
.. doctest::
>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import distributions as dist
>>>
>>> def cauchy_random_walk():
... return numpyro.sample(
... "x",
... dist.TransformedDistribution(
... dist.Cauchy(0, 1).expand([10, 1]).to_event(1),
... dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
... ),
... )
>>>
>>> numpyro.handlers.seed(cauchy_random_walk, 0)().shape
(10, 1)
>>>
>>> def rocket_trajectory():
... scale = numpyro.sample(
... "scale",
... dist.HalfCauchy(1).expand([2]).to_event(1),
... )
... transition_matrix = jnp.array([[1, 1], [0, 1]])
... return numpyro.sample(
... "x",
... dist.TransformedDistribution(
... dist.Normal(0, scale).expand([10, 2]).to_event(1),
... dist.transforms.RecursiveLinearTransform(transition_matrix),
... ),
... )
>>>
>>> numpyro.handlers.seed(rocket_trajectory, 0)().shape
(10, 2)
"""

domain = constraints.real_matrix
codomain = constraints.real_matrix

def __init__(self, transition_matrix: jnp.ndarray) -> None:
self.transition_matrix = transition_matrix

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# Move the time axis to the first position so we can scan over it.
x = jnp.moveaxis(x, -2, 0)

def f(y, x):
y = jnp.einsum("...ij,...j->...i", self.transition_matrix, y) + x
return y, y

_, y = lax.scan(f, jnp.zeros_like(x, shape=x.shape[1:]), x)
return jnp.moveaxis(y, 0, -2)

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
# Move the time axis to the first position so we can scan over it in reverse.
y = jnp.moveaxis(y, -2, 0)

def f(y, prev):
x = y - jnp.einsum("...ij,...j->...i", self.transition_matrix, prev)
return prev, x

_, x = lax.scan(f, y[-1], jnp.roll(y, 1, axis=0).at[0].set(0), reverse=True)
return jnp.moveaxis(x, 0, -2)

def log_abs_det_jacobian(self, x: jnp.ndarray, y: jnp.ndarray, intermediates=None):
return jnp.zeros_like(x, shape=x.shape[:-2])

def tree_flatten(self):
return (self.transition_matrix,), (
("transition_matrix",),
{},
)

def __eq__(self, other):
if not isinstance(other, RecursiveLinearTransform):
return False
return jnp.array_equal(self.transition_matrix, other.transition_matrix)


##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down
2 changes: 1 addition & 1 deletion scripts/update_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys

root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
blacklist = ["/build/", "/dist/", "/pyro_api.egg"]
blacklist = ["/build/", "/dist/", "/pyro_api.egg", "/venv/"]
file_types = [("*.py", "# {}"), ("*.cpp", "// {}")]

parser = argparse.ArgumentParser()
Expand Down
12 changes: 12 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3222,3 +3222,15 @@ def test_lowrank_mvn_19885(capfd: pytest.CaptureFixture) -> None:
assert x.shape == (sample_size, batch_size, event_size)
log_prob = _assert_not_jax_issue_19885(capfd, distribution.log_prob, x)
assert log_prob.shape == (sample_size, batch_size)


def test_gaussian_random_walk_linear_recursive_equivalence():
dist1 = dist.GaussianRandomWalk(3.7, 15)
dist2 = dist.TransformedDistribution(
dist.Normal(0, 3.7).expand([15, 1]).to_event(2),
dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
)
x1 = dist1.sample(random.PRNGKey(7))
x2 = dist2.sample(random.PRNGKey(7))
assert jnp.allclose(x1, x2.squeeze())
assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))
42 changes: 40 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from collections import namedtuple
from functools import partial
import math

import pytest

from jax import jit, random, tree_map, vmap
from jax import jacfwd, jit, random, tree_map, vmap
import jax.numpy as jnp

from numpyro.distributions.flows import (
Expand All @@ -30,6 +31,7 @@
PermuteTransform,
PowerTransform,
RealFastFourierTransform,
RecursiveLinearTransform,
ReshapeTransform,
ScaledUnitLowerCholeskyTransform,
SigmoidTransform,
Expand Down Expand Up @@ -90,6 +92,11 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
(),
dict(transform_shape=(3, 4, 5), transform_ndims=3),
),
"recursive_linear": T(
RecursiveLinearTransform,
(jnp.eye(5),),
dict(),
),
"simplex_to_ordered": T(
SimplexToOrderedTransform,
(_a(1.0),),
Expand Down Expand Up @@ -277,6 +284,10 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims):
(PowerTransform(2.5), ()),
(RealFastFourierTransform(7), (7,)),
(RealFastFourierTransform((8, 9), 2), (8, 9)),
(
RecursiveLinearTransform(random.normal(random.key(17), (4, 4))),
(7, 4),
),
(ReshapeTransform((5, 2), (10,)), (10,)),
(ReshapeTransform((15,), (3, 5)), (3, 5)),
(ScaledUnitLowerCholeskyTransform(), (6,)),
Expand Down Expand Up @@ -312,4 +323,31 @@ def test_bijective_transforms(transform, shape):
atol = 1e-2
assert jnp.allclose(x1, x2, atol=atol)

assert transform.log_abs_det_jacobian(x1, y).shape == batch_shape
log_abs_det_jacobian = transform.log_abs_det_jacobian(x1, y)
assert log_abs_det_jacobian.shape == batch_shape

# Also check the Jacobian numerically for transforms with the same input and output
# size, unless they are explicitly excluded. E.g., the upper triangular of the
# CholeskyTransform is zero, giving rise to a singular Jacobian.
skip_jacobian_check = (CholeskyTransform,)
size_x = int(x1.size / math.prod(batch_shape))
size_y = int(y.size / math.prod(batch_shape))
if size_x == size_y and not isinstance(transform, skip_jacobian_check):
jac = (
vmap(jacfwd(transform))(x1)
.reshape((-1,) + x1.shape[len(batch_shape) :])
.reshape(batch_shape + (size_y, size_x))
)
slogdet = jnp.linalg.slogdet(jac)
assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol)


def test_batched_recursive_linear_transform():
batch_shape = (4, 17)
x = random.normal(random.key(8), batch_shape + (10, 3))
# Get a batch of matrices with eigenvalues that don't blow up the sequence.
A = CorrCholeskyTransform()(random.normal(random.key(7), batch_shape + (3,)))
transform = RecursiveLinearTransform(A)
y = transform(x)
assert y.shape == x.shape
assert jnp.allclose(x, transform.inv(y), atol=1e-6)

0 comments on commit ad6861a

Please sign in to comment.