Skip to content

Commit

Permalink
Fix TransformedDistribution shaping logic (#50581)
Browse files Browse the repository at this point in the history
Summary:
Fixes #50496
Fixes #34859
Fixes #21596

This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.

## Summary of changes

- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.

## Changes planned for follow-up PRs

- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.

## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for #34859

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: #50581

Reviewed By: ezyang, glaringlee, jpchen

Differential Revision: D26024247

Pulled By: neerajprad

fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
  • Loading branch information
fritzo authored and facebook-github-bot committed Jan 26, 2021
1 parent 250c711 commit a347c74
Show file tree
Hide file tree
Showing 15 changed files with 557 additions and 124 deletions.
2 changes: 1 addition & 1 deletion test/distributions/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_biject_to(constraint_fn, args, is_cuda):
assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint)

j = t.log_abs_det_jacobian(x, y)
assert j.shape == x.shape[:x.dim() - t.input_event_dim]
assert j.shape == x.shape[:x.dim() - t.domain.event_dim]


@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
Expand Down
30 changes: 28 additions & 2 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,22 @@ def test_has_examples(self):
self.assertIn(Dist, distributions_with_examples,
"Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))

def test_support_attributes(self):
for Dist, params in EXAMPLES:
for param in params:
d = Dist(**param)
event_dim = len(d.event_shape)
self.assertEqual(d.support.event_dim, event_dim)
try:
self.assertEqual(Dist.support.event_dim, event_dim)
except NotImplementedError:
pass
is_discrete = d.support.is_discrete
try:
self.assertEqual(Dist.support.is_discrete, is_discrete)
except NotImplementedError:
pass

def test_distribution_expand(self):
shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
for Dist, params in EXAMPLES:
Expand Down Expand Up @@ -1620,8 +1636,8 @@ def test_logisticnormal(self):
self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
self.assertEqual(LogisticNormal(0.2, .6).sample((1,)).size(), (2,))
self.assertEqual(LogisticNormal(-0.7, 50.0).sample((1,)).size(), (2,))
self.assertEqual(LogisticNormal(0.2, .6).sample().size(), (2,))
self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,))

# sample check for extreme value of mean, std
set_rng_seed(1)
Expand Down Expand Up @@ -3832,6 +3848,16 @@ def test_kl_shape(self):
'Actual {}'.format(kl.shape),
]))

def test_kl_transformed(self):
# Regression test for https://github.com/pytorch/pytorch/issues/34859
scale = torch.ones(2, 3)
loc = torch.zeros(2, 3)
normal = Normal(loc=loc, scale=scale)
diag_normal = Independent(normal, reinterpreted_batch_ndims=1)
trans_dist = TransformedDistribution(diag_normal, AffineTransform(loc=0., scale=2.))
self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,))
self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,))

def test_entropy_monte_carlo(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for Dist, params in EXAMPLES:
Expand Down
122 changes: 114 additions & 8 deletions test/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import torch
from torch.autograd.functional import jacobian
from torch.distributions import Dirichlet, Normal, TransformedDistribution, constraints
from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
CorrCholeskyTransform, ExpTransform,
LowerCholeskyTransform, PowerTransform,
CorrCholeskyTransform, ExpTransform, IndependentTransform,
LowerCholeskyTransform, PowerTransform, ReshapeTransform,
SigmoidTransform, TanhTransform, SoftmaxTransform,
StickBreakingTransform, identity_transform, Transform,
_InverseTransform)
Expand All @@ -22,6 +22,8 @@ def get_transforms(cache_size):
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
TanhTransform(cache_size=cache_size),
AffineTransform(0, 1, cache_size=cache_size),
Expand Down Expand Up @@ -57,6 +59,12 @@ def get_transforms(cache_size):
torch.randn(4, 5),
cache_size=cache_size),
]),
ReshapeTransform((4, 5), (2, 5, 2)),
IndependentTransform(
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
1),
]
transforms += [t.inv for t in transforms]
return transforms
Expand Down Expand Up @@ -92,7 +100,16 @@ def transform_id(x):

def generate_data(transform):
torch.manual_seed(1)
while isinstance(transform, IndependentTransform):
transform = transform.base_transform
if isinstance(transform, ReshapeTransform):
return torch.randn(transform.in_shape)
if isinstance(transform.inv, ReshapeTransform):
return torch.randn(transform.inv.out_shape)
domain = transform.domain
while (isinstance(domain, constraints.independent) and
domain is not constraints.real_vector):
domain = domain.base_constraint
codomain = transform.codomain
x = torch.empty(4, 5)
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
Expand Down Expand Up @@ -170,13 +187,15 @@ def test_forward_inverse(transform, test_cached):
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
assert y.shape == transform.forward_shape(x.shape)
if test_cached:
x2 = transform.inv(y) # should be implemented at least by caching
else:
try:
x2 = transform.inv(y.clone()) # bypass cache
except NotImplementedError:
pytest.skip('Not implemented.')
assert x2.shape == transform.inverse_shape(y.shape)
y2 = transform(x2)
if transform.bijective:
# verify function inverse
Expand Down Expand Up @@ -316,25 +335,29 @@ def test_jacobian(transform):
except NotImplementedError:
pytest.skip('Not implemented.')
# Test shape
target_shape = x.shape[:x.dim() - transform.input_event_dim]
target_shape = x.shape[:x.dim() - transform.domain.event_dim]
assert actual.shape == target_shape

