Skip to content

Commit

Permalink
Don't copy inputs in constant_fold
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 25, 2024
1 parent 7af0a87 commit b496127
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def constant_fold(
attempting constant folding, and any old non-shared inputs will not work with
the returned outputs
"""
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True)

# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
folded_xs = rewrite_graph(fg).outputs
Expand Down
15 changes: 15 additions & 0 deletions tests/model/transform/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from pymc import Deterministic, do
from pymc.data import Data
from pymc.distributions import HalfNormal, Normal
from pymc.exceptions import NotConstantValueError
from pymc.model import Model
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.pytensorf import constant_fold


def test_freeze_dims_and_data():
Expand Down Expand Up @@ -144,3 +146,16 @@ def test_freeze_dim_after_do_intervention():

frozen_do_m = freeze_dims_and_data(do_m)
assert frozen_do_m["x"].type.shape == (5,)


def test_freeze_dims_and_data_partially_observed_rv():
# Regression test for #7387

with Model(coords={"a": [0, 1, 2]}) as model:
y = Normal("y", 0, observed=[0, 0, np.nan], dims="a")

with pytest.raises(NotConstantValueError):
constant_fold([y.shape])

frozen_y = freeze_dims_and_data(model)["y"]
assert constant_fold([frozen_y.shape]) == (3,)
46 changes: 27 additions & 19 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,25 +646,33 @@ def test_reseed_rngs():
assert rng.get_value().bit_generator.state == bit_generator.state


def test_constant_fold():
x = pt.random.normal(size=(5,))
y = pt.arange(x.size)

res = constant_fold((y, y.shape))
assert np.array_equal(res[0], np.arange(5))
assert tuple(res[1]) == (5,)


def test_constant_fold_raises():
size = pytensor.shared(5)
x = pt.random.normal(size=(size,))
y = pt.arange(x.size)

with pytest.raises(NotConstantValueError):
constant_fold((y, y.shape))

res = constant_fold((y, y.shape), raise_not_constant=False)
assert tuple(res[1].eval()) == (5,)
class TestConstantFold:
def test_constant_fold(self):
x = pt.random.normal(size=(5,))
y = pt.arange(x.size)

res = constant_fold((y, y.shape))
assert np.array_equal(res[0], np.arange(5))
assert tuple(res[1]) == (5,)

def test_constant_fold_raises(self):
size = pytensor.shared(5)
x = pt.random.normal(size=(size,))
y = pt.arange(x.size)

with pytest.raises(NotConstantValueError):
constant_fold((y, y.shape))

res = constant_fold((y, y.shape), raise_not_constant=False)
assert tuple(res[1].eval()) == (5,)

def test_inputs_preserved(self):
# Make sure constant_folded graph depends on original graph inputs (not copies)
# Regression test for #7387
a = pt.scalar("a", dtype="int")
out = pt.empty((a,))
(out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False)
assert out_shape is a


def test_replace_vars_in_graphs():
Expand Down

0 comments on commit b496127

Please sign in to comment.