Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Independent constraint #50547

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions test/distributions/test_constraints.py
Expand Up @@ -7,6 +7,7 @@

CONSTRAINTS = [
(constraints.real,),
(constraints.real_vector,),
(constraints.positive,),
(constraints.greater_than, [-10., -2, 0, 2, 10]),
(constraints.greater_than, 0),
Expand Down
5 changes: 4 additions & 1 deletion test/distributions/test_distributions.py
Expand Up @@ -3915,7 +3915,10 @@ def test_support_constraints(self):
constraint = dist.support
message = '{} example {}/{} sample = {}'.format(
Dist.__name__, i + 1, len(params), value)
self.assertTrue(constraint.check(value).all(), msg=message)
self.assertEqual(constraint.event_dim, len(dist.event_shape), msg=message)
ok = constraint.check(value)
self.assertEqual(ok.shape, dist.batch_shape, msg=message)
self.assertTrue(ok.all(), msg=message)


class TestNumericalStability(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/binomial.py
Expand Up @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property
@constraints.dependent_property(is_discrete=True)
def support(self):
return constraints.integer_interval(0, self.total_count)

Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/categorical.py
Expand Up @@ -38,7 +38,7 @@ class Categorical(Distribution):
logits (Tensor): event log-odds
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
'logits': constraints.real_vector}
has_enumerate_support = True

def __init__(self, probs=None, logits=None, validate_args=None):
Expand Down Expand Up @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property
@constraints.dependent_property(is_discrete=True)
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

Expand Down
12 changes: 10 additions & 2 deletions torch/distributions/constraint_registry.py
Expand Up @@ -153,13 +153,21 @@ def __call__(self, constraint):
################################################################################

@biject_to.register(constraints.real)
@biject_to.register(constraints.real_vector)
@transform_to.register(constraints.real)
@transform_to.register(constraints.real_vector)
def _transform_to_real(constraint):
return transforms.identity_transform


@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
return biject_to(constraint.base_constraint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this will be enough to address #50496 since the event dim information will be lost when we call log_abs_det_jacobian, i.e. it seems to me that the transform returned for both the independent as well as the base constraint is the same. One way would be to post-hoc modify the output event dim of a transform to handle that use case. @feynmanliang - please correct me if I missed something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just think that we no longer need input_event_dim, output_event_dim in transforms. All we need is to define a correct domain, codomain. For composed transform, we still need to handle the logic properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, yes that's one of the differences w.r.t. numpyro. That's probably worth another discussion. We do have domain, co-domain. I see what you mean - we need to adjust the event dims appropriately and with this change we can remove input/output event dims altogether. That sounds nice actually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo - I like @fehiepsi's suggestion of modifying the domain/codomain's event dim (and remove input/output event dims since that wasn't part of last release) to handle this, but regardless, all the proposed changes in this PR look great to me, so please feel free to defer this to later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds right. I think we can follow @feynmanliang's suggestion in a future PR and say wrap with an IndependentTransform or set the Transform.event_dim. I'll demote this PR from "Fixes" to "Addresses".



@transform_to.register(constraints.independent)
def _transform_to_independent(constraint):
return transform_to(constraint.base_constraint)


@biject_to.register(constraints.positive)
@transform_to.register(constraints.positive)
def _transform_to_positive(constraint):
Expand Down
123 changes: 109 additions & 14 deletions torch/distributions/constraints.py
Expand Up @@ -7,18 +7,19 @@
- ``constraints.dependent``
- ``constraints.greater_than(lower_bound)``
- ``constraints.greater_than_eq(lower_bound)``
- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
- ``constraints.integer_interval(lower_bound, upper_bound)``
- ``constraints.interval(lower_bound, upper_bound)``
- ``constraints.less_than(upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.nonnegative_integer``
- ``constraints.one_hot``
- ``constraints.positive``
- ``constraints.positive_definite``
- ``constraints.positive_integer``
- ``constraints.real``
- ``constraints.positive``
- ``constraints.real_vector``
- ``constraints.real``
- ``constraints.simplex``
- ``constraints.stack``
- ``constraints.unit_interval``
Expand All @@ -35,6 +36,7 @@
'dependent_property',
'greater_than',
'greater_than_eq',
'independent',
'integer_interval',
'interval',
'half_open_interval',
Expand All @@ -61,7 +63,8 @@ class Constraint(object):
A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
"""
is_discrete = False
is_discrete = False # Default to continuous.
event_dim = 0 # Default to univariate.

def check(self, value):
"""
Expand All @@ -79,6 +82,23 @@ class _Dependent(Constraint):
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
"""
def __init__(self, *, is_discrete=False, event_dim=0):
self.is_discrete = is_discrete
self.event_dim = event_dim
super().__init__()

def __call__(self, *, is_discrete=None, event_dim=None):
"""
Support for syntax to customize static attributes::

constraints.dependent(is_discrete=True, event_dim=1)
"""
if is_discrete is None:
is_discrete = self.is_discrete
if event_dim is None:
event_dim = self.event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)

def check(self, x):
raise ValueError('Cannot determine validity of dependent constraint')

Expand All @@ -102,7 +122,49 @@ def __init__(self, low, high):
def support(self):
return constraints.interval(self.low, self.high)
"""
pass
def __init__(self, fn=None, *, is_discrete=False, event_dim=0):
self.is_discrete = is_discrete
self.event_dim = event_dim
super().__init__(fn)

def __call__(self, fn):
"""
Support for syntax to customize static attributes::

@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self):
...
"""
return _DependentProperty(fn, is_discrete=self.is_discrete, event_dim=self.event_dim)


class _IndependentConstraint(Constraint):
"""
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
dims in :meth:`check`, so that an event is valid only if all its
independent entries are valid.
"""
def __init__(self, base_constraint, reinterpreted_batch_ndims):
assert isinstance(base_constraint, Constraint)
assert isinstance(reinterpreted_batch_ndims, int)
neerajprad marked this conversation as resolved.
Show resolved Hide resolved
assert reinterpreted_batch_ndims >= 0
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()

@property
def is_discrete(self):
return self.base_constraint.is_discrete

@property
def event_dim(self):
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims

def check(self, value):
result = self.base_constraint.check(value)
fritzo marked this conversation as resolved.
Show resolved Hide resolved
result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,))
result = result.all(-1)
return result


class _Boolean(Constraint):
Expand All @@ -120,6 +182,7 @@ class _OneHot(Constraint):
Constrain to one-hot vectors.
"""
is_discrete = True
event_dim = 1

def check(self, value):
is_boolean = (value == 0) | (value == 1)
Expand All @@ -136,6 +199,7 @@ class _IntegerInterval(Constraint):
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
Expand All @@ -154,6 +218,7 @@ class _IntegerLessThan(Constraint):

def __init__(self, upper_bound):
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (value % 1 == 0) & (value <= self.upper_bound)
Expand All @@ -172,6 +237,7 @@ class _IntegerGreaterThan(Constraint):

def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()

def check(self, value):
return (value % 1 == 0) & (value >= self.lower_bound)
Expand All @@ -196,6 +262,7 @@ class _GreaterThan(Constraint):
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()

def check(self, value):
return self.lower_bound < value
Expand All @@ -212,6 +279,7 @@ class _GreaterThanEq(Constraint):
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()

def check(self, value):
return self.lower_bound <= value
Expand All @@ -228,6 +296,7 @@ class _LessThan(Constraint):
"""
def __init__(self, upper_bound):
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return value < self.upper_bound
Expand All @@ -245,6 +314,7 @@ class _Interval(Constraint):
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (self.lower_bound <= value) & (value <= self.upper_bound)
Expand All @@ -262,6 +332,7 @@ class _HalfOpenInterval(Constraint):
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (self.lower_bound <= value) & (value < self.upper_bound)
Expand All @@ -277,6 +348,8 @@ class _Simplex(Constraint):
Constrain to the unit simplex in the innermost (rightmost) dimension.
Specifically: `x >= 0` and `x.sum(-1) == 1`.
"""
event_dim = 1

def check(self, value):
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)

Expand All @@ -285,6 +358,8 @@ class _LowerTriangular(Constraint):
"""
Constrain to lower-triangular square matrices.
"""
event_dim = 2

def check(self, value):
value_tril = value.tril()
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
Expand All @@ -294,6 +369,8 @@ class _LowerCholesky(Constraint):
"""
Constrain to lower-triangular square matrices with positive diagonals.
"""
event_dim = 2

def check(self, value):
value_tril = value.tril()
lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
Expand All @@ -307,6 +384,8 @@ class _CorrCholesky(Constraint):
Constrain to lower-triangular square matrices with positive diagonals and each
row vector being of unit length.
"""
event_dim = 2

def check(self, value):
tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor
row_norm = torch.linalg.norm(value.detach(), dim=-1)
Expand All @@ -318,6 +397,8 @@ class _PositiveDefinite(Constraint):
"""
Constrain to positive-definite matrices.
"""
event_dim = 2

def check(self, value):
matrix_shape = value.shape[-2:]
batch_shape = value.unsqueeze(0).shape[:-2]
Expand All @@ -328,15 +409,6 @@ def check(self, value):
for v in flattened_value]).view(batch_shape)


