Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when freezing model with partially observed RVs #7388

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def constant_fold(
attempting constant folding, and any old non-shared inputs will not work with
the returned outputs
"""
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be safe, as we can never mutate inputs accidentally (there is no such operation)


# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
folded_xs = rewrite_graph(fg).outputs
Expand Down
15 changes: 15 additions & 0 deletions tests/model/transform/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from pymc import Deterministic, do
from pymc.data import Data
from pymc.distributions import HalfNormal, Normal
from pymc.exceptions import NotConstantValueError
from pymc.model import Model
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.pytensorf import constant_fold


def test_freeze_dims_and_data():
Expand Down Expand Up @@ -144,3 +146,16 @@ def test_freeze_dim_after_do_intervention():

frozen_do_m = freeze_dims_and_data(do_m)
assert frozen_do_m["x"].type.shape == (5,)


def test_freeze_dims_and_data_partially_observed_rv():
# Regression test for #7387

with Model(coords={"a": [0, 1, 2]}) as model:
y = Normal("y", 0, observed=[0, 0, np.nan], dims="a")

with pytest.raises(NotConstantValueError):
constant_fold([y.shape])

frozen_y = freeze_dims_and_data(model)["y"]
assert constant_fold([frozen_y.shape]) == (3,)
46 changes: 27 additions & 19 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,25 +646,33 @@ def test_reseed_rngs():
assert rng.get_value().bit_generator.state == bit_generator.state


def test_constant_fold():
x = pt.random.normal(size=(5,))
y = pt.arange(x.size)

res = constant_fold((y, y.shape))
assert np.array_equal(res[0], np.arange(5))
assert tuple(res[1]) == (5,)


def test_constant_fold_raises():
size = pytensor.shared(5)
x = pt.random.normal(size=(size,))
y = pt.arange(x.size)

with pytest.raises(NotConstantValueError):
constant_fold((y, y.shape))

res = constant_fold((y, y.shape), raise_not_constant=False)
assert tuple(res[1].eval()) == (5,)
class TestConstantFold:
def test_constant_fold(self):
x = pt.random.normal(size=(5,))
y = pt.arange(x.size)

res = constant_fold((y, y.shape))
assert np.array_equal(res[0], np.arange(5))
assert tuple(res[1]) == (5,)

def test_constant_fold_raises(self):
size = pytensor.shared(5)
x = pt.random.normal(size=(size,))
y = pt.arange(x.size)

with pytest.raises(NotConstantValueError):
constant_fold((y, y.shape))

res = constant_fold((y, y.shape), raise_not_constant=False)
assert tuple(res[1].eval()) == (5,)

def test_inputs_preserved(self):
# Make sure constant_folded graph depends on original graph inputs (not copies)
# Regression test for #7387
a = pt.scalar("a", dtype="int")
out = pt.empty((a,))
(out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False)
assert out_shape is a


def test_replace_vars_in_graphs():
Expand Down
Loading