Skip to content

Commit

Permalink
Implement model transform that freezes RV dims and Data
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 4, 2023
1 parent 3c5158e commit 28ded31
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 26 deletions.
3 changes: 2 additions & 1 deletion docs/source/api/model/conditioning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ Model Conditioning
.. autosummary::
:toctree: generated/

change_value_transforms
do
observe
change_value_transforms
remove_value_transforms
freeze_existing_rv_dims_and_data
37 changes: 17 additions & 20 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,8 +1130,7 @@ def set_data(

for d, dname in enumerate(dims):
length_tensor = self.dim_lengths[dname]
with pytensor.config.change_flags(cxx=""):
old_length = length_tensor.eval()
old_length = length_tensor.eval()
new_length = values.shape[d]
original_coords = self.coords.get(dname, None)
new_coords = coords.get(dname, None)
Expand Down Expand Up @@ -1399,24 +1398,22 @@ def create_value_var(
else:
transform = _default_transform(rv_var.owner.op, rv_var)

if value_var is not None:
if transform is not None:
raise ValueError("Cannot use transform when providing a pre-defined value_var")
elif transform is None:
# Create value variable with the same type as the RV
value_var = rv_var.type()
value_var.name = rv_var.name
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = rv_var.tag.test_value
else:
# Create value variable with the same type as the transformed RV
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
value_var.name = f"{rv_var.name}_{transform.name}__"
value_var.tag.transform = transform
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = transform.forward(
rv_var, *rv_var.owner.inputs
).tag.test_value
if value_var is None:
if transform is None:
# Create value variable with the same type as the RV
value_var = rv_var.type()
value_var.name = rv_var.name
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = rv_var.tag.test_value

Check warning on line 1407 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L1407

Added line #L1407 was not covered by tests
else:
# Create value variable with the same type as the transformed RV
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
value_var.name = f"{rv_var.name}_{transform.name}__"
value_var.tag.transform = transform
if pytensor.config.compute_test_value != "off":
value_var.tag.test_value = transform.forward(

Check warning on line 1414 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L1414

Added line #L1414 was not covered by tests
rv_var, *rv_var.owner.inputs
).tag.test_value

_add_future_warning_tag(value_var)
rv_var.tag.value_var = value_var
Expand Down
4 changes: 1 addition & 3 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,7 @@ def first_non_model_var(var):
var, value, *dims = model_var.owner.inputs
transform = model_var.owner.op.transform
model.free_RVs.append(var)
# PyMC does not allow setting transform when we pass a value_var. Why?
model.create_value_var(var, transform=None, value_var=value)
model.rvs_to_transforms[var] = transform
model.create_value_var(var, transform=transform, value_var=value)

Check warning on line 324 in pymc/model/fgraph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/fgraph.py#L324

Added line #L324 was not covered by tests
model.set_initval(var, initval=None)
elif isinstance(model_var.owner.op, ModelObservedRV):
var, value, *dims = model_var.owner.inputs
Expand Down
64 changes: 62 additions & 2 deletions pymc/model/transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

from typing import Any, Mapping, Optional, Sequence, Union

from pytensor.graph import ancestors
from pytensor.tensor import TensorVariable
from pytensor import clone_replace
from pytensor.compile import SharedVariable
from pytensor.graph import FunctionGraph, ancestors
from pytensor.tensor import TensorVariable, constant

from pymc import Model
from pymc.logprob.transforms import Transform
Expand Down Expand Up @@ -349,9 +351,67 @@ def remove_value_transforms(
return change_value_transforms(model, {var: None for var in vars})


def freeze_existing_rv_dims_and_data(model: Model) -> Model:
"""Recreate a Model with fixed RV dimensions and Data values.
The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
Likewise, it will not be possible to update pre-existing Data in the new model.
Note that any new RVs and Data created after calling this function will still be "unfrozen".
This transformation may allow more performant sampling, or compiling model functions to backends that
are more restrictive about dynamic shapes such as JAX.
"""
fg, memo = fgraph_from_model(model)

Check warning on line 365 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L365

Added line #L365 was not covered by tests

# Replace mutable dim lengths and data by constants
frozen_vars = {

Check warning on line 368 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L368

Added line #L368 was not covered by tests
memo[dim_length]: constant(
dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
)
for dim_length in model.dim_lengths.values()
if isinstance(dim_length, SharedVariable)
}
frozen_vars |= {

Check warning on line 375 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L375

Added line #L375 was not covered by tests
memo[data_var].owner.inputs[0]: constant(
data_var.get_value(), name=data_var.name, dtype=data_var.type.dtype
)
for data_var in model.named_vars.values()
if isinstance(data_var, SharedVariable)
}

old_outs, coords = fg.outputs, fg._coords

Check warning on line 383 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L383

Added line #L383 was not covered by tests
# Rebuild strict will force the recreation of RV nodes with updated static types
new_outs = clone_replace(old_outs, replace=frozen_vars, rebuild_strict=False)
for old_out, new_out in zip(old_outs, new_outs):
new_out.name = old_out.name
fg = FunctionGraph(outputs=new_outs, clone=False)
fg._coords = coords

Check warning on line 389 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L385-L389

Added lines #L385 - L389 were not covered by tests

# Recreate value variables from new RVs to propagate static types to logp graphs
replacements = {}
for node in fg.apply_nodes:
if not isinstance(node.op, ModelFreeRV):
continue
rv, old_value, *dims = node.inputs
if dims is None:
continue
transform = node.op.transform
if transform is None:
new_value = rv.type()

Check warning on line 401 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L392-L401

Added lines #L392 - L401 were not covered by tests
else:
new_value = transform.forward(rv, *rv.owner.inputs).type()
new_value.name = old_value.name
replacements[old_value] = new_value
fg.replace_all(tuple(replacements.items()), import_missing=True)

Check warning on line 406 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L403-L406

Added lines #L403 - L406 were not covered by tests

return model_from_fgraph(fg)

Check warning on line 408 in pymc/model/transform/conditioning.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/transform/conditioning.py#L408

Added line #L408 was not covered by tests


__all__ = (
"change_value_transforms",
"do",
"freeze_existing_rv_dims_and_data",
"observe",
"remove_value_transforms",
)
34 changes: 34 additions & 0 deletions tests/model/transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import pytest

from pytensor import config
from pytensor.graph import Constant

import pymc as pm

from pymc.distributions.transforms import logodds
from pymc.model.transform.conditioning import (
change_value_transforms,
do,
freeze_existing_rv_dims_and_data,
observe,
remove_value_transforms,
)
Expand Down Expand Up @@ -308,3 +310,35 @@ def test_remove_value_transforms():
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: logodds, new_q: None}


def test_freeze_existing_rv_dims_and_data():
with pm.Model(coords={"test_dim": range(5)}) as m:
std = pm.Data("std", [1])
x = pm.HalfNormal("x", std, dims=("test_dim",))
y = pm.Normal("y", shape=x.shape[0] + 1)

x_logp, y_logp = m.logp(sum=False)

assert not isinstance(std, Constant)
assert x.type.shape == (None,)
assert y.type.shape == (None,)
assert x_logp.type.shape == (None,)
assert y_logp.type.shape == (None,)

frozen_m = freeze_existing_rv_dims_and_data(m)
std, x, y = frozen_m["std"], frozen_m["x"], frozen_m["y"]
x_logp, y_logp = frozen_m.logp(sum=False)
assert isinstance(std, Constant)
assert x.type.shape == (5,)
assert y.type.shape == (6,)
assert x_logp.type.shape == (5,)
assert y_logp.type.shape == (6,)


def test_freeze_rv_dims_nothing_to_change():
with pm.Model(coords={"test_dim": range(5)}) as m:
x = pm.HalfNormal("x", shape=(5,))
y = pm.Normal("y", shape=x.shape[0] + 1)

assert m.point_logps() == freeze_existing_rv_dims_and_data(m).point_logps()

0 comments on commit 28ded31

Please sign in to comment.