Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement default_transform and transform argument for distributions #7207

Merged
merged 24 commits into from
Apr 19, 2024

Conversation

aerubanov
Copy link
Contributor

@aerubanov aerubanov commented Mar 21, 2024

Description

Related Issue

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7207.org.readthedocs.build/en/7207/

@aerubanov aerubanov marked this pull request as draft March 21, 2024 18:31
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, left some small user-quality of life suggestions

pymc/model/core.py Show resolved Hide resolved
pymc/model/core.py Outdated Show resolved Hide resolved
pymc/model/core.py Show resolved Hide resolved
pymc/model/core.py Show resolved Hide resolved
Copy link

codecov bot commented Mar 22, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.28%. Comparing base (034b9a4) to head (4c94e9e).
Report is 19 commits behind head on main.

❗ Current head 4c94e9e differs from pull request most recent head 704aac6. Consider uploading reports for the commit 704aac6 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7207      +/-   ##
==========================================
- Coverage   92.29%   92.28%   -0.01%     
==========================================
  Files         101      100       -1     
  Lines       16892    16906      +14     
==========================================
+ Hits        15590    15602      +12     
- Misses       1302     1304       +2     
Files Coverage Δ
pymc/distributions/distribution.py 93.96% <100.00%> (+0.03%) ⬆️
pymc/model/core.py 92.26% <100.00%> (+0.44%) ⬆️
pymc/model/fgraph.py 97.39% <100.00%> (-0.51%) ⬇️

... and 16 files with indirect coverage changes

@aerubanov
Copy link
Contributor Author

Looks like I fixed all failed test cases. Going add some new tests and changes to documentation as a next steps.

@ricardoV94
Copy link
Member

Looks like I fixed all failed test cases. Going add some new tests and changes to documentation as a next steps.

@mkusnetsov took the documentation initiative in #7232 so we should be good to go just with docstrings and tests

@aerubanov aerubanov marked this pull request as ready for review April 9, 2024 16:54
@aerubanov
Copy link
Contributor Author

@ricardoV94 I added some test cases to check warning it transform=None and transform order if both default_transform and transform are using

@@ -397,6 +398,15 @@ def __new__(
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if transform is None and default_transform is UNSET:
Copy link
Member

@ricardoV94 ricardoV94 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning should be in the relevant section in pm.Model instead of Distribution

pymc/model/core.py Outdated Show resolved Hide resolved
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None)