class _RealVector(Constraint):
"""
Constrain to real-valued vectors. This is the same as `constraints.real`,
but additionally reduces across the `event_shape` dimension.
"""
def check(self, value):
return torch.all(value == value, dim=-1) # False for NANs.


class _Cat(Constraint):
"""
Constraint functor that applies a sequence of constraints
Expand All @@ -351,6 +423,15 @@ def __init__(self, cseq, dim=0, lengths=None):
self.lengths = list(lengths)
assert len(self.lengths) == len(self.cseq)
self.dim = dim
super().__init__()

@property
def is_discrete(self):
return any(c.is_discrete for c in self.cseq)

@property
def event_dim(self):
return max(c.event_dim for c in self.cseq)

def check(self, value):
assert -value.dim() <= self.dim < value.dim()
Expand All @@ -373,23 +454,37 @@ def __init__(self, cseq, dim=0):
assert all(isinstance(c, Constraint) for c in cseq)
self.cseq = list(cseq)
self.dim = dim
super().__init__()

@property
def is_discrete(self):
return any(c.is_discrete for c in self.cseq)

@property
def event_dim(self):
dim = max(c.event_dim for c in self.cseq)
if self.dim + dim < 0:
dim += 1
return dim

def check(self, value):
assert -value.dim() <= self.dim < value.dim()
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
return torch.stack([constr.check(v)
for v, constr in zip(vs, self.cseq)], self.dim)


# Public interface.
dependent = _Dependent()
dependent_property = _DependentProperty
independent = _IndependentConstraint
boolean = _Boolean()
one_hot = _OneHot()
nonnegative_integer = _IntegerGreaterThan(0)
positive_integer = _IntegerGreaterThan(1)
integer_interval = _IntegerInterval
real = _Real()
real_vector = _RealVector()
real_vector = independent(real, 1)
positive = _GreaterThan(0.)
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/independent.py
Expand Up @@ -68,7 +68,8 @@ def has_enumerate_support(self):

@constraints.dependent_property
def support(self):
return self.base_dist.support
return constraints.independent(self.base_dist.support,
self.reinterpreted_batch_ndims)

@property
def mean(self):
Expand Down
8 changes: 4 additions & 4 deletions torch/distributions/lowrank_multivariate_normal.py
Expand Up @@ -73,10 +73,10 @@ class LowRankMultivariateNormal(Distribution):

capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
"""
arg_constraints = {"loc": constraints.real,
"cov_factor": constraints.real,
"cov_diag": constraints.positive}
support = constraints.real
arg_constraints = {"loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2),
"cov_diag": constraints.independent(constraints.positive, 1)}
support = constraints.real_vector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the internal failures are due to checks like if dist.support is constraints.real or isinstance(dist.support, constraints._Real). For some of the distributions, this check will need to be updated (e.g. multinomial and I don't see a way around that) in client code. Constraints wrapped within Independent will be harder to check this way. e.g. if we have independent(real, 2), what would be the recommended way to check the base constraint without the event dim information?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This points to a more general issue that I have observed, which is what is the recommended way for inferring constraint type. e.g. the most general way would be something like isinstance(constraint, _IntegerInterval) but that requires peeking into the non-public interface.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad for constraints.independent(constraints.real, 2) I suppose we could brute force and

assert isinstance(c, constraints.independent)
assert isinstance(c.base_constraint, constraints.real)
assert isinstance(c.event_dim, 2)

or we could define .__eq__() to enable syntax like:

assert c == constraints.independent(constraints.real, 2)

(we already do this for transforms).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are, I think, two issues - I created #50616 to discuss one of these. The other issue is checking for the constraint type (disregarding event dim and the wrapping independent constraint) which is where most of the usage is rather than a hard equality check. So something like, isinstance(constraint, constraints._Real) or (isinstance(constraint, constraints._Independent) and isinstance(constraint.base_constraint, constraints._Real)) will work but only if the independents aren't nested.

Update: I think __eq__ will also be very useful, but that's a separate feature, I am merely looking to support existing usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to use something like the following internally to handle these use cases, and we can consider providing this as a utility in PyTorch itself so as to avoid usage of non-public classes for these kind of checks. I am not sure how much usage like this we are going to find in the wild, so this may still end up breaking some code, which is my only (although minor) concern.

def _unwrap(constraint):
    if isinstance(constraint, constraints.independent):
        return _unwrap(constraint.base_constraint)
    return constraint if isinstance(constraint, type) else constraint.__class__


def constraint_type_eq(constraint1, constraint2):
    return _unwrap(constraint1) == _unwrap(constraint2)

Then we can do:

>>> constraint_eq(constraints.independent(constraints.real, 1), constraints.real)
True

instead of isinstance(constraint, constraints._Real) or (isinstance(constraint, constraints._Independent) and isinstance(constraint.base_constraint, constraints._Real))

has_rsample = True

def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
Expand Down