From fb8274f32db5370a7e808bb9e6c33279eb8dd1e1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Jan 2021 11:58:36 -0500 Subject: [PATCH 1/7] Add constraints.independent and Constraint.event_dim --- test/distributions/test_distributions.py | 4 +- torch/distributions/categorical.py | 2 +- torch/distributions/constraint_registry.py | 12 ++- torch/distributions/constraints.py | 78 ++++++++++++++++--- .../lowrank_multivariate_normal.py | 8 +- torch/distributions/multinomial.py | 2 +- torch/distributions/multivariate_normal.py | 2 +- torch/distributions/one_hot_categorical.py | 2 +- 8 files changed, 87 insertions(+), 23 deletions(-) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index a196169be142..bed01815573d 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -3915,7 +3915,9 @@ def test_support_constraints(self): constraint = dist.support message = '{} example {}/{} sample = {}'.format( Dist.__name__, i + 1, len(params), value) - self.assertTrue(constraint.check(value).all(), msg=message) + ok = constraint.check(value).all() + assert ok.shape == dist.batch_shape + self.assertTrue(ok.all(), msg=message) class TestNumericalStability(TestCase): diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 319d2dd01b66..eebfbffdffc7 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -38,7 +38,7 @@ class Categorical(Distribution): logits (Tensor): event log-odds """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} + 'logits': constraints.real_vector} has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 4675b8ceaca8..63fd4b8bf9ce 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -153,13 +153,21 @@ def __call__(self, constraint): ################################################################################ @biject_to.register(constraints.real) -@biject_to.register(constraints.real_vector) @transform_to.register(constraints.real) -@transform_to.register(constraints.real_vector) def _transform_to_real(constraint): return transforms.identity_transform +@biject_to.register(constraints.independent) +def _biject_to_independent(constraint): + return biject_to(constraint.base_constraint) + + +@transform_to.register(constraints.independent) +def _transform_to_independent(constraint): + return transform_to(constraint.base_constraint) + + @biject_to.register(constraints.positive) @transform_to.register(constraints.positive) def _transform_to_positive(constraint): diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 87d72d52d26b..7cbaff92d6cb 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -7,6 +7,7 @@ - ``constraints.dependent`` - ``constraints.greater_than(lower_bound)`` - ``constraints.greater_than_eq(lower_bound)`` +- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` - ``constraints.integer_interval(lower_bound, upper_bound)`` - ``constraints.interval(lower_bound, upper_bound)`` - ``constraints.less_than(upper_bound)`` @@ -14,11 +15,11 @@ - ``constraints.lower_triangular`` - ``constraints.nonnegative_integer`` - ``constraints.one_hot`` -- ``constraints.positive`` - ``constraints.positive_definite`` - ``constraints.positive_integer`` -- ``constraints.real`` +- ``constraints.positive`` - ``constraints.real_vector`` +- ``constraints.real`` - ``constraints.simplex`` - ``constraints.stack`` - ``constraints.unit_interval`` @@ -35,6 +36,7 @@ 'dependent_property', 'greater_than', 'greater_than_eq', + 'independent', 'integer_interval', 'interval', 'half_open_interval', @@ -62,6 +64,7 @@ class Constraint(object): e.g. within which a variable can be optimized. """ is_discrete = False + event_dim = 0 def check(self, value): """ @@ -105,6 +108,34 @@ def support(self): pass +class _IndependentConstraint(Constraint): + """ + Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many + dims in :meth:`check`, so that an event is valid only if all its + independent entries are valid. + """ + def __init__(self, base_constraint, reinterpreted_batch_ndims): + assert isinstance(base_constraint, Constraint) + assert isinstance(reinterpreted_batch_ndims, int) + assert reinterpreted_batch_ndims >= 0 + self.base_constraint = base_constraint + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + + @property + def is_discrete(self): + return self.base_dist.is_discrete + + @property + def event_dim(self): + return self.base_dist.event_dim + self.reinterpreted_batch_ndims + + def check(self, value): + result = self.base_constraint.check(value) + result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,)) + result = result.all(-1) + return result + + class _Boolean(Constraint): """ Constrain to the two values `{0, 1}`. @@ -120,6 +151,7 @@ class _OneHot(Constraint): Constrain to one-hot vectors. """ is_discrete = True + event_dim = 1 def check(self, value): is_boolean = (value == 0) | (value == 1) @@ -277,6 +309,8 @@ class _Simplex(Constraint): Constrain to the unit simplex in the innermost (rightmost) dimension. Specifically: `x >= 0` and `x.sum(-1) == 1`. """ + event_dim = 1 + def check(self, value): return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) @@ -285,6 +319,8 @@ class _LowerTriangular(Constraint): """ Constrain to lower-triangular square matrices. """ + event_dim = 2 + def check(self, value): value_tril = value.tril() return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] @@ -294,6 +330,8 @@ class _LowerCholesky(Constraint): """ Constrain to lower-triangular square matrices with positive diagonals. """ + event_dim = 2 + def check(self, value): value_tril = value.tril() lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] @@ -307,6 +345,8 @@ class _CorrCholesky(Constraint): Constrain to lower-triangular square matrices with positive diagonals and each row vector being of unit length. """ + event_dim = 2 + def check(self, value): tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor row_norm = torch.linalg.norm(value.detach(), dim=-1) @@ -318,6 +358,8 @@ class _PositiveDefinite(Constraint): """ Constrain to positive-definite matrices. """ + event_dim = 2 + def check(self, value): matrix_shape = value.shape[-2:] batch_shape = value.unsqueeze(0).shape[:-2] @@ -328,15 +370,6 @@ def check(self, value): for v in flattened_value]).view(batch_shape) -class _RealVector(Constraint): - """ - Constrain to real-valued vectors. This is the same as `constraints.real`, - but additionally reduces across the `event_shape` dimension. - """ - def check(self, value): - return torch.all(value == value, dim=-1) # False for NANs. - - class _Cat(Constraint): """ Constraint functor that applies a sequence of constraints @@ -352,6 +385,14 @@ def __init__(self, cseq, dim=0, lengths=None): assert len(self.lengths) == len(self.cseq) self.dim = dim + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + return max(c.event_dim for c in self.cseq) + def check(self, value): assert -value.dim() <= self.dim < value.dim() checks = [] @@ -374,22 +415,35 @@ def __init__(self, cseq, dim=0): self.cseq = list(cseq) self.dim = dim + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + dim = max(c.event_dim for c in self.cseq) + if self.dim + dim < 0: + dim += 1 + return dim + def check(self, value): assert -value.dim() <= self.dim < value.dim() vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] return torch.stack([constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim) + # Public interface. dependent = _Dependent() dependent_property = _DependentProperty +independent = _IndependentConstraint boolean = _Boolean() one_hot = _OneHot() nonnegative_integer = _IntegerGreaterThan(0) positive_integer = _IntegerGreaterThan(1) integer_interval = _IntegerInterval real = _Real() -real_vector = _RealVector() +real_vector = independent(real, 1) positive = _GreaterThan(0.) greater_than = _GreaterThan greater_than_eq = _GreaterThanEq diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index 8b1e76c175a6..b184a6c485d6 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -73,10 +73,10 @@ class LowRankMultivariateNormal(Distribution): capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """ - arg_constraints = {"loc": constraints.real, - "cov_factor": constraints.real, - "cov_diag": constraints.positive} - support = constraints.real + arg_constraints = {"loc": constraints.real_vector, + "cov_factor": constraints.independent(constraints.real, 2), + "cov_diag": constraints.independent(constraints.positive, 1)} + support = constraints.real_vector has_rsample = True def __init__(self, loc, cov_factor, cov_diag, validate_args=None): diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 9162dd4713d4..5ccc67c8a32a 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -38,7 +38,7 @@ class Multinomial(Distribution): logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} + 'logits': constraints.real_vector} total_count: int @property diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 4845d4742dfc..fe6e91286fde 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -113,7 +113,7 @@ class MultivariateNormal(Distribution): 'covariance_matrix': constraints.positive_definite, 'precision_matrix': constraints.positive_definite, 'scale_tril': constraints.lower_cholesky} - support = constraints.real + support = constraints.real_vector has_rsample = True def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 64f696802d76..1ab53ccc7061 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -28,7 +28,7 @@ class OneHotCategorical(Distribution): logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} + 'logits': constraints.real_vector} support = constraints.one_hot has_enumerate_support = True From 1d75d76bca27d8d47cb9845fd2c96424cc8ae8f3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Jan 2021 13:00:00 -0500 Subject: [PATCH 2/7] Fix tests, support static attrs of dependent --- test/distributions/test_constraints.py | 1 + test/distributions/test_distributions.py | 2 +- torch/distributions/binomial.py | 2 +- torch/distributions/categorical.py | 2 +- torch/distributions/constraints.py | 51 +++++++++++++++++++++--- torch/distributions/independent.py | 3 +- torch/distributions/multinomial.py | 5 ++- torch/distributions/uniform.py | 3 +- 8 files changed, 57 insertions(+), 12 deletions(-) diff --git a/test/distributions/test_constraints.py b/test/distributions/test_constraints.py index d4dd9239920d..ffff932dfa37 100644 --- a/test/distributions/test_constraints.py +++ b/test/distributions/test_constraints.py @@ -7,6 +7,7 @@ CONSTRAINTS = [ (constraints.real,), + (constraints.real_vector,), (constraints.positive,), (constraints.greater_than, [-10., -2, 0, 2, 10]), (constraints.greater_than, 0), diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index bed01815573d..077412e963fc 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -3915,7 +3915,7 @@ def test_support_constraints(self): constraint = dist.support message = '{} example {}/{} sample = {}'.format( Dist.__name__, i + 1, len(params), value) - ok = constraint.check(value).all() + ok = constraint.check(value) assert ok.shape == dist.batch_shape self.assertTrue(ok.all(), msg=message) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index dc2e8fc5bad6..caafcfacb166 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._param.new(*args, **kwargs) - @constraints.dependent_property + @constraints.dependent_property(is_discrete=True) def support(self): return constraints.integer_interval(0, self.total_count) diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index eebfbffdffc7..ec22a7e4f802 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._param.new(*args, **kwargs) - @constraints.dependent_property + @constraints.dependent_property(is_discrete=True) def support(self): return constraints.integer_interval(0, self._num_events - 1) diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 7cbaff92d6cb..9017abc1cc38 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -63,8 +63,8 @@ class Constraint(object): A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized. """ - is_discrete = False - event_dim = 0 + is_discrete = False # Default to continuous. + event_dim = 0 # Default to univariate. def check(self, value): """ @@ -82,6 +82,23 @@ class _Dependent(Constraint): Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints. """ + def __init__(self, *, is_discrete=False, event_dim=0): + self.is_discrete = is_discrete + self.event_dim = event_dim + super().__init__() + + def __call__(self, *, is_discrete=None, event_dim=None): + """ + Support for syntax to customize static attributes:: + + constraints.dependent(is_discrete=True, event_dim=1) + """ + if is_discrete is None: + is_discrete = self.is_discrete + if event_dim is None: + event_dim = self.event_dim + return _Dependent(is_discrete=is_discrete, event_dim=event_dim) + def check(self, x): raise ValueError('Cannot determine validity of dependent constraint') @@ -105,7 +122,20 @@ def __init__(self, low, high): def support(self): return constraints.interval(self.low, self.high) """ - pass + def __init__(self, fn=None, *, is_discrete=False, event_dim=0): + self.is_discrete = is_discrete + self.event_dim = event_dim + super().__init__(fn) + + def __call__(self, fn): + """ + Support for syntax to customize static attributes:: + + @constraints.dependent_dependent(is_discrete=True, event_dim=1) + def support(self): + ... + """ + return _DependentProperty(fn, is_discrete=self.is_discrete, event_dim=self.event_dim) class _IndependentConstraint(Constraint): @@ -120,14 +150,15 @@ def __init__(self, base_constraint, reinterpreted_batch_ndims): assert reinterpreted_batch_ndims >= 0 self.base_constraint = base_constraint self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__() @property def is_discrete(self): - return self.base_dist.is_discrete + return self.base_constraint.is_discrete @property def event_dim(self): - return self.base_dist.event_dim + self.reinterpreted_batch_ndims + return self.base_constraint.event_dim + self.reinterpreted_batch_ndims def check(self, value): result = self.base_constraint.check(value) @@ -168,6 +199,7 @@ class _IntegerInterval(Constraint): def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound + super().__init__() def check(self, value): return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) @@ -186,6 +218,7 @@ class _IntegerLessThan(Constraint): def __init__(self, upper_bound): self.upper_bound = upper_bound + super().__init__() def check(self, value): return (value % 1 == 0) & (value <= self.upper_bound) @@ -204,6 +237,7 @@ class _IntegerGreaterThan(Constraint): def __init__(self, lower_bound): self.lower_bound = lower_bound + super().__init__() def check(self, value): return (value % 1 == 0) & (value >= self.lower_bound) @@ -228,6 +262,7 @@ class _GreaterThan(Constraint): """ def __init__(self, lower_bound): self.lower_bound = lower_bound + super().__init__() def check(self, value): return self.lower_bound < value @@ -244,6 +279,7 @@ class _GreaterThanEq(Constraint): """ def __init__(self, lower_bound): self.lower_bound = lower_bound + super().__init__() def check(self, value): return self.lower_bound <= value @@ -260,6 +296,7 @@ class _LessThan(Constraint): """ def __init__(self, upper_bound): self.upper_bound = upper_bound + super().__init__() def check(self, value): return value < self.upper_bound @@ -277,6 +314,7 @@ class _Interval(Constraint): def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound + super().__init__() def check(self, value): return (self.lower_bound <= value) & (value <= self.upper_bound) @@ -294,6 +332,7 @@ class _HalfOpenInterval(Constraint): def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound + super().__init__() def check(self, value): return (self.lower_bound <= value) & (value < self.upper_bound) @@ -384,6 +423,7 @@ def __init__(self, cseq, dim=0, lengths=None): self.lengths = list(lengths) assert len(self.lengths) == len(self.cseq) self.dim = dim + super().__init__() @property def is_discrete(self): @@ -414,6 +454,7 @@ def __init__(self, cseq, dim=0): assert all(isinstance(c, Constraint) for c in cseq) self.cseq = list(cseq) self.dim = dim + super().__init__() @property def is_discrete(self): diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index de34bb604774..0776ca6f67a7 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -68,7 +68,8 @@ def has_enumerate_support(self): @constraints.dependent_property def support(self): - return self.base_dist.support + return constraints.independent(self.base_dist.support, + self.reinterpreted_batch_ndims) @property def mean(self): diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 5ccc67c8a32a..cbce1b3cf186 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -70,9 +70,10 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) - @constraints.dependent_property + @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): - return constraints.integer_interval(0, self.total_count) + return constraints.independent( + constraints.integer_interval(0, self.total_count), 1) @property def logits(self): diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index edaf5abf77a5..8912de0c8bca 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -22,7 +22,8 @@ class Uniform(Distribution): high (float or Tensor): upper range (exclusive). """ # TODO allow (loc,scale) parameterization to allow independent constraints. - arg_constraints = {'low': constraints.dependent, 'high': constraints.dependent} + arg_constraints = {'low': constraints.dependent(is_discrete=False, event_dim=0), + 'high': constraints.dependent(is_discrete=False, event_dim=0)} has_rsample = True @property From 1a93eda77d75eb14824a9efe98bfa81c10da2cf0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Jan 2021 13:10:47 -0500 Subject: [PATCH 3/7] Add check for constraint.event_dim --- test/distributions/test_distributions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 077412e963fc..3dacfc542e0d 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -3913,10 +3913,11 @@ def test_support_constraints(self): dist = Dist(**param) value = dist.sample() constraint = dist.support + self.assertEqual(constraint.event_dim, len(dist.event_shape)) message = '{} example {}/{} sample = {}'.format( Dist.__name__, i + 1, len(params), value) ok = constraint.check(value) - assert ok.shape == dist.batch_shape + self.assertEqual(ok.shape, dist.batch_shape) self.assertTrue(ok.all(), msg=message) From 7ceabb7cab1a2bd75e27468da7710fd80f50f3f8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Jan 2021 13:11:30 -0500 Subject: [PATCH 4/7] Fix test printing --- test/distributions/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 3dacfc542e0d..0c84ff0e7058 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -3913,11 +3913,11 @@ def test_support_constraints(self): dist = Dist(**param) value = dist.sample() constraint = dist.support - self.assertEqual(constraint.event_dim, len(dist.event_shape)) message = '{} example {}/{} sample = {}'.format( Dist.__name__, i + 1, len(params), value) + self.assertEqual(constraint.event_dim, len(dist.event_shape), msg=message) ok = constraint.check(value) - self.assertEqual(ok.shape, dist.batch_shape) + self.assertEqual(ok.shape, dist.batch_shape, msg=message) self.assertTrue(ok.all(), msg=message) From d50b99270848c7186d45f674431aaab882b4a6b3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Jan 2021 13:25:33 -0500 Subject: [PATCH 5/7] Simplify --- torch/distributions/constraints.py | 2 +- torch/distributions/uniform.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 9017abc1cc38..be36cb5b880e 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -131,7 +131,7 @@ def __call__(self, fn): """ Support for syntax to customize static attributes:: - @constraints.dependent_dependent(is_discrete=True, event_dim=1) + @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): ... """ diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index 8912de0c8bca..edaf5abf77a5 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -22,8 +22,7 @@ class Uniform(Distribution): high (float or Tensor): upper range (exclusive). """ # TODO allow (loc,scale) parameterization to allow independent constraints. - arg_constraints = {'low': constraints.dependent(is_discrete=False, event_dim=0), - 'high': constraints.dependent(is_discrete=False, event_dim=0)} + arg_constraints = {'low': constraints.dependent, 'high': constraints.dependent} has_rsample = True @property From 20dafb7457935743be0763c18f6da24dbf1ef763 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Jan 2021 17:28:23 -0500 Subject: [PATCH 6/7] Fix shape of batched Multinomial.total_count --- torch/distributions/multinomial.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index cbce1b3cf186..6c94e3413b50 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -72,8 +72,11 @@ def _new(self, *args, **kwargs): @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): + total_count = self.total_count + if hasattr(total_count, "unsqueeze"): + total_count = total_count.unsqueeze(-1) return constraints.independent( - constraints.integer_interval(0, self.total_count), 1) + constraints.integer_interval(0, total_count), 1) @property def logits(self): From 0e58ebe6be52a31467d1d1fc3e80baffd484b967 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Jan 2021 18:45:32 -0500 Subject: [PATCH 7/7] Address review comments --- torch/distributions/constraints.py | 24 ++++++++++++++++++++++++ torch/distributions/multinomial.py | 6 +----- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index be36cb5b880e..9a8c9bcebdfa 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -13,6 +13,7 @@ - ``constraints.less_than(upper_bound)`` - ``constraints.lower_cholesky`` - ``constraints.lower_triangular`` +- ``constraints.multinomial`` - ``constraints.nonnegative_integer`` - ``constraints.one_hot`` - ``constraints.positive_definite`` @@ -44,6 +45,7 @@ 'less_than', 'lower_cholesky', 'lower_triangular', + 'multinomial', 'nonnegative_integer', 'positive', 'positive_definite', @@ -162,6 +164,9 @@ def event_dim(self): def check(self, value): result = self.base_constraint.check(value) + if result.dim() < self.reinterpreted_batch_ndims: + expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims + raise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}") result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,)) result = result.all(-1) return result @@ -354,6 +359,24 @@ def check(self, value): return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) +class _Multinomial(Constraint): + """ + Constrain to nonnegative integer values summing to at most an upper bound. + + Note due to limitations of the Multinomial distribution, this currently + checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future + this may be strengthened to ``value.sum(-1) == upper_bound``. + """ + is_discrete = True + event_dim = 1 + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + + def check(self, x): + return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) + + class _LowerTriangular(Constraint): """ Constrain to lower-triangular square matrices. @@ -489,6 +512,7 @@ def check(self, value): greater_than = _GreaterThan greater_than_eq = _GreaterThanEq less_than = _LessThan +multinomial = _Multinomial unit_interval = _Interval(0., 1.) interval = _Interval half_open_interval = _HalfOpenInterval diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 6c94e3413b50..b25560b0895e 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -72,11 +72,7 @@ def _new(self, *args, **kwargs): @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): - total_count = self.total_count - if hasattr(total_count, "unsqueeze"): - total_count = total_count.unsqueeze(-1) - return constraints.independent( - constraints.integer_interval(0, total_count), 1) + return constraints.multinomial(self.total_count) @property def logits(self):