Skip to content

Commit

Permalink
Implement model transform that freezes RV dims and Data
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 27, 2024
1 parent bad219a commit 546d59e
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 24 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ jobs:
tests/model/test_fgraph.py
tests/model/transform/test_basic.py
tests/model/transform/test_conditioning.py
tests/model/transform/test_optimization.py
tests/test_model_graph.py
tests/ode/test_ode.py
tests/ode/test_utils.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@ Model Conditioning
.. autosummary::
:toctree: generated/

change_value_transforms
do
observe
change_value_transforms
remove_value_transforms


Model Optimization
------------------
.. currentmodule:: pymc.model.transform.optimization
.. autosummary::
:toctree: generated/

freeze_dims_and_data
37 changes: 17 additions & 20 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,8 +1121,7 @@ def set_data(

for d, dname in enumerate(dims):
length_tensor = self.dim_lengths[dname]
with pytensor.config.change_flags(cxx=""):
old_length = length_tensor.eval()
old_length = length_tensor.eval()
new_length = values.shape[d]
original_coords = self.coords.get(dname, None)
new_coords = coords.get(dname, None)
Expand Down Expand Up @@ -1404,24 +1403,22 @@ def create_value_var(
else:
transform = _default_transform(rv_var.owner.op, rv_var)

if value_var is not None:
if transform is not None:
raise ValueError("Cannot use transform when providing a pre-defined value_var")
elif transform is None:
# Create value variable with the same type as the RV
value_var = rv_var.type()
value_var.name = rv_var.name
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = rv_var.tag.test_value
else:
# Create value variable with the same type as the transformed RV
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
value_var.name = f"{rv_var.name}_{transform.name}__"
value_var.tag.transform = transform
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = transform.forward(
rv_var, *rv_var.owner.inputs
).tag.test_value
if value_var is None:
if transform is None:
# Create value variable with the same type as the RV
value_var = rv_var.type()
value_var.name = rv_var.name
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = rv_var.tag.test_value
else:
# Create value variable with the same type as the transformed RV
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
value_var.name = f"{rv_var.name}_{transform.name}__"
value_var.tag.transform = transform
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = transform.forward(
rv_var, *rv_var.owner.inputs
).tag.test_value

_add_future_warning_tag(value_var)
rv_var.tag.value_var = value_var
Expand Down
4 changes: 1 addition & 3 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,7 @@ def first_non_model_var(var):
var, value, *dims = model_var.owner.inputs
transform = model_var.owner.op.transform
model.free_RVs.append(var)
# PyMC does not allow setting transform when we pass a value_var. Why?
model.create_value_var(var, transform=None, value_var=value)
model.rvs_to_transforms[var] = transform
model.create_value_var(var, transform=transform, value_var=value)
model.set_initval(var, initval=None)
elif isinstance(model_var.owner.op, ModelObservedRV):
var, value, *dims = model_var.owner.inputs
Expand Down
80 changes: 80 additions & 0 deletions pymc/model/transform/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytensor import clone_replace
from pytensor.compile import SharedVariable
from pytensor.graph import FunctionGraph
from pytensor.tensor import constant

from pymc import Model
from pymc.model.fgraph import ModelFreeRV, fgraph_from_model, model_from_fgraph


def freeze_dims_and_data(model: Model) -> Model:
"""Recreate a Model with fixed RV dimensions and Data values.
The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
Likewise, it will not be possible to update pre-existing Data in the new model.
Note that any new RVs and Data created after calling this function will still be "unfrozen".
This transformation may allow more performant sampling, or compiling model functions to backends that
are more restrictive about dynamic shapes such as JAX.
"""
fg, memo = fgraph_from_model(model)

# Replace mutable dim lengths and data by constants
frozen_vars = {
memo[dim_length]: constant(
dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
)
for dim_length in model.dim_lengths.values()
if isinstance(dim_length, SharedVariable)
}
frozen_vars |= {
memo[data_var].owner.inputs[0]: constant(
data_var.get_value(), name=data_var.name, dtype=data_var.type.dtype
)
for data_var in model.named_vars.values()
if isinstance(data_var, SharedVariable)
}

old_outs, coords = fg.outputs, fg._coords
# Rebuild strict will force the recreation of RV nodes with updated static types
new_outs = clone_replace(old_outs, replace=frozen_vars, rebuild_strict=False)
for old_out, new_out in zip(old_outs, new_outs):
new_out.name = old_out.name
fg = FunctionGraph(outputs=new_outs, clone=False)
fg._coords = coords

# Recreate value variables from new RVs to propagate static types to logp graphs
replacements = {}
for node in fg.apply_nodes:
if not isinstance(node.op, ModelFreeRV):
continue
rv, old_value, *dims = node.inputs
if dims is None:
continue
transform = node.op.transform
if transform is None:
new_value = rv.type()
else:
new_value = transform.forward(rv, *rv.owner.inputs).type()
new_value.name = old_value.name
replacements[old_value] = new_value
fg.replace_all(tuple(replacements.items()), import_missing=True)

return model_from_fgraph(fg)


__all__ = ("freeze_dims_and_data",)
51 changes: 51 additions & 0 deletions tests/model/transform/test_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytensor.graph import Constant

from pymc.data import Data
from pymc.distributions import HalfNormal, Normal
from pymc.model import Model
from pymc.model.transform.optimization import freeze_dims_and_data


def test_freeze_existing_rv_dims_and_data():
with Model(coords={"test_dim": range(5)}) as m:
std = Data("std", [1])
x = HalfNormal("x", std, dims=("test_dim",))
y = Normal("y", shape=x.shape[0] + 1)

x_logp, y_logp = m.logp(sum=False)

assert not isinstance(std, Constant)
assert x.type.shape == (None,)
assert y.type.shape == (None,)
assert x_logp.type.shape == (None,)
assert y_logp.type.shape == (None,)

frozen_m = freeze_dims_and_data(m)
std, x, y = frozen_m["std"], frozen_m["x"], frozen_m["y"]
x_logp, y_logp = frozen_m.logp(sum=False)
assert isinstance(std, Constant)
assert x.type.shape == (5,)
assert y.type.shape == (6,)
assert x_logp.type.shape == (5,)
assert y_logp.type.shape == (6,)


def test_freeze_rv_dims_nothing_to_change():
with Model(coords={"test_dim": range(5)}) as m:
x = HalfNormal("x", shape=(5,))
y = Normal("y", shape=x.shape[0] + 1)

assert m.point_logps() == freeze_dims_and_data(m).point_logps()

0 comments on commit 546d59e

Please sign in to comment.