Skip to content

Commit

Permalink
Rebase and squash
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Jan 22, 2021
1 parent b2e5617 commit 4e1a93f
Show file tree
Hide file tree
Showing 13 changed files with 419 additions and 108 deletions.
2 changes: 1 addition & 1 deletion test/distributions/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_biject_to(constraint_fn, args, is_cuda):
assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint)

j = t.log_abs_det_jacobian(x, y)
assert j.shape == x.shape[:x.dim() - t.input_event_dim]
assert j.shape == x.shape[:x.dim() - t.domain.event_dim]


@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
Expand Down
20 changes: 18 additions & 2 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,22 @@ def test_has_examples(self):
self.assertIn(Dist, distributions_with_examples,
"Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))

def test_support_attributes(self):
for Dist, params in EXAMPLES:
for param in params:
d = Dist(**param)
event_dim = len(d.event_shape)
self.assertEqual(d.support.event_dim, event_dim)
try:
self.assertEqual(Dist.support.event_dim, event_dim)
except NotImplementedError:
pass
is_discrete = d.support.is_discrete
try:
self.assertEqual(Dist.support.is_discrete, is_discrete)
except NotImplementedError:
pass

def test_distribution_expand(self):
shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
for Dist, params in EXAMPLES:
Expand Down Expand Up @@ -1620,8 +1636,8 @@ def test_logisticnormal(self):
self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
self.assertEqual(LogisticNormal(0.2, .6).sample((1,)).size(), (2,))
self.assertEqual(LogisticNormal(-0.7, 50.0).sample((1,)).size(), (2,))
self.assertEqual(LogisticNormal(0.2, .6).sample().size(), (2,))
self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,))

# sample check for extreme value of mean, std
set_rng_seed(1)
Expand Down
64 changes: 57 additions & 7 deletions test/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch.autograd.functional import jacobian
from torch.distributions import Dirichlet, Normal, TransformedDistribution, constraints
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
CorrCholeskyTransform, ExpTransform,
LowerCholeskyTransform, PowerTransform,
CorrCholeskyTransform, ExpTransform, IndependentTransform,
LowerCholeskyTransform, PowerTransform, ReshapeTransform,
SigmoidTransform, TanhTransform, SoftmaxTransform,
StickBreakingTransform, identity_transform, Transform,
_InverseTransform)
Expand Down Expand Up @@ -57,6 +57,7 @@ def get_transforms(cache_size):
torch.randn(4, 5),
cache_size=cache_size),
]),
ReshapeTransform((4, 5), (2, 5, 2)),
]
transforms += [t.inv for t in transforms]
return transforms
Expand Down Expand Up @@ -92,7 +93,16 @@ def transform_id(x):

def generate_data(transform):
torch.manual_seed(1)
while isinstance(transform, IndependentTransform):
transform = transform.base_transform
if isinstance(transform, ReshapeTransform):
return torch.randn(transform.in_shape)
if isinstance(transform.inv, ReshapeTransform):
return torch.randn(transform.inv.out_shape)
domain = transform.domain
while (isinstance(domain, constraints.independent) and
domain.reinterpreted_batch_ndims == 0):
domain = domain.base_constraint
codomain = transform.codomain
x = torch.empty(4, 5)
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
Expand Down Expand Up @@ -170,8 +180,10 @@ def test_forward_inverse(transform, test_cached):
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
assert y.shape == transform.forward_shape(x.shape)
if test_cached:
x2 = transform.inv(y) # should be implemented at least by caching
x2.shape == transform.inverse_shape(y.shape)
else:
try:
x2 = transform.inv(y.clone()) # bypass cache
Expand Down Expand Up @@ -316,25 +328,29 @@ def test_jacobian(transform):
except NotImplementedError:
pytest.skip('Not implemented.')
# Test shape
target_shape = x.shape[:x.dim() - transform.input_event_dim]
target_shape = x.shape[:x.dim() - transform.domain.event_dim]
assert actual.shape == target_shape

# Expand if required
transform = reshape_transform(transform, x.shape)
ndims = len(x.shape)
event_dim = ndims - transform.input_event_dim
event_dim = ndims - transform.domain.event_dim
x_ = x.view((-1,) + x.shape[event_dim:])
n = x_.shape[0]
# Reshape to squash batch dims to a single batch dim
transform = reshape_transform(transform, x_.shape)

# 1. Transforms with 0 off-diagonal elements
if transform.input_event_dim == 0:
# 1. Transforms with unit jacobian
if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
# 2. Transforms with 0 off-diagonal elements
elif transform.domain.event_dim == 0:
jac = jacobian(transform, x_)
# assert off-diagonal elements are zero
assert torch.allclose(jac, jac.diagonal().diag_embed())
expected = jac.diagonal().abs().log().reshape(x.shape)
# 2. Transforms with non-0 off-diagonal elements
# 3. Transforms with non-0 off-diagonal elements
else:
if isinstance(transform, CorrCholeskyTransform):
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
Expand All @@ -361,5 +377,39 @@ def test_jacobian(transform):
assert torch.allclose(actual, expected, atol=1e-5)


@pytest.mark.parametrize("event_dims",
[(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
ids=str)
def test_compose_affine(event_dims):
transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == max(event_dims)
assert transform.domain.event_dim == max(event_dims)

dist = TransformedDistribution(Normal(0, 1), transform.parts)
assert dist.support.event_dim == max(event_dims)

dist = TransformedDistribution(Dirichlet(torch.ones(5)), transforms)
assert dist.support.event_dim == max(1, max(event_dims))


@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
transforms = [ReshapeTransform((), ()),
ReshapeTransform((2,), (1, 2)),
ReshapeTransform((3, 1, 2), (6,)),
ReshapeTransform((6,), (2, 3))]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == 2
assert transform.domain.event_dim == 2
data = torch.randn(batch_shape + (3, 2))
assert transform(data).shape == batch_shape + (2, 3)

dist = TransformedDistribution(Normal(data, 1), transforms)
assert dist.batch_shape == batch_shape
assert dist.event_shape == (2, 3)
assert dist.support.event_dim == 2


if __name__ == '__main__':
pytest.main([__file__])
2 changes: 1 addition & 1 deletion torch/distributions/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property(is_discrete=True)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self.total_count)

Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property(is_discrete=True)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

Expand Down
8 changes: 6 additions & 2 deletions torch/distributions/constraint_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,16 @@ def _transform_to_real(constraint):

@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
return biject_to(constraint.base_constraint)
base_transform = biject_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims)


@transform_to.register(constraints.independent)
def _transform_to_independent(constraint):
return transform_to(constraint.base_constraint)
base_transform = transform_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims)


@biject_to.register(constraints.positive)
Expand Down
72 changes: 56 additions & 16 deletions torch/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,20 @@ class Constraint(object):
A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
"""
Attributes:
is_discrete (bool): Whether constrained space is discrete.
Defaults to False.
event_dim (int): Number of rightmost dimensions that together define
an event. The :meth:`check` method will remove this many dimensions
when computing validity.
"""
is_discrete = False # Default to continuous.
event_dim = 0 # Default to univariate.

def check(self, value):
"""
Returns a byte tensor of `sample_shape + batch_shape` indicating
Returns a byte tensor of ``sample_shape + batch_shape`` indicating
whether each event in value satisfies this constraint.
"""
raise NotImplementedError
Expand All @@ -83,22 +90,42 @@ class _Dependent(Constraint):
"""
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
"""
def __init__(self, *, is_discrete=False, event_dim=0):
self.is_discrete = is_discrete
self.event_dim = event_dim
Args:
is_discrete (bool): Optional value of ``.is_discrete`` in case this
can be computed statically. If not provided, access to the
``.is_discrete`` attribute will raise a NotImplementedError.
event_dim (int): Optional value of ``.event_dim`` in case this
can be computed statically. If not provided, access to the
``.event_dim`` attribute will raise a NotImplementedError.
"""
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()

def __call__(self, *, is_discrete=None, event_dim=None):
@property
def is_discrete(self):
if self._is_discrete is NotImplemented:
raise NotImplementedError(".is_discrete cannot be determined statically")
return self._is_discrete

@property
def event_dim(self):
if self._event_dim is NotImplemented:
raise NotImplementedError(".event_dim cannot be determined statically")
return self._event_dim

def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
"""
Support for syntax to customize static attributes::
constraints.dependent(is_discrete=True, event_dim=1)
"""
if is_discrete is None:
is_discrete = self.is_discrete
if event_dim is None:
event_dim = self.event_dim
if is_discrete is NotImplemented:
is_discrete = self._is_discrete
if event_dim is NotImplemented:
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)

def check(self, x):
Expand All @@ -120,14 +147,23 @@ class Uniform(Distribution):
def __init__(self, low, high):
self.low = low
self.high = high
@constraints.dependent_property
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.interval(self.low, self.high)
Args:
fn (callable): The function to be decorated.
is_discrete (bool): Optional value of ``.is_discrete`` in case this
can be computed statically. If not provided, access to the
``.is_discrete`` attribute will raise a NotImplementedError.
event_dim (int): Optional value of ``.event_dim`` in case this
can be computed statically. If not provided, access to the
``.event_dim`` attribute will raise a NotImplementedError.
"""
def __init__(self, fn=None, *, is_discrete=False, event_dim=0):
self.is_discrete = is_discrete
self.event_dim = event_dim
def __init__(self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented):
super().__init__(fn)
self._is_discrete = is_discrete
self._event_dim = event_dim

def __call__(self, fn):
"""
Expand All @@ -137,7 +173,7 @@ def __call__(self, fn):
def support(self):
...
"""
return _DependentProperty(fn, is_discrete=self.is_discrete, event_dim=self.event_dim)
return _DependentProperty(fn, is_discrete=self._is_discrete, event_dim=self._event_dim)


class _IndependentConstraint(Constraint):
Expand Down Expand Up @@ -171,6 +207,10 @@ def check(self, value):
result = result.all(-1)
return result

def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__[1:], repr(self.base_constraint),
self.reinterpreted_batch_ndims)


class _Boolean(Constraint):
"""
Expand Down
5 changes: 2 additions & 3 deletions torch/distributions/logistic_normal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from torch.distributions import constraints
from torch.distributions.normal import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
Expand Down Expand Up @@ -33,11 +32,11 @@ class LogisticNormal(TransformedDistribution):

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
if not base_dist.batch_shape:
base_dist = base_dist.expand([1])
super(LogisticNormal, self).__init__(base_dist,
StickBreakingTransform(),
validate_args=validate_args)
# Adjust event shape since StickBreakingTransform adds 1 dimension
self._event_shape = torch.Size([s + 1 for s in self._event_shape])

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogisticNormal, _instance)
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def variance(self):
a = self.alpha.clamp(min=2)
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))

@constraints.dependent_property
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.greater_than(self.scale)

Expand Down
6 changes: 3 additions & 3 deletions torch/distributions/relaxed_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class ExpRelaxedCategorical(Distribution):
(Jang et al, 2017)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
support = constraints.real
'logits': constraints.real_vector}
support = constraints.real_vector # The true support is actually a submanifold of this.
has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
Expand Down Expand Up @@ -104,7 +104,7 @@ class RelaxedOneHotCategorical(TransformedDistribution):
logits (Tensor): the log probability of each event.
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
'logits': constraints.real_vector}
support = constraints.simplex
has_rsample = True

Expand Down

0 comments on commit 4e1a93f

Please sign in to comment.