Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Independent constraint #50547

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions test/distributions/test_constraints.py
Expand Up @@ -7,6 +7,7 @@

CONSTRAINTS = [
(constraints.real,),
(constraints.real_vector,),
(constraints.positive,),
(constraints.greater_than, [-10., -2, 0, 2, 10]),
(constraints.greater_than, 0),
Expand Down
5 changes: 4 additions & 1 deletion test/distributions/test_distributions.py
Expand Up @@ -3915,7 +3915,10 @@ def test_support_constraints(self):
constraint = dist.support
message = '{} example {}/{} sample = {}'.format(
Dist.__name__, i + 1, len(params), value)
self.assertTrue(constraint.check(value).all(), msg=message)
self.assertEqual(constraint.event_dim, len(dist.event_shape), msg=message)
ok = constraint.check(value)
self.assertEqual(ok.shape, dist.batch_shape, msg=message)
self.assertTrue(ok.all(), msg=message)


class TestNumericalStability(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/binomial.py
Expand Up @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property
@constraints.dependent_property(is_discrete=True)
def support(self):
return constraints.integer_interval(0, self.total_count)

Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/categorical.py
Expand Up @@ -38,7 +38,7 @@ class Categorical(Distribution):
logits (Tensor): event log-odds
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
'logits': constraints.real_vector}
has_enumerate_support = True

def __init__(self, probs=None, logits=None, validate_args=None):
Expand Down Expand Up @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None):
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property
@constraints.dependent_property(is_discrete=True)
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

Expand Down
12 changes: 10 additions & 2 deletions torch/distributions/constraint_registry.py
Expand Up @@ -153,13 +153,21 @@ def __call__(self, constraint):
################################################################################

@biject_to.register(constraints.real)
@biject_to.register(constraints.real_vector)
@transform_to.register(constraints.real)
@transform_to.register(constraints.real_vector)
def _transform_to_real(constraint):
return transforms.identity_transform


@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
return biject_to(constraint.base_constraint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this will be enough to address #50496 since the event dim information will be lost when we call log_abs_det_jacobian, i.e. it seems to me that the transform returned for both the independent as well as the base constraint is the same. One way would be to post-hoc modify the output event dim of a transform to handle that use case. @feynmanliang - please correct me if I missed something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just think that we no longer need input_event_dim, output_event_dim in transforms. All we need is to define a correct domain, codomain. For composed transform, we still need to handle the logic properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, yes that's one of the differences w.r.t. numpyro. That's probably worth another discussion. We do have domain, co-domain. I see what you mean - we need to adjust the event dims appropriately and with this change we can remove input/output event dims altogether. That sounds nice actually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo - I like @fehiepsi's suggestion of modifying the domain/codomain's event dim (and remove input/output event dims since that wasn't part of last release) to handle this, but regardless, all the proposed changes in this PR look great to me, so please feel free to defer this to later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds right. I think we can follow @feynmanliang's suggestion in a future PR and say wrap with an IndependentTransform or set the Transform.event_dim. I'll demote this PR from "Fixes" to "Addresses".



@transform_to.register(constraints.independent)
def _transform_to_independent(constraint):
return transform_to(constraint.base_constraint)


@biject_to.register(constraints.positive)
@transform_to.register(constraints.positive)
def _transform_to_positive(constraint):
Expand Down