with pytest.warns(
Copy link
Member

@ricardoV94 ricardoV94 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functionality shouldn't be tested here, since it's not specific to Mixture. Probably in model/test_core.py there should be related stuff already? I see you already have it there, so is this test needed?

tests/distributions/test_transform.py Outdated Show resolved Hide resolved
tests/distributions/test_transform.py Outdated Show resolved Hide resolved
Comment on lines 221 to 228
x = pm.Uniform("x", lower=0, upper=1, transform=transform, default_transform=None)
# Operation between the variables provides a regression test for #7054
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform)
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform)
y = pm.Uniform(
"y", lower=0, upper=pt.exp(x), transform=transform, default_transform=None
)
z = pm.Uniform("z", lower=0, upper=y, transform=transform, default_transform=None)
w = pm.Uniform(
"w", lower=0, upper=pt.square(z), transform=transform, default_transform=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass transform to default_transform


with pm.Model() as model:
x = pm.Normal("x", transform=DummyTransform(2), default_transform=DummyTransform(1))
assert transform_order == [1, 2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use regular transforms, and simply assert the obtained transform is a Chain, which as a property like transform_list that includes the transforms, and you can assert those are also the expected ones. The transform is available in models.rvs_to_transforms[x]

Also, would be nice to include a numerical example that would have led to nan or -inf probability before the change, like an ordered mixture of LogNormals evaluated at -1

@ricardoV94 ricardoV94 changed the title Implement default_transform and transform argument for distributions Implement default_transform and transform argument for distributions Apr 15, 2024
aerubanov and others added 5 commits April 15, 2024 20:17
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides making the test_transform_order simpler, just need some tweaks to the docstrings. They don't describe the parameter. We can also give better type hint than Any. If mypy complaints just revert to Any

total_size=None,
dims=None,
transform=UNSET,
default_transform=UNSET,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing description in the docstrings

@@ -1288,6 +1299,7 @@ def make_obs_var(
data: np.ndarray,
dims,
transform: Any | None,
default_transform: Any | None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing description in the docstrings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also type hint isn't great, should be RVTransform | None (or something like that, don't remember the exact class now)

rv_var: TensorVariable,
*,
transform: Any,
default_transform: Any,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update type hints and docstring

@aerubanov
Copy link
Contributor Author

@ricardoV94 Is there any way to get logp value after transform applied? I`m trying something like

    with pm.Model() as model:
        x1 = pm.LogNormal("x1", 0.0, 1.0)
        x2 = pm.LogNormal("x2", 0.0, 1.0, default_transform=LogTransform())
        assert pm.logp(x1, 1.0).eval() != pm.logp(x2, 1.0).eval()

and got assertion error

@ricardoV94
Copy link
Member

@ricardoV94 Is there any way to get logp value after transform applied? I`m trying something like

    with pm.Model() as model:
        x1 = pm.LogNormal("x1", 0.0, 1.0)
        x2 = pm.LogNormal("x2", 0.0, 1.0, default_transform=LogTransform())
        assert pm.logp(x1, 1.0).eval() != pm.logp(x2, 1.0).eval()

and got assertion error

You want to use model.compile_logp(), and you'll have to create two models, one with transform and one without

@aerubanov
Copy link
Contributor Author

@ricardoV94 I added new test case with numerical example for transform args. Please let me know what do you thunk about it.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @aerubanov

I left some minor comments, I think this is the last round!

tests/model/test_core.py Outdated Show resolved Hide resolved
with pm.Model() as model1:
x1 = pm.LogNormal("x1", 0, 1, transform=Interval(-2, 2), default_transform=None)
with pm.Model() as model3:
x2 = pm.LogNormal("x2", 0, 1, transform=Interval(-2, 2), default_transform=log)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine but is not very realistic. What about ordered transform + log which is the example that motivated this PR?

Suggested change
x2 = pm.LogNormal("x2", 0, 1, transform=Interval(-2, 2), default_transform=log)
x2 = pm.LogNormal("x2", 0, 1, transform=Interval(-2, 2), default_transform=log)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the logic of the test is spot-on!

pymc/model/core.py Outdated Show resolved Hide resolved
@@ -1230,7 +1239,9 @@ def register_rv(
dims : tuple
Dimension names for the variable.
transform
A transform for the random variable in log-likelihood space.
Additianal transform which may be applied after default transform.
default_transform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: show default_transform before transform (also in the signature)?

Comment on lines 1318 to 1320
transform
Additianal transform which may be applied after default transform.
default_transform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Comment on lines 1418 to 1422
transform: Transform
Additianal transform which may be applied after default transform.

default_transform: Transform
A transform for the random variable in log-likelihood space.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

aerubanov and others added 3 commits April 18, 2024 16:13
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@aerubanov
Copy link
Contributor Author

@ricardoV94 please review again

tests/model/test_core.py Show resolved Hide resolved
tests/model/test_core.py Outdated Show resolved Hide resolved
tests/model/test_core.py Outdated Show resolved Hide resolved
def test_transform_order(self):
with pm.Model() as model:
x = pm.Normal("x", transform=Interval(0, 1), default_transform=log)
assert isinstance(model.rvs_to_transforms[x], ChainedTransform)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick (feel free to ignore): Save the transform in a separate variable so you don't need to write 3 times model.rvs_to_transforms[x]

aerubanov and others added 2 commits April 18, 2024 20:24
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks great!

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 18, 2024

We should follow up with an issue to provide informative warnings when doing prior/posterior predictive sampling of variables with custom non-default transforms, like we do with Potentials

@ricardoV94 ricardoV94 merged commit 23e418d into pymc-devs:main Apr 19, 2024
21 checks passed
@ricardoV94 ricardoV94 added the major Include in major changes release notes section label Apr 19, 2024
@aerubanov aerubanov deleted the 5674-default-transform-arg branch April 19, 2024 07:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section request discussion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement default_transform and transform argument for distributions
2 participants