From 21c2542b6a9faafce0b6a3e1583a07b3fba9269d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Jan 2021 18:40:17 -0800 Subject: [PATCH] Independent constraint (#50547) Summary: Addresses https://github.com/pytorch/pytorch/issues/50496 This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions. - Adds a `constraints.independent` and replaces `real_vector` with `independent(real, 1)`. (this pattern has long been used in Pyro) - Adds an `.event_dim` attribute to all constraints. - Tests that `constraint.check(data)` has the correct shape. (Previously the shapes were incorrect). - Adds machinery to set static `.is_discrete` and `.event_dim` for `constraints.dependent`. - Fixes constraints for a number of distributions. ## Tested - added a new check to the constraints tests - added a new check for `.event_dim` cc fehiepsi feynmanliang stefanwebb Pull Request resolved: https://github.com/pytorch/pytorch/pull/50547 Reviewed By: VitalyFedyunin Differential Revision: D25918330 Pulled By: neerajprad fbshipit-source-id: a648c3de3e8704f70f445c0f1c39f2593c8c74db --- test/distributions/test_constraints.py | 1 + test/distributions/test_distributions.py | 5 +- torch/distributions/binomial.py | 2 +- torch/distributions/categorical.py | 4 +- torch/distributions/constraint_registry.py | 12 +- torch/distributions/constraints.py | 147 ++++++++++++++++-- torch/distributions/independent.py | 3 +- .../lowrank_multivariate_normal.py | 8 +- torch/distributions/multinomial.py | 6 +- torch/distributions/multivariate_normal.py | 2 +- torch/distributions/one_hot_categorical.py | 2 +- 11 files changed, 162 insertions(+), 30 deletions(-) diff --git a/test/distributions/test_constraints.py b/test/distributions/test_constraints.py index d4dd9239920d..ffff932dfa37 100644 --- a/test/distributions/test_constraints.py +++ b/test/distributions/test_constraints.py @@ -7,6 +7,7 @@ CONSTRAINTS = [ (constraints.real,), + (constraints.real_vector,), (constraints.positive,), (constraints.greater_than, [-10., -2, 0, 2, 10]), (constraints.greater_than, 0), diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index a196169be142..0c84ff0e7058 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -3915,7 +3915,10 @@ def test_support_constraints(self): constraint = dist.support message = '{} example {}/{} sample = {}'.format( Dist.__name__, i + 1, len(params), value) - self.assertTrue(constraint.check(value).all(), msg=message) + self.assertEqual(constraint.event_dim, len(dist.event_shape), msg=message) + ok = constraint.check(value) + self.assertEqual(ok.shape, dist.batch_shape, msg=message) + self.assertTrue(ok.all(), msg=message) class TestNumericalStability(TestCase): diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index dc2e8fc5bad6..caafcfacb166 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._param.new(*args, **kwargs) - @constraints.dependent_property + @constraints.dependent_property(is_discrete=True) def support(self): return constraints.integer_interval(0, self.total_count) diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 319d2dd01b66..ec22a7e4f802 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -38,7 +38,7 @@ class Categorical(Distribution): logits (Tensor): event log-odds """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} + 'logits': constraints.real_vector} has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._param.new(*args, **kwargs) - @constraints.dependent_property + @constraints.dependent_property(is_discrete=True) def support(self): return constraints.integer_interval(0, self._num_events - 1) diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 4675b8ceaca8..63fd4b8bf9ce 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -153,13 +153,21 @@ def __call__(self, constraint): ################################################################################ @biject_to.register(constraints.real) -@biject_to.register(constraints.real_vector) @transform_to.register(constraints.real) -@transform_to.register(constraints.real_vector) def _transform_to_real(constraint): return transforms.identity_transform +@biject_to.register(constraints.independent) +def _biject_to_independent(constraint): + return biject_to(constraint.base_constraint) + + +@transform_to.register(constraints.independent) +def _transform_to_independent(constraint): + return transform_to(constraint.base_constraint) + + @biject_to.register(constraints.positive) @transform_to.register(constraints.positive) def _transform_to_positive(constraint): diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 87d72d52d26b..9a8c9bcebdfa 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -7,18 +7,20 @@ - ``constraints.dependent`` - ``constraints.greater_than(lower_bound)`` - ``constraints.greater_than_eq(lower_bound)`` +- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` - ``constraints.integer_interval(lower_bound, upper_bound)`` - ``constraints.interval(lower_bound, upper_bound)`` - ``constraints.less_than(upper_bound)`` - ``constraints.lower_cholesky`` - ``constraints.lower_triangular`` +- ``constraints.multinomial`` - ``constraints.nonnegative_integer`` - ``constraints.one_hot`` -- ``constraints.positive`` - ``constraints.positive_definite`` - ``constraints.positive_integer`` -- ``constraints.real`` +- ``constraints.positive`` - ``constraints.real_vector`` +- ``constraints.real`` - ``constraints.simplex`` - ``constraints.stack`` - ``constraints.unit_interval`` @@ -35,6 +37,7 @@ 'dependent_property', 'greater_than', 'greater_than_eq', + 'independent', 'integer_interval', 'interval', 'half_open_interval', @@ -42,6 +45,7 @@ 'less_than', 'lower_cholesky', 'lower_triangular', + 'multinomial', 'nonnegative_integer', 'positive', 'positive_definite', @@ -61,7 +65,8 @@ class Constraint(object): A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized. """ - is_discrete = False + is_discrete = False # Default to continuous. + event_dim = 0 # Default to univariate. def check(self, value): """ @@ -79,6 +84,23 @@ class _Dependent(Constraint): Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints. """ + def __init__(self, *, is_discrete=False, event_dim=0): + self.is_discrete = is_discrete + self.event_dim = event_dim + super().__init__() + + def __call__(self, *, is_discrete=None, event_dim=None): + """ + Support for syntax to customize static attributes:: + + constraints.dependent(is_discrete=True, event_dim=1) + """ + if is_discrete is None: + is_discrete = self.is_discrete + if event_dim is None: + event_dim = self.event_dim + return _Dependent(is_discrete=is_discrete, event_dim=event_dim) + def check(self, x): raise ValueError('Cannot determine validity of dependent constraint') @@ -102,7 +124,52 @@ def __init__(self, low, high): def support(self): return constraints.interval(self.low, self.high) """ - pass + def __init__(self, fn=None, *, is_discrete=False, event_dim=0): + self.is_discrete = is_discrete + self.event_dim = event_dim + super().__init__(fn) + + def __call__(self, fn): + """ + Support for syntax to customize static attributes:: + + @constraints.dependent_property(is_discrete=True, event_dim=1) + def support(self): + ... + """ + return _DependentProperty(fn, is_discrete=self.is_discrete, event_dim=self.event_dim) + + +class _IndependentConstraint(Constraint): + """ + Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many + dims in :meth:`check`, so that an event is valid only if all its + independent entries are valid. + """ + def __init__(self, base_constraint, reinterpreted_batch_ndims): + assert isinstance(base_constraint, Constraint) + assert isinstance(reinterpreted_batch_ndims, int) + assert reinterpreted_batch_ndims >= 0 + self.base_constraint = base_constraint + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__() + + @property + def is_discrete(self): + return self.base_constraint.is_discrete + + @property + def event_dim(self): + return self.base_constraint.event_dim + self.reinterpreted_batch_ndims + + def check(self, value): + result = self.base_constraint.check(value) + if result.dim() < self.reinterpreted_batch_ndims: + expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims + raise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}") + result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,)) + result = result.all(-1) + return result class _Boolean(Constraint): @@ -120,6 +187,7 @@ class _OneHot(Constraint): Constrain to one-hot vectors. """ is_discrete = True + event_dim = 1 def check(self, value): is_boolean = (value == 0) | (value == 1) @@ -136,6 +204,7 @@ class _IntegerInterval(Constraint): def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound + super().__init__() def check(self, value): return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) @@ -154,6 +223,7 @@ class _IntegerLessThan(Constraint): def __init__(self, upper_bound): self.upper_bound = upper_bound + super().__init__() def check(self, value): return (value % 1 == 0) & (value <= self.upper_bound) @@ -172,6 +242,7 @@ class _IntegerGreaterThan(Constraint): def __init__(self, lower_bound): self.lower_bound = lower_bound + super().__init__() def check(self, value): return (value % 1 == 0) & (value >= self.lower_bound) @@ -196,6 +267,7 @@ class _GreaterThan(Constraint): """ def __init__(self, lower_bound): self.lower_bound = lower_bound + super().__init__() def check(self, value): return self.lower_bound < value @@ -212,6 +284,7 @@ class _GreaterThanEq(Constraint): """ def __init__(self, lower_bound): self.lower_bound = lower_bound + super().__init__() def check(self, value): return self.lower_bound <= value @@ -228,6 +301,7 @@ class _LessThan(Constraint): """ def __init__(self, upper_bound): self.upper_bound = upper_bound + super().__init__() def check(self, value): return value < self.upper_bound @@ -245,6 +319,7 @@ class _Interval(Constraint): def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound + super().__init__() def check(self, value): return (self.lower_bound <= value) & (value <= self.upper_bound) @@ -262,6 +337,7 @@ class _HalfOpenInterval(Constraint): def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound + super().__init__() def check(self, value): return (self.lower_bound <= value) & (value < self.upper_bound) @@ -277,14 +353,36 @@ class _Simplex(Constraint): Constrain to the unit simplex in the innermost (rightmost) dimension. Specifically: `x >= 0` and `x.sum(-1) == 1`. """ + event_dim = 1 + def check(self, value): return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) +class _Multinomial(Constraint): + """ + Constrain to nonnegative integer values summing to at most an upper bound. + + Note due to limitations of the Multinomial distribution, this currently + checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future + this may be strengthened to ``value.sum(-1) == upper_bound``. + """ + is_discrete = True + event_dim = 1 + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + + def check(self, x): + return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) + + class _LowerTriangular(Constraint): """ Constrain to lower-triangular square matrices. """ + event_dim = 2 + def check(self, value): value_tril = value.tril() return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] @@ -294,6 +392,8 @@ class _LowerCholesky(Constraint): """ Constrain to lower-triangular square matrices with positive diagonals. """ + event_dim = 2 + def check(self, value): value_tril = value.tril() lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] @@ -307,6 +407,8 @@ class _CorrCholesky(Constraint): Constrain to lower-triangular square matrices with positive diagonals and each row vector being of unit length. """ + event_dim = 2 + def check(self, value): tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor row_norm = torch.linalg.norm(value.detach(), dim=-1) @@ -318,6 +420,8 @@ class _PositiveDefinite(Constraint): """ Constrain to positive-definite matrices. """ + event_dim = 2 + def check(self, value): matrix_shape = value.shape[-2:] batch_shape = value.unsqueeze(0).shape[:-2] @@ -328,15 +432,6 @@ def check(self, value): for v in flattened_value]).view(batch_shape) -class _RealVector(Constraint): - """ - Constrain to real-valued vectors. This is the same as `constraints.real`, - but additionally reduces across the `event_shape` dimension. - """ - def check(self, value): - return torch.all(value == value, dim=-1) # False for NANs. - - class _Cat(Constraint): """ Constraint functor that applies a sequence of constraints @@ -351,6 +446,15 @@ def __init__(self, cseq, dim=0, lengths=None): self.lengths = list(lengths) assert len(self.lengths) == len(self.cseq) self.dim = dim + super().__init__() + + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + return max(c.event_dim for c in self.cseq) def check(self, value): assert -value.dim() <= self.dim < value.dim() @@ -373,6 +477,18 @@ def __init__(self, cseq, dim=0): assert all(isinstance(c, Constraint) for c in cseq) self.cseq = list(cseq) self.dim = dim + super().__init__() + + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + dim = max(c.event_dim for c in self.cseq) + if self.dim + dim < 0: + dim += 1 + return dim def check(self, value): assert -value.dim() <= self.dim < value.dim() @@ -380,20 +496,23 @@ def check(self, value): return torch.stack([constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim) + # Public interface. dependent = _Dependent() dependent_property = _DependentProperty +independent = _IndependentConstraint boolean = _Boolean() one_hot = _OneHot() nonnegative_integer = _IntegerGreaterThan(0) positive_integer = _IntegerGreaterThan(1) integer_interval = _IntegerInterval real = _Real() -real_vector = _RealVector() +real_vector = independent(real, 1) positive = _GreaterThan(0.) greater_than = _GreaterThan greater_than_eq = _GreaterThanEq less_than = _LessThan +multinomial = _Multinomial unit_interval = _Interval(0., 1.) interval = _Interval half_open_interval = _HalfOpenInterval diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index de34bb604774..0776ca6f67a7 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -68,7 +68,8 @@ def has_enumerate_support(self): @constraints.dependent_property def support(self): - return self.base_dist.support + return constraints.independent(self.base_dist.support, + self.reinterpreted_batch_ndims) @property def mean(self): diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index 8b1e76c175a6..b184a6c485d6 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -73,10 +73,10 @@ class LowRankMultivariateNormal(Distribution): capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """ - arg_constraints = {"loc": constraints.real, - "cov_factor": constraints.real, - "cov_diag": constraints.positive} - support = constraints.real + arg_constraints = {"loc": constraints.real_vector, + "cov_factor": constraints.independent(constraints.real, 2), + "cov_diag": constraints.independent(constraints.positive, 1)} + support = constraints.real_vector has_rsample = True def __init__(self, loc, cov_factor, cov_diag, validate_args=None): diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 9162dd4713d4..b25560b0895e 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -38,7 +38,7 @@ class Multinomial(Distribution): logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} + 'logits': constraints.real_vector} total_count: int @property @@ -70,9 +70,9 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) - @constraints.dependent_property + @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): - return constraints.integer_interval(0, self.total_count) + return constraints.multinomial(self.total_count) @property def logits(self): diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 4845d4742dfc..fe6e91286fde 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -113,7 +113,7 @@ class MultivariateNormal(Distribution): 'covariance_matrix': constraints.positive_definite, 'precision_matrix': constraints.positive_definite, 'scale_tril': constraints.lower_cholesky} - support = constraints.real + support = constraints.real_vector has_rsample = True def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 64f696802d76..1ab53ccc7061 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -28,7 +28,7 @@ class OneHotCategorical(Distribution): logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} + 'logits': constraints.real_vector} support = constraints.one_hot has_enumerate_support = True