Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TransformedDistribution shaping logic #50581

Closed
wants to merge 12 commits into from
Closed

Conversation

fritzo
Copy link
Collaborator

@fritzo fritzo commented Jan 15, 2021

Fixes #50496
Fixes #34859
Fixes #21596

This fixes many bugs involving TransformedDistribution and ComposeTransform when the component transforms changed their event shapes. Part of the fix is to introduce an IndependentTransform analogous to distributions.Independent and constraints.independent, and to introduce methods Transform.forward_shape() and .inverse_shape(). I have followed @fehiepsi's suggestion and replaced .input_event_dim -> .domain.event_dim and .output_event_dim -> .codomain.event_dim. This allows us to deprecate .event_dim as an attribute.

Summary of changes

  • Fixes TransformDistribution and ComposeTransform shape errors.
  • Fixes a behavior bug in LogisticNormal.
  • Fixes kl_divergence(TransformedDistribution, TransformedDistribution)
  • Adds methods Transform.forward_shape(), .inverse_shape() which are required for correct shape computations in TransformedDistribution and ComposeTransform.
  • Adds an IndependentTransform.
  • Adds a ReshapeTransform which is invaluable in testing shape logic in ComposeTransform and TransformedDistribution and which will be used by @stefanwebb flowtorch.
  • Fixes incorrect default values in constraints.dependent.event_dim.
  • Documents the .event_dim and .is_discrete attributes.

Changes planned for follow-up PRs

  • Memoize @constraints.dependent_property as we do with @lazy_property, since we now consult those properties much more often.

Tested

  • added a test for Dist.support vs Dist(**params).support to ensure static and dynamic attributes agree.
  • refactoring is covered by existing tests
  • add test cases for ReshapedTransform
  • add a test for TransformedDistribution on a wide grid of input shapes
  • added a regression test for Incorrect shape from torch.distributions.kl.kl_divergence #34859

cc @fehiepsi @feynmanliang @stefanwebb

@fritzo fritzo added the module: distributions Related to torch.distributions label Jan 15, 2021
torch/distributions/transforms.py Show resolved Hide resolved
torch/distributions/transforms.py Outdated Show resolved Hide resolved
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 17, 2021
@neerajprad
Copy link
Contributor

Could you merge in changes from the earlier PR?

@fritzo fritzo changed the title Independent transform Fix distribution.Transform shaping logic Jan 22, 2021
@fritzo fritzo changed the title Fix distribution.Transform shaping logic Fix TransformedDistribution shaping logic Jan 22, 2021
@neerajprad
Copy link
Contributor

The changes to the IndependentTransform look great. I'll defer to @stefanwebb and @fehiepsi for the ReshapeTransform related changes. I'm importing this internally to verify that this doesn't break anything on the client side, and will do another closer review later in the evening.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@fritzo fritzo requested a review from fehiepsi January 22, 2021 19:55
@codecov
Copy link

codecov bot commented Jan 23, 2021

Codecov Report

Merging #50581 (e8b121e) into master (b2e5617) will decrease coverage by 0.00%.
The diff coverage is 76.61%.

@@            Coverage Diff             @@
##           master   #50581      +/-   ##
==========================================
- Coverage   80.99%   80.99%   -0.01%     
==========================================
  Files        1916     1916              
  Lines      209590   209748     +158     
==========================================
+ Hits       169767   169884     +117     
- Misses      39823    39864      +41     

reinterpreted_batch_ndims = domain_event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
self.base_dist = base_distribution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty clean!

fehiepsi
fehiepsi previously approved these changes Jan 23, 2021
@fehiepsi fehiepsi dismissed their stale review January 23, 2021 05:27

Let me take another pass

Copy link
Contributor

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo You missed the forward_shape and inverse_shape for ExponentTransform and SoftmaxTransform. The former needs broadcasting to include exponent shape. The latter should be like StickBreakingTransform. Other changes look great to me!

Copy link
Contributor

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding detailed tests! LGTM, pending comments from @fehipsi's.

@fritzo
Copy link
Collaborator Author

fritzo commented Jan 23, 2021

@fehiepsi @neerajprad thanks for your detailed review!

You missed the forward_shape and inverse_shape for ExponentTransform and SoftmaxTransform

I don't think PyTorch has an ExponentTransform, that must be NumPyro only 😄 . I've added a check to SoftmaxTransform.forward_shape() and .inverse_shape() but note that transform preserves shape exactly, unlike StickBreakingTransform which changes the size of the rightmost axis.

@fehiepsi
Copy link
Contributor

Oops... I meant PowerTransform @fritzo...

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Jan 26, 2021
Summary:
Fixes #50496
Fixes #34859
Fixes #21596

This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.

## Summary of changes

- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.

## Changes planned for follow-up PRs

- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.

## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for #34859

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: #50581

Reviewed By: ezyang, glaringlee, jpchen

Differential Revision: D26024247

Pulled By: neerajprad

fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
@github-actions github-actions bot deleted the independent-transform branch February 10, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: distributions Related to torch.distributions open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
6 participants