228 changes: 226 additions & 2 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest
import scipy.stats as st

from pytensor import scan
from pytensor import scan, shared
from pytensor.tensor import TensorVariable

import pymc as pm
Expand All @@ -42,14 +42,16 @@
CustomDist,
CustomDistRV,
CustomSymbolicDistRV,
PartialObservedRV,
SymbolicRandomVariable,
_moment,
create_partial_observed_rv,
moment,
)
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
from pymc.distributions.transforms import log
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.basic import logcdf, logp
from pymc.logprob.basic import conditional_logp, logcdf, logp
from pymc.model import Deterministic, Model
from pymc.pytensorf import collect_default_updates
from pymc.sampling import draw, sample
Expand Down Expand Up @@ -700,3 +702,225 @@ def test_dtype(self, floatX):
assert pm.DiracDelta.dist(2**16).dtype == "int32"
assert pm.DiracDelta.dist(2**32).dtype == "int64"
assert pm.DiracDelta.dist(2.0).dtype == floatX


class TestPartialObservedRV:
@pytest.mark.parametrize("symbolic_rv", (False, True))
def test_univariate(self, symbolic_rv):
data = np.array([0.25, 0.5, 0.25])
mask = np.array([False, False, True])

rv = pm.Normal.dist([1, 2, 3])
if symbolic_rv:
# We use a Censored Normal so that PartialObservedRV is needed,
# but don't use the bounds for testing the logp
rv = pm.Censored.dist(rv, lower=-100, upper=100)
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
if symbolic_rv:
assert isinstance(obs_rv.owner.op, PartialObservedRV)
assert isinstance(unobs_rv.owner.op, PartialObservedRV)
else:
assert isinstance(obs_rv.owner.op, Normal)
assert isinstance(unobs_rv.owner.op, Normal)

# Tesh shapes
assert tuple(obs_rv.shape.eval()) == (2,)
assert tuple(unobs_rv.shape.eval()) == (1,)
assert tuple(joined_rv.shape.eval()) == (3,)

# Test logp
logp = conditional_logp(
{obs_rv: pt.as_tensor(data[~mask]), unobs_rv: pt.as_tensor(data[mask])}
)
obs_logp, unobs_logp = pytensor.function([], list(logp.values()))()
np.testing.assert_allclose(obs_logp, st.norm([1, 2]).logpdf([0.25, 0.5]))
np.testing.assert_allclose(unobs_logp, st.norm([3]).logpdf([0.25]))

@pytest.mark.parametrize("obs_component_selected", (True, False))
def test_multivariate_constant_mask_separable(self, obs_component_selected):
if obs_component_selected:
mask = np.zeros((1, 4), dtype=bool)
else:
mask = np.ones((1, 4), dtype=bool)
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
assert isinstance(obs_rv.owner.op, pm.Dirichlet)
assert isinstance(unobs_rv.owner.op, pm.Dirichlet)

# Test shapes
if obs_component_selected:
expected_obs_shape = (1, 4)
expected_unobs_shape = (0, 4)
else:
expected_obs_shape = (0, 4)
expected_unobs_shape = (1, 4)
assert tuple(obs_rv.shape.eval()) == expected_obs_shape
assert tuple(unobs_rv.shape.eval()) == expected_unobs_shape
assert tuple(joined_rv.shape.eval()) == (1, 4)

# Test logp
logp = conditional_logp(
{
obs_rv: pt.as_tensor(obs_data)[obs_mask],
unobs_rv: pt.as_tensor(unobs_data)[unobs_mask],
}
)
obs_logp, unobs_logp = pytensor.function([], list(logp.values()))()
if obs_component_selected:
expected_obs_logp = pm.logp(rv, obs_data).eval()
expected_unobs_logp = []
else:
expected_obs_logp = []
expected_unobs_logp = pm.logp(rv, unobs_data).eval()
np.testing.assert_allclose(obs_logp, expected_obs_logp)
np.testing.assert_allclose(unobs_logp, expected_unobs_logp)

def test_multivariate_constant_mask_unseparable(self):
mask = pt.constant(np.array([[True, True, False, False]]))
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
assert isinstance(obs_rv.owner.op, PartialObservedRV)
assert isinstance(unobs_rv.owner.op, PartialObservedRV)