# Expand if required
transform = reshape_transform(transform, x.shape)
ndims = len(x.shape)
event_dim = ndims - transform.input_event_dim
event_dim = ndims - transform.domain.event_dim
x_ = x.view((-1,) + x.shape[event_dim:])
n = x_.shape[0]
# Reshape to squash batch dims to a single batch dim
transform = reshape_transform(transform, x_.shape)

# 1. Transforms with 0 off-diagonal elements
if transform.input_event_dim == 0:
# 1. Transforms with unit jacobian
if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
# 2. Transforms with 0 off-diagonal elements
elif transform.domain.event_dim == 0:
jac = jacobian(transform, x_)
# assert off-diagonal elements are zero
assert torch.allclose(jac, jac.diagonal().diag_embed())
expected = jac.diagonal().abs().log().reshape(x.shape)
# 2. Transforms with non-0 off-diagonal elements
# 3. Transforms with non-0 off-diagonal elements
else:
if isinstance(transform, CorrCholeskyTransform):
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
Expand All @@ -361,5 +384,88 @@ def test_jacobian(transform):
assert torch.allclose(actual, expected, atol=1e-5)


@pytest.mark.parametrize("event_dims",
[(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
ids=str)
def test_compose_affine(event_dims):
transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == max(event_dims)
assert transform.domain.event_dim == max(event_dims)

base_dist = Normal(0, 1)
if transform.domain.event_dim:
base_dist = base_dist.expand((1,) * transform.domain.event_dim)
dist = TransformedDistribution(base_dist, transform.parts)
assert dist.support.event_dim == max(event_dims)

base_dist = Dirichlet(torch.ones(5))
if transform.domain.event_dim > 1:
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
dist = TransformedDistribution(base_dist, transforms)
assert dist.support.event_dim == max(1, max(event_dims))


@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
transforms = [ReshapeTransform((), ()),
ReshapeTransform((2,), (1, 2)),
ReshapeTransform((3, 1, 2), (6,)),
ReshapeTransform((6,), (2, 3))]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == 2
assert transform.domain.event_dim == 2
data = torch.randn(batch_shape + (3, 2))
assert transform(data).shape == batch_shape + (2, 3)

dist = TransformedDistribution(Normal(data, 1), transforms)
assert dist.batch_shape == batch_shape
assert dist.event_shape == (2, 3)
assert dist.support.event_dim == 2


@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim,
num_transforms, sample_shape):
shape = torch.Size([2, 3, 4, 5])
base_dist = Normal(0, 1)
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
if base_event_dim:
base_dist = Independent(base_dist, base_event_dim)
transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
ReshapeTransform((4, 5), (20,)),
ReshapeTransform((3, 20), (6, 10))]
transforms = transforms[:num_transforms]
transform = ComposeTransform(transforms)

# Check validation in .__init__().
if base_batch_dim + base_event_dim < transform.domain.event_dim:
with pytest.raises(ValueError):
TransformedDistribution(base_dist, transforms)
return
d = TransformedDistribution(base_dist, transforms)

# Check sampling is sufficiently expanded.
x = d.sample(sample_shape)
assert x.shape == sample_shape + d.batch_shape + d.event_shape
num_unique = len(set(x.reshape(-1).tolist()))
assert num_unique >= 0.9 * x.numel()

# Check log_prob shape on full samples.
log_prob = d.log_prob(x)
assert log_prob.shape == sample_shape + d.batch_shape

# Check log_prob shape on partial samples.
y = x
while y.dim() > len(d.event_shape):
y = y[0]
log_prob = d.log_prob(y)
assert log_prob.shape == d.batch_shape


if __name__ == '__main__':
pytest.main([__file__])
2 changes: 1 addition & 1 deletion torch/distributions/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property(is_discrete=True)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self.total_count)

Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property(is_discrete=True)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

Expand Down
8 changes: 6 additions & 2 deletions torch/distributions/constraint_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,16 @@ def _transform_to_real(constraint):

@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
return biject_to(constraint.base_constraint)
base_transform = biject_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims)


@transform_to.register(constraints.independent)
def _transform_to_independent(constraint):
return transform_to(constraint.base_constraint)
base_transform = transform_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims)


@biject_to.register(constraints.positive)
Expand Down

0 comments on commit a347c74

Please sign in to comment.