-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
TransformedDistribution and event_shape #21596
Labels
module: distributions
Related to torch.distributions
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
umanwizard
added
the
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
label
Jun 12, 2019
@fritzo any idea how to fix this? One idea I have is add a
with |
It's tough to say without more examples. My first reaction would be to define a method class Transform(object):
def transform_event_shape(self, event_shape):
return event_shape
class StickBreakingTransform(Transform):
def transform_event_shape(self, event_shape):
return torch.Size((event_shape[0] - 1,)) and use that in |
5 tasks
facebook-github-bot
pushed a commit
that referenced
this issue
Jan 26, 2021
Summary: Fixes #50496 Fixes #34859 Fixes #21596 This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute. ## Summary of changes - Fixes `TransformDistribution` and `ComposeTransform` shape errors. - Fixes a behavior bug in `LogisticNormal`. - Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)` - Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`. - Adds an `IndependentTransform`. - Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch. - Fixes incorrect default values in `constraints.dependent.event_dim`. - Documents the `.event_dim` and `.is_discrete` attributes. ## Changes planned for follow-up PRs - Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often. ## Tested - [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree. - [x] refactoring is covered by existing tests - [x] add test cases for `ReshapedTransform` - [x] add a test for `TransformedDistribution` on a wide grid of input shapes - [x] added a regression test for #34859 cc fehiepsi feynmanliang stefanwebb Pull Request resolved: #50581 Reviewed By: ezyang, glaringlee, jpchen Differential Revision: D26024247 Pulled By: neerajprad fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: distributions
Related to torch.distributions
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Bug
TransformedDistribution does not seem to have correct event_shape when the transform has changed the number of dimensions (e.g. with
StickBreakingTransform
).To Reproduce
Steps to reproduce the behavior:
Produces the results of
Expected behavior
The last print statement should (seemingly) give:
Environment
conda
,pip
, source):conda
The text was updated successfully, but these errors were encountered: