Skip to content

Commit

Permalink
Independent constraint (#50547)
Browse files Browse the repository at this point in the history
Summary:
Addresses #50496

This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions.
- Adds a `constraints.independent` and replaces `real_vector` with `independent(real, 1)`. (this pattern has long been used in Pyro)
- Adds an `.event_dim` attribute to all constraints.
- Tests that `constraint.check(data)` has the correct shape. (Previously the shapes were incorrect).
- Adds machinery to set static `.is_discrete` and `.event_dim` for `constraints.dependent`.
- Fixes constraints for a number of distributions.

## Tested
- added a new check to the constraints tests
- added a new check for `.event_dim`

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: #50547

Reviewed By: VitalyFedyunin

Differential Revision: D25918330

Pulled By: neerajprad

fbshipit-source-id: a648c3de3e8704f70f445c0f1c39f2593c8c74db
  • Loading branch information
fritzo authored and facebook-github-bot committed Jan 22, 2021
1 parent 5016637 commit 21c2542
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 30 deletions.
1 change: 1 addition & 0 deletions test/distributions/test_constraints.py
Expand Up @@ -7,6 +7,7 @@

CONSTRAINTS = [
(constraints.real,),
(constraints.real_vector,),
(constraints.positive,),
(constraints.greater_than, [-10., -2, 0, 2, 10]),
(constraints.greater_than, 0),
Expand Down
5 changes: 4 additions & 1 deletion test/distributions/test_distributions.py
Expand Up @@ -3915,7 +3915,10 @@ def test_support_constraints(self):
constraint = dist.support
message = '{} example {}/{} sample = {}'.format(
Dist.__name__, i + 1, len(params), value)
self.assertTrue(constraint.check(value).all(), msg=message)
self.assertEqual(constraint.event_dim, len(dist.event_shape), msg=message)
ok = constraint.check(value)
self.assertEqual(ok.shape, dist.batch_shape, msg=message)
self.assertTrue(ok.all(), msg=message)


class TestNumericalStability(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/binomial.py
Expand Up @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property
@constraints.dependent_property(is_discrete=True)
def support(self):
return constraints.integer_interval(0, self.total_count)

Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/categorical.py
Expand Up @@ -38,7 +38,7 @@ class Categorical(Distribution):
logits (Tensor): event log-odds
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
'logits': constraints.real_vector}
has_enumerate_support = True

def __init__(self, probs=None, logits=None, validate_args=None):
Expand Down Expand Up @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property
@constraints.dependent_property(is_discrete=True)
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

Expand Down
12 changes: 10 additions & 2 deletions torch/distributions/constraint_registry.py
Expand Up @@ -153,13 +153,21 @@ def __call__(self, constraint):
################################################################################

@biject_to.register(constraints.real)
@biject_to.register(constraints.real_vector)
@transform_to.register(constraints.real)
@transform_to.register(constraints.real_vector)
def _transform_to_real(constraint):
return transforms.identity_transform


@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
return biject_to(constraint.base_constraint)


@transform_to.register(constraints.independent)
def _transform_to_independent(constraint):
return transform_to(constraint.base_constraint)


@biject_to.register(constraints.positive)
@transform_to.register(constraints.positive)
def _transform_to_positive(constraint):
Expand Down

0 comments on commit 21c2542

Please sign in to comment.