# Test shapes
assert tuple(obs_rv.shape.eval()) == (2,)
assert tuple(unobs_rv.shape.eval()) == (2,)
assert tuple(joined_rv.shape.eval()) == (1, 4)

# Test logp
logp = conditional_logp(
{
obs_rv: pt.as_tensor(obs_data)[obs_mask],
unobs_rv: pt.as_tensor(unobs_data)[unobs_mask],
}
)
obs_logp, unobs_logp = pytensor.function([], list(logp.values()))()

# For non-separable cases the logp always shows up in the observed variable
expected_logp = pm.logp(rv, [[0.1, 0.4, 0.4, 0.1]]).eval()
np.testing.assert_almost_equal(obs_logp, expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

def test_multivariate_shared_mask_separable(self):
mask = shared(np.array([True]))
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
# Multivariate RVs with shared masks on the last component are always unseparable.
assert isinstance(obs_rv.owner.op, pm.Dirichlet)
assert isinstance(unobs_rv.owner.op, pm.Dirichlet)

# Test shapes
assert tuple(obs_rv.shape.eval()) == (0, 4)
assert tuple(unobs_rv.shape.eval()) == (1, 4)
assert tuple(joined_rv.shape.eval()) == (1, 4)

# Test logp
logp = conditional_logp(
{
obs_rv: pt.as_tensor(obs_data)[obs_mask],
unobs_rv: pt.as_tensor(unobs_data)[unobs_mask],
}
)
logp_fn = pytensor.function([], list(logp.values()))
obs_logp, unobs_logp = logp_fn()
expected_logp = pm.logp(rv, unobs_data).eval()
np.testing.assert_almost_equal(obs_logp, [])
np.testing.assert_array_equal(unobs_logp, expected_logp)

# Test that we can update a shared mask
mask.set_value(np.array([False]))
Comment on lines +847 to +848
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯


assert tuple(obs_rv.shape.eval()) == (1, 4)
assert tuple(unobs_rv.shape.eval()) == (0, 4)

new_expected_logp = pm.logp(rv, obs_data).eval()
assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak
obs_logp, unobs_logp = logp_fn()
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

def test_multivariate_shared_mask_unseparable(self):
# Even if the mask is initially not mixing support dims,
# it could later be changed in a way that does!
mask = shared(np.array([[True, True, True, True]]))
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
# Multivariate RVs with shared masks on the last component are always unseparable.
assert isinstance(obs_rv.owner.op, PartialObservedRV)
assert isinstance(unobs_rv.owner.op, PartialObservedRV)

# Test shapes
assert tuple(obs_rv.shape.eval()) == (0,)
assert tuple(unobs_rv.shape.eval()) == (4,)
assert tuple(joined_rv.shape.eval()) == (1, 4)

# Test logp
logp = conditional_logp(
{
obs_rv: pt.as_tensor(obs_data)[obs_mask],
unobs_rv: pt.as_tensor(unobs_data)[unobs_mask],
}
)
logp_fn = pytensor.function([], list(logp.values()))
obs_logp, unobs_logp = logp_fn()
# For non-separable cases the logp always shows up in the observed variable
# Even in this case where all entries come from an unobserved component
expected_logp = pm.logp(rv, unobs_data).eval()
np.testing.assert_almost_equal(obs_logp, expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

# Test that we can update a shared mask
mask.set_value(np.array([[False, False, True, True]]))

assert tuple(obs_rv.shape.eval()) == (2,)
assert tuple(unobs_rv.shape.eval()) == (2,)

new_expected_logp = pm.logp(rv, [0.1, 0.4, 0.4, 0.1]).eval()
assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak
obs_logp, unobs_logp = logp_fn()
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

def test_moment(self):
x = pm.GaussianRandomWalk.dist(init_dist=pm.Normal.dist(-5), mu=1, steps=9)
ref_moment = moment(x).eval()
assert not np.allclose(ref_moment[::2], ref_moment[1::2]) # Otherwise test is weak

(obs_x, _), (unobs_x, _), _ = create_partial_observed_rv(
x, mask=np.array([False, True] * 5)
)
np.testing.assert_allclose(moment(obs_x).eval(), ref_moment[::2])
np.testing.assert_allclose(moment(unobs_x).eval(), ref_moment[1::2])

def test_wrong_mask(self):
rv = pm.Normal.dist(shape=(5,))

invalid_mask = np.array([0, 2, 4])
with pytest.raises(ValueError, match="mask must be an array or tensor of boolean dtype"):
create_partial_observed_rv(rv, invalid_mask)

invalid_mask = np.zeros((1, 5), dtype=bool)
with pytest.raises(ValueError, match="mask can't have more dims than rv"):
create_partial_observed_rv(rv, invalid_mask)
8 changes: 8 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,14 @@ def test_data_mutable_default_warning(self):
assert isinstance(data, pt.TensorConstant)
pass

def test_masked_array_error(self):
with pm.Model():
with pytest.raises(
NotImplementedError,
match="Masked arrays or arrays with `nan` entries are not supported.",
):
pm.ConstantData("x", [0, 1, np.nan, 2])


def test_data_naming():
"""
Expand Down
123 changes: 73 additions & 50 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@
from pymc import Deterministic, Potential
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.distributions import Normal, transforms
from pymc.distributions.transforms import log
from pymc.distributions.distribution import PartialObservedRV
from pymc.distributions.transforms import log, simplex
from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.basic import conditional_logp, transformed_conditional_logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Point, ValueGradFunction, modelcontext
from pymc.testing import SeededTest
Expand Down Expand Up @@ -356,13 +357,13 @@ def test_missing_data(self):
gf = m.logp_dlogp_function()
gf._extra_are_set = True

assert m["x2_missing"].type == gf._extra_vars_shared["x2_missing"].type
assert m["x2_unobserved"].type == gf._extra_vars_shared["x2_unobserved"].type

# The dtype of the merged observed/missing deterministic should match the RV dtype
assert m.deterministics[0].type.dtype == x2.type.dtype

point = m.initial_point(random_seed=None).copy()
del point["x2_missing"]
del point["x2_unobserved"]

res = [gf(DictToArrayBijection.map(Point(point, model=m))) for i in range(5)]

Expand Down Expand Up @@ -565,7 +566,7 @@ def test_make_obs_var():
assert masked_output != fake_distribution
assert not isinstance(masked_output, RandomVariable)
# Ensure it has missing values
assert {"testing_inputs_missing"} == {v.name for v in fake_model.value_vars}
assert {"testing_inputs_unobserved"} == {v.name for v in fake_model.value_vars}
assert {"testing_inputs", "testing_inputs_observed"} == {
v.name for v in fake_model.observed_RVs
}
Expand Down Expand Up @@ -1220,7 +1221,7 @@ def test_missing_basic(self, missing_data):
with pytest.warns(ImputationWarning):
_ = pm.Normal("y", x, 1, observed=missing_data)

assert "y_missing" in model.named_vars
assert "y_unobserved" in model.named_vars

test_point = model.initial_point()
assert not np.isnan(model.compile_logp()(test_point))
Expand All @@ -1237,7 +1238,7 @@ def test_missing_with_predictors(self):
with pytest.warns(ImputationWarning):
y = pm.Normal("y", x * predictors, 1, observed=data)

assert "y_missing" in model.named_vars
assert "y_unobserved" in model.named_vars

test_point = model.initial_point()
assert not np.isnan(model.compile_logp()(test_point))
Expand Down Expand Up @@ -1277,17 +1278,19 @@ def test_interval_missing_observations(self):
with pytest.warns(ImputationWarning):
theta2 = pm.Normal("theta2", mu=theta1, observed=obs2)

assert isinstance(model.rvs_to_transforms[model["theta1_missing"]], IntervalTransform)
assert isinstance(
model.rvs_to_transforms[model["theta1_unobserved"]], IntervalTransform
)
assert model.rvs_to_transforms[model["theta1_observed"]] is None

prior_trace = pm.sample_prior_predictive(random_seed=rng, return_inferencedata=False)
assert set(prior_trace.keys()) == {
"theta1",
"theta1_observed",
"theta1_missing",
"theta1_unobserved",
"theta2",
"theta2_observed",
"theta2_missing",
"theta2_unobserved",
}

# Make sure the observed + missing combined deterministics have the
Expand All @@ -1302,14 +1305,16 @@ def test_interval_missing_observations(self):
# Make sure the missing parts of the combined deterministic matches the
# sampled missing and observed variable values
assert (
np.mean(prior_trace["theta1"][:, obs1.mask] - prior_trace["theta1_missing"]) == 0.0
np.mean(prior_trace["theta1"][:, obs1.mask] - prior_trace["theta1_unobserved"])
== 0.0
)
assert (
np.mean(prior_trace["theta1"][:, ~obs1.mask] - prior_trace["theta1_observed"])
== 0.0
)
assert (
np.mean(prior_trace["theta2"][:, obs2.mask] - prior_trace["theta2_missing"]) == 0.0
np.mean(prior_trace["theta2"][:, obs2.mask] - prior_trace["theta2_unobserved"])
== 0.0
)
assert (
np.mean(prior_trace["theta2"][:, ~obs2.mask] - prior_trace["theta2_observed"])
Expand All @@ -1325,18 +1330,22 @@ def test_interval_missing_observations(self):
)
assert set(trace.varnames) == {
"theta1",
"theta1_missing",
"theta1_missing_interval__",
"theta1_unobserved",
"theta1_unobserved_interval__",
"theta2",
"theta2_missing",
"theta2_unobserved",
}

# Make sure that the missing values are newly generated samples and that
# the observed and deterministic match
assert np.all(0 < trace["theta1_missing"].mean(0))
assert np.all(0 < trace["theta2_missing"].mean(0))
assert np.isclose(np.mean(trace["theta1"][:, obs1.mask] - trace["theta1_missing"]), 0)
assert np.isclose(np.mean(trace["theta2"][:, obs2.mask] - trace["theta2_missing"]), 0)
assert np.all(0 < trace["theta1_unobserved"].mean(0))
assert np.all(0 < trace["theta2_unobserved"].mean(0))
assert np.isclose(
np.mean(trace["theta1"][:, obs1.mask] - trace["theta1_unobserved"]), 0
)
assert np.isclose(
np.mean(trace["theta2"][:, obs2.mask] - trace["theta2_unobserved"]), 0
)

# Make sure that the observed values are unchanged
assert np.allclose(np.var(trace["theta1"][:, ~obs1.mask], 0), 0.0)
Expand Down Expand Up @@ -1377,7 +1386,7 @@ def test_missing_logp1(self):
with pytest.warns(ImputationWarning):
x = pm.Gamma("x", 1, 1, observed=[1, 1, 1, np.nan])

logp_val = m2.compile_logp()({"x_missing_log__": np.array([0])})
logp_val = m2.compile_logp()({"x_unobserved_log__": np.array([0])})
assert logp_val == -4.0

def test_missing_logp2(self):
Expand All @@ -1393,37 +1402,48 @@ def test_missing_logp2(self):
"theta2", mu=theta1, observed=np.array([np.nan, np.nan, 2, np.nan, 4])
)
m_missing_logp = m_missing.compile_logp()(
{"theta1_missing": [2, 4], "theta2_missing": [0, 1, 3]}
{"theta1_unobserved": [2, 4], "theta2_unobserved": [0, 1, 3]}
)

assert m_logp == m_missing_logp

def test_missing_multivariate(self):
"""Test model with missing variables whose transform changes base shape still works"""
def test_missing_multivariate_separable(self):
with pm.Model() as m_miss:
with pytest.warns(ImputationWarning):
x = pm.Dirichlet(
"x",
a=[1, 2, 3],
observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]),
)
assert (m_miss["x_unobserved"].owner.op, pm.Dirichlet)
assert (m_miss["x_observed"].owner.op, pm.Dirichlet)

with pm.Model() as m_unobs:
x = pm.Dirichlet("x", a=[1, 2, 3], shape=(1, 3))

inp_vals = simplex.forward(np.array([[0.3, 0.3, 0.4]])).eval()
np.testing.assert_allclose(
m_miss.compile_logp(jacobian=False)({"x_unobserved_simplex__": inp_vals}),
m_unobs.compile_logp(jacobian=False)({"x_simplex__": inp_vals}) * 2,
)

def test_missing_multivariate_unseparable(self):
with pm.Model() as m_miss:
with pytest.raises(
NotImplementedError,
match="Automatic inputation is only supported for univariate RandomVariables",
):
with pytest.warns(ImputationWarning):
x = pm.Dirichlet(
"x",
a=[1, 2, 3],
observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]),
)

# TODO: Test can be used when local_subtensor_rv_lift supports multivariate distributions
# from pymc.distributions.transforms import simplex
#
# with pm.Model() as m_unobs:
# x = pm.Dirichlet("x", a=[1, 2, 3])
#
# inp_vals = simplex.forward(np.array([0.3, 0.3, 0.4])).eval()
# assert np.isclose(
# m_miss.compile_logp()({"x_missing_simplex__": inp_vals}),
# m_unobs.compile_logp(jacobian=False)({"x_simplex__": inp_vals}) * 2,
# )
with pytest.warns(ImputationWarning):
x = pm.Dirichlet(
"x",
a=[1, 2, 3],
observed=np.array([[0.3, 0.3, np.nan], [np.nan, np.nan, 0.4]]),
)

assert isinstance(m_miss["x_unobserved"].owner.op, PartialObservedRV)
assert isinstance(m_miss["x_observed"].owner.op, PartialObservedRV)

inp_values = np.array([0.3, 0.3, 0.4])
np.testing.assert_allclose(
m_miss.compile_logp()({"x_unobserved": [0.4, 0.3, 0.3]}),
st.dirichlet.logpdf(inp_values, [1, 2, 3]) * 2,
)

def test_missing_vector_parameter(self):
with pm.Model() as m:
Expand All @@ -1439,7 +1459,7 @@ def test_missing_vector_parameter(self):
assert np.all(x_draws[:, 0] < 0)
assert np.all(x_draws[:, 1] > 0)
assert np.isclose(
m.compile_logp()({"x_missing": np.array([-10, 10, -10, 10])}),
m.compile_logp()({"x_unobserved": np.array([-10, 10, -10, 10])}),
st.norm(scale=0.1).logpdf(0) * 6,
)

Expand All @@ -1458,7 +1478,7 @@ def test_missing_symmetric(self):
x_obs_rv = m["x_observed"]
x_obs_vv = m.rvs_to_values[x_obs_rv]

x_unobs_rv = m["x_missing"]
x_unobs_rv = m["x_unobserved"]
x_unobs_vv = m.rvs_to_values[x_unobs_rv]

logp = transformed_conditional_logp(
Expand All @@ -1482,18 +1502,21 @@ def test_dims(self):
x = pm.Normal("x", observed=data, dims=("observed",))
assert model.named_vars_to_dims == {"x": ("observed",)}

def test_error_non_random_variable(self):
def test_symbolic_random_variable(self):
data = np.array([np.nan] * 3 + [0] * 7)
with pm.Model() as model:
msg = "x of type <class 'pymc.distributions.censored.CensoredRV'> is not supported"
with pytest.raises(NotImplementedError, match=msg):
with pytest.warns(ImputationWarning):
x = pm.Censored(
"x",
pm.Normal.dist(),
lower=0,
upper=10,
observed=data,
)
np.testing.assert_almost_equal(
model.compile_logp()({"x_unobserved": [0] * 3}),
st.norm.logcdf(0) * 10,
)


class TestShared(SeededTest):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def model_with_imputations():

compute_graph = {
"a": set(),
"L_missing": {"a"},
"L_unobserved": {"a"},
"L_observed": {"a"},
"L": {"L_missing", "L_observed"},
"L": {"L_unobserved", "L_observed"},
}
plates = {
"": {"a"},
"2": {"L_missing"},
"2": {"L_unobserved"},
"10": {"L_observed"},
"12": {"L"},
}
Expand Down
6 changes: 3 additions & 3 deletions tests/tuning/test_starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def test_find_MAP_issue_4488():
y = pm.Deterministic("y", x + 1)
map_estimate = find_MAP()

assert not set.difference({"x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
np.testing.assert_allclose(map_estimate["x_missing"], 0.2, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
assert not set.difference({"x_unobserved", "x_unobserved_log__", "y"}, set(map_estimate.keys()))
np.testing.assert_allclose(map_estimate["x_unobserved"], 0.2, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_unobserved"][0] + 1])


def test_find_MAP_warning_non_free_RVs():
Expand Down