Skip to content

Commit

Permalink
Address google/jax#19885 for numpyro. (#1743)
Browse files Browse the repository at this point in the history
* Address google/jax#19885 for numpyro.

* Implement function to add constant or batch of vectors to diagonal.

* Use `add_diag` helper function in `distributions` module.

* Move `add_diag` to `distributions.util` module.
  • Loading branch information
tillahoffmann committed Feb 28, 2024
1 parent 00f43d0 commit e6c187c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 21 deletions.
22 changes: 7 additions & 15 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
SigmoidTransform,
)
from numpyro.distributions.util import (
add_diag,
betainc,
betaincinv,
cholesky_of_inverse,
Expand Down Expand Up @@ -1081,10 +1082,7 @@ def _onion(self, key, size):
# correct the diagonal
# NB: beta_sample = sum(w ** 2) because norm 2 of u is 1.
diag = jnp.ones(cholesky.shape[:-1]).at[..., 1:].set(jnp.sqrt(1 - beta_sample))
cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(
self.dimension
)
return cholesky
return add_diag(cholesky, diag)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down Expand Up @@ -1860,7 +1858,7 @@ def _batch_capacitance_tril(W, D):
Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
K = jnp.matmul(Wt_Dinv, W)
# could be inefficient
return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1])))
return jnp.linalg.cholesky(add_diag(K, 1))


def _batch_lowrank_logdet(W, D, capacitance_tril):
Expand Down Expand Up @@ -1957,17 +1955,15 @@ def scale_tril(self):
cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1)
Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2))
K = jnp.add(K, jnp.identity(K.shape[-1]))
K = add_diag(K, 1)
scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K)
return scale_tril

@lazy_property
def covariance_matrix(self):
# TODO: find a better solution to create a diagonal matrix
new_diag = self.cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
covariance_matrix = new_diag + jnp.matmul(
covariance_matrix = add_diag(jnp.matmul(
self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)
)
), self.cov_diag)
return covariance_matrix

@lazy_property
Expand All @@ -1979,12 +1975,8 @@ def precision_matrix(self):
self.cov_diag, axis=-2
)
A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
# TODO: find a better solution to create a diagonal matrix
inverse_cov_diag = jnp.reciprocal(self.cov_diag)
diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity(
self.loc.shape[-1]
)
return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A)
return add_diag(- jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down
5 changes: 3 additions & 2 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from numpyro.distributions import constraints
from numpyro.distributions.util import (
add_diag,
matrix_to_tril_vec,
signed_stick_breaking_tril,
sum_rightmost,
Expand Down Expand Up @@ -753,7 +754,7 @@ def __call__(self, x):
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
diag = jnp.exp(x[..., -n:])
return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n)
return add_diag(z, diag)

def _inverse(self, y):
z = matrix_to_tril_vec(y, diagonal=-1)
Expand Down Expand Up @@ -792,7 +793,7 @@ def __call__(self, x):
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
diag = softplus(x[..., -n:])
return (z + jnp.identity(n)) * diag[..., None]
return add_diag(z, 1) * diag[..., None]

def _inverse(self, y):
diag = jnp.diagonal(y, axis1=-2, axis2=-1)
Expand Down
10 changes: 9 additions & 1 deletion numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def signed_stick_breaking_tril(t):
z1m_cumprod_sqrt_shifted = jnp.pad(
z1m_cumprod_sqrt[..., :-1], pad_width, mode="constant", constant_values=1.0
)
y = (r + jnp.identity(r.shape[-1])) * z1m_cumprod_sqrt_shifted
y = add_diag(r, 1) * z1m_cumprod_sqrt_shifted
return y


Expand Down Expand Up @@ -680,3 +680,11 @@ def wrapper(self, *args, **kwargs):
return log_prob

return wrapper


def add_diag(matrix: jnp.ndarray, diag: jnp.ndarray) -> jnp.ndarray:
"""
Add `diag` to the trailing diagonal of `matrix`.
"""
idx = jnp.arange(matrix.shape[-1])
return matrix.at[..., idx, idx].add(diag)
58 changes: 55 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from itertools import product
import math
import os
from typing import Callable

import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
Expand Down Expand Up @@ -3070,9 +3071,11 @@ def sample(d: dist.Distribution):

for in_axes, out_axes in in_out_axes_cases:
batched_params = [
jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg)
if isinstance(ax, int)
else arg
(
jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg)
if isinstance(ax, int)
else arg
)
for arg, ax in zip(params, in_axes)
]
# Recreate the jax_dist to avoid side effects coming from `d.sample`
Expand Down Expand Up @@ -3169,3 +3172,52 @@ def test_sample_truncated_normal_in_tail():
def test_jax_custom_prng():
samples = dist.Normal(0, 5).sample(random.PRNGKey(0), sample_shape=(1000,))
assert ~jnp.isinf(samples).any()


