Skip to content

Commit

Permalink
Intercept UserWarning on JAX random function tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 24, 2023
1 parent 93bfa1b commit 53b00ea
Showing 1 changed file with 45 additions and 51 deletions.
96 changes: 45 additions & 51 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import re

import numpy as np
import pytest
import scipy.stats as stats
Expand All @@ -22,6 +20,13 @@
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402


def random_function(*args, **kwargs):
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
return function(*args, **kwargs)


def test_random_RandomStream():
"""Two successive calls of a compiled graph using `RandomStream` should
return different values.
Expand All @@ -30,11 +35,7 @@ def test_random_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()

with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
fn = random_function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()

Expand All @@ -47,13 +48,7 @@ def test_random_updates(rng_ctor):
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs

with pytest.warns(
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()

# Check that original rng variable content was not overwritten when calling jax_typify
Expand Down Expand Up @@ -83,17 +78,14 @@ def test_random_updates_input_storage_order():

# This function replaces inp by input_shared in the update expression
# This is what caused the RNG to appear later than inp_shared in the input_storage
with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = pytensor.function(
inputs=[],
outputs=[],
updates={inp_shared: inp_update},
givens={inp: inp_shared},
mode="JAX",
)

fn = random_function(
inputs=[],
outputs=[],
updates={inp_shared: inp_update},
givens={inp: inp_shared},
mode="JAX",
)
fn()
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
fn()
Expand Down Expand Up @@ -457,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
else:
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
g_fn = function(dist_params, g, mode=jax_mode)
g_fn = random_function(dist_params, g, mode=jax_mode)
samples = g_fn(
*[
i.tag.test_value
Expand All @@ -481,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

Expand All @@ -492,7 +484,7 @@ def test_random_mvnormal():
mu = np.ones(4)
cov = np.eye(4)
g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)

Expand All @@ -507,7 +499,7 @@ def test_random_mvnormal():
def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123))
g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

Expand All @@ -517,29 +509,29 @@ def test_random_choice():
num_samples = 10000
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(4), size=num_samples, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)

# `replace=False` produces unique results
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
assert len(np.unique(samples)) == 99

# We can pass an array with probabilities
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10))


def test_random_categorical():
rng = shared(np.random.RandomState(123))
g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)

Expand All @@ -548,7 +540,7 @@ def test_random_permutation():
array = np.arange(4)
rng = shared(np.random.RandomState(123))
g = at.random.permutation(array, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
permuted = g_fn()
with pytest.raises(AssertionError):
np.testing.assert_allclose(array, permuted)
Expand All @@ -558,7 +550,7 @@ def test_random_geometric():
rng = shared(np.random.RandomState(123))
p = np.array([0.3, 0.7])
g = at.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
Expand All @@ -569,7 +561,7 @@ def test_negative_binomial():
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -583,7 +575,7 @@ def test_binomial():
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = at.random.binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
Expand All @@ -598,7 +590,7 @@ def test_beta_binomial():
a = np.array([1.5, 13])
b = np.array([0.5, 9])
g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -616,7 +608,7 @@ def test_multinomial():
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -632,7 +624,7 @@ def test_vonmises_mu_outside_circle():
mu = np.array([-30, 40])
kappa = np.array([100, 10])
g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
Expand Down Expand Up @@ -678,7 +670,10 @@ def rng_fn(cls, rng, size):
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)

with pytest.raises(NotImplementedError):
compare_jax_and_py(fgraph, [])
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])


def test_random_custom_implementation():
Expand Down Expand Up @@ -709,7 +704,10 @@ def sample_fn(rng, size, dtype, *parameters):
rng = shared(np.random.RandomState(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])


def test_random_concrete_shape():
Expand All @@ -726,19 +724,15 @@ def test_random_concrete_shape():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape, rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(x_at, 1, rng=rng)
with pytest.warns(
UserWarning,
match="The RandomType SharedVariables \[.+\] will not be used"
):
jax_fn = function([x_at], out, mode=jax_mode)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


Expand All @@ -757,7 +751,7 @@ def test_random_concrete_shape_subtensor():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (3,)


Expand All @@ -773,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,)


Expand All @@ -784,5 +778,5 @@ def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123))
size_at = at.scalar()
out = at.random.normal(0, 1, size=size_at, rng=rng)
jax_fn = function([size_at], out, mode=jax_mode)
jax_fn = random_function([size_at], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)

0 comments on commit 53b00ea

Please sign in to comment.