New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TransformedDistribution shaping logic #50581
Conversation
Could you merge in changes from the earlier PR? |
3ccc720
to
4e1a93f
Compare
The changes to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Codecov Report
@@ Coverage Diff @@
## master #50581 +/- ##
==========================================
- Coverage 80.99% 80.99% -0.01%
==========================================
Files 1916 1916
Lines 209590 209748 +158
==========================================
+ Hits 169767 169884 +117
- Misses 39823 39864 +41 |
reinterpreted_batch_ndims = domain_event_dim - base_event_dim | ||
if reinterpreted_batch_ndims > 0: | ||
base_distribution = Independent(base_distribution, reinterpreted_batch_ndims) | ||
self.base_dist = base_distribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty clean!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fritzo You missed the forward_shape
and inverse_shape
for ExponentTransform and SoftmaxTransform. The former needs broadcasting to include exponent
shape. The latter should be like StickBreakingTransform. Other changes look great to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding detailed tests! LGTM, pending comments from @fehipsi's.
@fehiepsi @neerajprad thanks for your detailed review!
I don't think PyTorch has an |
Oops... I meant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes #50496 Fixes #34859 Fixes #21596 This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute. ## Summary of changes - Fixes `TransformDistribution` and `ComposeTransform` shape errors. - Fixes a behavior bug in `LogisticNormal`. - Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)` - Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`. - Adds an `IndependentTransform`. - Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch. - Fixes incorrect default values in `constraints.dependent.event_dim`. - Documents the `.event_dim` and `.is_discrete` attributes. ## Changes planned for follow-up PRs - Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often. ## Tested - [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree. - [x] refactoring is covered by existing tests - [x] add test cases for `ReshapedTransform` - [x] add a test for `TransformedDistribution` on a wide grid of input shapes - [x] added a regression test for #34859 cc fehiepsi feynmanliang stefanwebb Pull Request resolved: #50581 Reviewed By: ezyang, glaringlee, jpchen Differential Revision: D26024247 Pulled By: neerajprad fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
Fixes #50496
Fixes #34859
Fixes #21596
This fixes many bugs involving
TransformedDistribution
andComposeTransform
when the component transforms changed their event shapes. Part of the fix is to introduce anIndependentTransform
analogous todistributions.Independent
andconstraints.independent
, and to introduce methodsTransform.forward_shape()
and.inverse_shape()
. I have followed @fehiepsi's suggestion and replaced.input_event_dim
->.domain.event_dim
and.output_event_dim
->.codomain.event_dim
. This allows us to deprecate.event_dim
as an attribute.Summary of changes
TransformDistribution
andComposeTransform
shape errors.LogisticNormal
.kl_divergence(TransformedDistribution, TransformedDistribution)
Transform.forward_shape()
,.inverse_shape()
which are required for correct shape computations inTransformedDistribution
andComposeTransform
.IndependentTransform
.ReshapeTransform
which is invaluable in testing shape logic inComposeTransform
andTransformedDistribution
and which will be used by @stefanwebb flowtorch.constraints.dependent.event_dim
..event_dim
and.is_discrete
attributes.Changes planned for follow-up PRs
@constraints.dependent_property
as we do with@lazy_property
, since we now consult those properties much more often.Tested
Dist.support
vsDist(**params).support
to ensure static and dynamic attributes agree.ReshapedTransform
TransformedDistribution
on a wide grid of input shapescc @fehiepsi @feynmanliang @stefanwebb