def _assert_not_jax_issue_19885(
capfd: pytest.CaptureFixture, func: Callable, *args, **kwargs
) -> None:
# jit-ing identity plus matrix multiplication leads to performance degradation as
# discussed in https://github.com/google/jax/issues/19885. This assertion verifies
# that the issue does not affect perforance in numpyro.
for jit in [True, False]:
result = jax.jit(func)(*args, **kwargs)
block_until_ready = getattr(result, "block_until_ready", None)
if block_until_ready:
result = block_until_ready()
_, err = capfd.readouterr()
assert (
"MatMul reference implementation being executed" not in err
), f"jit: {jit}"
return result


@pytest.mark.xfail
def test_jax_issue_19885(capfd: pytest.CaptureFixture) -> None:
def func_with_warning(y) -> jnp.ndarray:
return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)

_assert_not_jax_issue_19885(capfd, func_with_warning, jnp.ones((20, 100, 100)))


def test_lowrank_mvn_19885(capfd: pytest.CaptureFixture) -> None:
# Create parameters.
batch_size = 100
event_size = 200
sample_size = 40
rank = 40
loc, cov_diag = random.normal(random.key(0), (2, batch_size, event_size))
cov_diag = jnp.exp(cov_diag)
cov_factor = random.normal(random.key(1), (batch_size, event_size, rank))

distribution = _assert_not_jax_issue_19885(
capfd, dist.LowRankMultivariateNormal, loc, cov_factor, cov_diag
)
x = _assert_not_jax_issue_19885(
capfd,
lambda x: distribution.sample(random.key(0), x.shape),
jnp.empty(sample_size),
)
assert x.shape == (sample_size, batch_size, event_size)
log_prob = _assert_not_jax_issue_19885(capfd, distribution.log_prob, x)
assert log_prob.shape == (sample_size, batch_size)
18 changes: 18 additions & 0 deletions test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from jax.scipy.special import expit, xlog1py, xlogy

from numpyro.distributions.util import (
add_diag,
binary_cross_entropy_with_logits,
binomial,
categorical,
Expand Down Expand Up @@ -164,3 +165,20 @@ def test_safe_normalize(dim):
data = jnp.zeros((10, dim))
x = safe_normalize(data)
assert_allclose((x * x).sum(-1), jnp.ones(x.shape[:-1]), rtol=1e-6)


@pytest.mark.parametrize(
"matrix_shape, diag_shape",
[
((5, 5), ()),
((7, 7), (7,)),
((10, 3, 3), (10, 3)),
((7, 5, 9, 9), (5, 1)),
],
)
def test_add_diag(matrix_shape: tuple, diag_shape: tuple) -> None:
matrix = random.normal(random.key(0), matrix_shape)
diag = random.normal(random.key(1), diag_shape)
expected = matrix + diag[..., None] * jnp.eye(matrix.shape[-1])
actual = add_diag(matrix, diag)
np.testing.assert_allclose(actual, expected)

0 comments on commit e6c187c

Please sign in to comment.