Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformedDistribution and event_shape #21596

Closed
justindomke opened this issue Jun 10, 2019 · 2 comments
Closed

TransformedDistribution and event_shape #21596

justindomke opened this issue Jun 10, 2019 · 2 comments
Assignees
Labels
module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@justindomke
Copy link

馃悰 Bug

TransformedDistribution does not seem to have correct event_shape when the transform has changed the number of dimensions (e.g. with StickBreakingTransform).

To Reproduce

Steps to reproduce the behavior:

import torch as tt
dist = tt.distributions.Dirichlet(tt.ones(3))
support = dist.support
tform = tt.distributions.constraint_registry.biject_to(support)
dist_unconstrained = tt.distributions.TransformedDistribution(dist,tform.inv)

print(tform)
print(dist_unconstrained.sample())
print(dist_unconstrained.event_shape)

Produces the results of

StickBreakingTransform()
tensor([ 0.3628, -1.9573])
torch.Size([3])

Expected behavior

The last print statement should (seemingly) give:

torch.Size([2])

Environment

  • PyTorch Version (e.g., 1.0): 1.1.0
  • OS (e.g., Linux): Mac
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): n/a
  • Python version: 3.6.4
  • CUDA/cuDNN version: n/a
  • GPU models and configuration: n/a
  • Any other relevant information: n/a
@vishwakftw vishwakftw added the module: distributions Related to torch.distributions label Jun 11, 2019
@umanwizard umanwizard added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 12, 2019
@alicanb
Copy link
Collaborator

alicanb commented Dec 12, 2019

@fritzo any idea how to fix this? One idea I have is add a event_shape_modifier to all transforms and replace

shape = self.base_dist.batch_shape + self.base_dist.event_shape

with shape = self.base_dist.batch_shape + torch.Size([s + event_shape_modifier for s in self.base_dist.event_shape])

@fritzo
Copy link
Collaborator

fritzo commented Dec 12, 2019

It's tough to say without more examples. My first reaction would be to define a method

class Transform(object):
    def transform_event_shape(self, event_shape):
        return event_shape

class StickBreakingTransform(Transform):
    def transform_event_shape(self, event_shape):
        return torch.Size((event_shape[0] - 1,))

and use that in TransformedDistribution. This might not work with jitting.

@fritzo fritzo self-assigned this Jan 22, 2021
facebook-github-bot pushed a commit that referenced this issue Jan 26, 2021
Summary:
Fixes #50496
Fixes #34859
Fixes #21596

This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.

## Summary of changes

- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.

## Changes planned for follow-up PRs

- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.

## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for #34859

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: #50581

Reviewed By: ezyang, glaringlee, jpchen

Differential Revision: D26024247

Pulled By: neerajprad

fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants