Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement default_transform and transform argument for distributions #7207

Merged
merged 24 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __new__(
observed=None,
total_size=None,
transform=UNSET,
default_transform=UNSET,
**kwargs,
) -> TensorVariable:
"""Adds a tensor variable corresponding to a PyMC distribution to the current model.
Expand Down Expand Up @@ -414,10 +415,11 @@ def __new__(
rv_out = model.register_rv(
rv_out,
name,
observed,
total_size,
observed=observed,
total_size=total_size,
dims=dims,
transform=transform,
default_transform=default_transform,
initval=initval,
)

Expand Down
80 changes: 64 additions & 16 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from sys import modules
from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
TypeVar,
Expand All @@ -48,7 +47,7 @@

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, is_minibatch
from pymc.distributions.transforms import _default_transform
from pymc.distributions.transforms import ChainedTransform, _default_transform
from pymc.exceptions import (
BlockModelAccessError,
ImputationWarning,
Expand All @@ -58,6 +57,7 @@
)
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
from pymc.model_graph import model_to_graphviz
from pymc.pytensorf import (
Expand Down Expand Up @@ -1214,7 +1214,16 @@ def set_data(
shared_object.set_value(values)

def register_rv(
self, rv_var, name, observed=None, total_size=None, dims=None, transform=UNSET, initval=None
self,
rv_var,
name,
*,
observed=None,
aerubanov marked this conversation as resolved.
Show resolved Hide resolved
total_size=None,
dims=None,
default_transform=UNSET,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing description in the docstrings

transform=UNSET,
initval=None,
):
"""Register an (un)observed random variable with the model.

Expand All @@ -1229,8 +1238,10 @@ def register_rv(
upscales logp of variable with ``coef = total_size/var.shape[0]``
dims : tuple
Dimension names for the variable.
default_transform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: show default_transform before transform (also in the signature)?

A default transform for the random variable in log-likelihood space.
transform
A transform for the random variable in log-likelihood space.
Additional transform which may be applied after default transform.
initval
The initial value of the random variable.

Expand All @@ -1255,7 +1266,7 @@ def register_rv(
if total_size is not None:
raise ValueError("total_size can only be passed to observed RVs")
self.free_RVs.append(rv_var)
self.create_value_var(rv_var, transform)
self.create_value_var(rv_var, transform=transform, default_transform=default_transform)
self.add_named_variable(rv_var, dims)
self.set_initval(rv_var, initval)
else:
Expand All @@ -1278,7 +1289,9 @@ def register_rv(

# `rv_var` is potentially changed by `make_obs_var`,
# for example into a new graph for imputation of missing data.
rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
rv_var = self.make_obs_var(
rv_var, observed, dims, default_transform, transform, total_size
)

return rv_var

Expand All @@ -1287,7 +1300,8 @@ def make_obs_var(
rv_var: TensorVariable,
data: np.ndarray,
dims,
transform: Any | None,
default_transform: Transform | None,
transform: Transform | None,
total_size: int | None,
) -> TensorVariable:
"""Create a `TensorVariable` for an observed random variable.
Expand All @@ -1301,8 +1315,10 @@ def make_obs_var(
The observed data.
dims : tuple
Dimension names for the variable.
transform : int, optional
default_transform
A transform for the random variable in log-likelihood space.
transform
Additional transform which may be applied after default transform.

Returns
-------
Expand Down Expand Up @@ -1339,12 +1355,19 @@ def make_obs_var(

# Register ObservedRV corresponding to observed component
observed_rv.name = f"{name}_observed"
self.create_value_var(observed_rv, transform=None, value_var=observed_data)
self.create_value_var(
observed_rv, transform=transform, default_transform=None, value_var=observed_data
)
self.add_named_variable(observed_rv)
self.observed_RVs.append(observed_rv)

# Register FreeRV corresponding to unobserved components
self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform)
self.register_rv(
unobserved_rv,
f"{name}_unobserved",
transform=transform,
default_transform=default_transform,
)

# Register Deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
Expand All @@ -1363,14 +1386,21 @@ def make_obs_var(
rv_var.name = name

rv_var.tag.observations = data
self.create_value_var(rv_var, transform=None, value_var=data)
self.create_value_var(
rv_var, transform=transform, default_transform=None, value_var=data
)
self.add_named_variable(rv_var, dims)
self.observed_RVs.append(rv_var)

return rv_var

def create_value_var(
self, rv_var: TensorVariable, transform: Any, value_var: Variable | None = None
self,
rv_var: TensorVariable,
*,
default_transform: Transform,
transform: Transform,
value_var: Variable | None = None,
) -> TensorVariable:
"""Create a ``TensorVariable`` that will be used as the random
variable's "value" in log-likelihood graphs.
Expand All @@ -1385,7 +1415,11 @@ def create_value_var(
----------
rv_var : TensorVariable

transform : Any
default_transform: Transform
A transform for the random variable in log-likelihood space.

transform: Transform
Additional transform which may be applied after default transform.

value_var : Variable, optional

Expand All @@ -1396,11 +1430,25 @@ def create_value_var(

# Make the value variable a transformed value variable,
# if there's an applicable transform
if transform is UNSET:
if transform is None and default_transform is UNSET:
default_transform = None
warnings.warn(
"To disable default transform, please use default_transform=None"
" instead of transform=None. Setting transform to None will"
" not have any effect in future.",
UserWarning,
)

if default_transform is UNSET:
if rv_var.owner is None:
transform = None
default_transform = None
else:
transform = _default_transform(rv_var.owner.op, rv_var)
default_transform = _default_transform(rv_var.owner.op, rv_var)

if transform is UNSET:
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
transform = default_transform
elif transform is not None and default_transform is not None:
transform = ChainedTransform([default_transform, transform])
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

if value_var is None:
if transform is None:
Expand Down
6 changes: 4 additions & 2 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,14 @@ def first_non_model_var(var):
var, value, *dims = model_var.owner.inputs
transform = model_var.owner.op.transform
model.free_RVs.append(var)
model.create_value_var(var, transform=transform, value_var=value)
model.create_value_var(
var, transform=transform, default_transform=None, value_var=value
)
model.set_initval(var, initval=None)
elif isinstance(model_var.owner.op, ModelObservedRV):
var, value, *dims = model_var.owner.inputs
model.observed_RVs.append(var)
model.create_value_var(var, transform=None, value_var=value)
model.create_value_var(var, transform=None, default_transform=None, value_var=value)
elif isinstance(model_var.owner.op, ModelPotential):
var, *dims = model_var.owner.inputs
model.potentials.append(var)
Expand Down
6 changes: 3 additions & 3 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,17 +1359,17 @@ def test_warning(self):

with warnings.catch_warnings():
warnings.simplefilter("error")
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None)

with warnings.catch_warnings():
warnings.simplefilter("error")
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)

# Case where the appropriate default transform is None
comp_dists = [Normal.dist(), Normal.dist()]
with warnings.catch_warnings():
warnings.simplefilter("error")
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists)
Mixture("mix7", w=[0.5, 0.5], comp_dists=comp_dists)


class TestZeroInflatedMixture:
Expand Down
4 changes: 2 additions & 2 deletions tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def test_transform_univariate_dist_logp_shape():

def test_univariate_transform_multivariate_dist_raises():
with pm.Model() as m:
pm.Dirichlet("x", [1, 1, 1], transform=tr.log)
pm.Dirichlet("x", [1, 1, 1], default_transform=tr.log)

for jacobian_val in (True, False):
with pytest.raises(
Expand All @@ -645,7 +645,7 @@ def log_jac_det(self, value, *inputs):
buggy_transform = BuggyTransform()

with pm.Model() as m:
pm.Uniform("x", shape=(4, 3), transform=buggy_transform)
pm.Uniform("x", shape=(4, 3), default_transform=buggy_transform)

for jacobian_val in (True, False):
with pytest.raises(
Expand Down
8 changes: 4 additions & 4 deletions tests/logprob/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ def test_interdependent_transformed_rvs(self, reversed):
transform = pm.distributions.transforms.Interval(
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
)
x = pm.Uniform("x", lower=0, upper=1, transform=transform)
x = pm.Uniform("x", lower=0, upper=1, default_transform=transform)
# Operation between the variables provides a regression test for #7054
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform)
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform)
y = pm.Uniform("y", lower=0, upper=pt.exp(x), default_transform=transform)
z = pm.Uniform("z", lower=0, upper=y, default_transform=transform)
w = pm.Uniform("w", lower=0, upper=pt.square(z), default_transform=transform)

rvs = [x, y, z, w]
if reversed:
Expand Down
47 changes: 41 additions & 6 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.distributions import Normal, transforms
from pymc.distributions.distribution import PartialObservedRV
from pymc.distributions.transforms import log, simplex
from pymc.distributions.transforms import (
ChainedTransform,
Interval,
LogTransform,
log,
ordered,
simplex,
)
from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import IntervalTransform
Expand Down Expand Up @@ -527,6 +534,34 @@ def test_model_var_maps():
assert model.rvs_to_transforms[x] is None


class TestTransformArgs:
def test_transform_warning(self):
with pm.Model():
with pytest.warns(
UserWarning,
match="To disable default transform,"
" please use default_transform=None"
" instead of transform=None. Setting transform to"
" None will not have any effect in future.",
):
a = pm.Normal("a", transform=None)

def test_transform_order(self):
with pm.Model() as model:
x = pm.Normal("x", transform=Interval(0, 1), default_transform=log)
assert isinstance(model.rvs_to_transforms[x], ChainedTransform)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick (feel free to ignore): Save the transform in a separate variable so you don't need to write 3 times model.rvs_to_transforms[x]

assert isinstance(model.rvs_to_transforms[x].transform_list[0], LogTransform)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(model.rvs_to_transforms[x].transform_list[1], Interval)

def test_default_transform_is_applied(self):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
with pm.Model() as model1:
x1 = pm.LogNormal("x1", [0, 0], [1, 1], transform=ordered, default_transform=None)
with pm.Model() as model3:
aerubanov marked this conversation as resolved.
Show resolved Hide resolved
x2 = pm.LogNormal("x2", [0, 0], [1, 1], transform=ordered, default_transform=log)
aerubanov marked this conversation as resolved.
Show resolved Hide resolved
assert np.isinf(model1.compile_logp()({"x1_ordered__": (-1, -1)}))
assert np.isfinite(model3.compile_logp()({"x2_chain__": (-1, -1)}))


def test_make_obs_var():
"""
Check returned values for `data` given known inputs to `as_tensor()`.
Expand All @@ -549,26 +584,26 @@ def test_make_obs_var():

# The function requires data and RV dimensionality to be compatible
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."):
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None)
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None, None)

# Check function behavior using the various inputs
# dense, sparse: Ensure that the missing values are appropriately set to None
# masked: a deterministic variable is returned

dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None)
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None, None)
assert dense_output == fake_distribution
assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant)
del fake_model.named_vars[fake_distribution.name]

sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None)
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None, None)
assert sparse_output == fake_distribution
assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output])
del fake_model.named_vars[fake_distribution.name]

# Here the RandomVariable is split into observed/imputed and a Deterministic is returned
with pytest.warns(ImputationWarning):
masked_output = fake_model.make_obs_var(
fake_distribution, masked_array_input, None, None, None
fake_distribution, masked_array_input, None, None, None, None
)
assert masked_output != fake_distribution
assert not isinstance(masked_output, RandomVariable)
Expand All @@ -581,7 +616,7 @@ def test_make_obs_var():

# Test that setting total_size returns a MinibatchRandomVariable
scaled_outputs = fake_model.make_obs_var(
fake_distribution, dense_input, None, None, total_size=100
fake_distribution, dense_input, None, None, None, total_size=100
)
assert scaled_outputs != fake_distribution
assert isinstance(scaled_outputs.owner.op, MinibatchRandomVariable)
Expand Down
4 changes: 2 additions & 2 deletions tests/model/transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ def test_change_value_transforms_error():

def test_remove_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", transform=logodds)
q = pm.Uniform("q", transform=logodds)
p = pm.Uniform("p", transform=logodds, default_transform=None)
q = pm.Uniform("q", transform=logodds, default_transform=None)

new_m = remove_value_transforms(base_m)
new_p = new_m["p"]
Expand Down
2 changes: 1 addition & 1 deletion tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_transform_with_rv_dependency(self, symbolic_rv):
transform = pm.distributions.transforms.Interval(
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
)
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
y = pm.Uniform("y", lower=0, upper=x, transform=transform, default_transform=None)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)
Expand Down
Loading