Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Independent constraint #50547

Closed
wants to merge 7 commits into from
Closed

Independent constraint #50547

wants to merge 7 commits into from

Conversation

fritzo
Copy link
Collaborator

@fritzo fritzo commented Jan 14, 2021

Addresses #50496

This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions.

  • Adds a constraints.independent and replaces real_vector with independent(real, 1). (this pattern has long been used in Pyro)
  • Adds an .event_dim attribute to all constraints.
  • Tests that constraint.check(data) has the correct shape. (Previously the shapes were incorrect).
  • Adds machinery to set static .is_discrete and .event_dim for constraints.dependent.
  • Fixes constraints for a number of distributions.

Tested

  • added a new check to the constraints tests
  • added a new check for .event_dim

cc @fehiepsi @feynmanliang @stefanwebb

@fritzo fritzo added the module: distributions Related to torch.distributions label Jan 14, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 14, 2021

💊 CI failures summary and remediations

As of commit d50b992 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@feynmanliang
Copy link
Contributor

feynmanliang commented Jan 14, 2021

Fixes #50496

edit should have read first... thanks for this! After this merges we can think about enriching the constraint registry e.g. constraint_registry.biject_to(independent(positive_real, event_dim=X)) == Compose(ExpTransform, AffineTransform(event_dim=X))

Copy link
Contributor

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the proposals:

  • adding independent constraint to beautifully correct the constraints/supports of Independent and other parameters (such as those in LowRank distribution)
  • having is_discrete attribute to those supports. This would simplify users' code to check if a distribution is discrete or continuous.

I don't quite understand the logic of fn, is_discrete, event_dim in those dependent, dependent_property decorators, hence would like to defer to @neerajprad to review it. :)

@fritzo
Copy link
Collaborator Author

fritzo commented Jan 14, 2021

... the logic of fn, is_discrete, event_dim in those dependent, dependent_property decorators

To clarify, I have added .__call__() methods to dependent and dependent_property. That is because I would like to support overriding of default .is_discrete and .event_dim using syntax like

arg_constraints = {"x": constraints.dependent(event_dim=1)}

@constraints.dependent_property(is_discrete=True)
def support(self):
    ...

But to be compatible with the old syntax without parens, we can't exactly add those new parameters to the constructor. So instead I've allowed creating of a temporary dependent constraint and then use its .__call__() method to create objects with the right parameters. Note @fehiepsi we use this same trick in Pyro, e.g. to create PyroSample objects using syntax either with or without parens:

@PyroSample
def x():
    ...

@PyroSample(infer={...})
def x():
    ...

@fehiepsi
Copy link
Contributor

Thanks @fritzo, that makes sense to me! Luckily we don't use dependent, dependent_property in numpyro (there it might not be easy to overwrite the __call__ method, which has the same functionality as .check) except for Uniform parameters (which should have the default event_dim, is_discrete properties).

torch/distributions/constraints.py Show resolved Hide resolved
def _transform_to_real(constraint):
return transforms.identity_transform


@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
return biject_to(constraint.base_constraint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this will be enough to address #50496 since the event dim information will be lost when we call log_abs_det_jacobian, i.e. it seems to me that the transform returned for both the independent as well as the base constraint is the same. One way would be to post-hoc modify the output event dim of a transform to handle that use case. @feynmanliang - please correct me if I missed something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just think that we no longer need input_event_dim, output_event_dim in transforms. All we need is to define a correct domain, codomain. For composed transform, we still need to handle the logic properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, yes that's one of the differences w.r.t. numpyro. That's probably worth another discussion. We do have domain, co-domain. I see what you mean - we need to adjust the event dims appropriately and with this change we can remove input/output event dims altogether. That sounds nice actually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo - I like @fehiepsi's suggestion of modifying the domain/codomain's event dim (and remove input/output event dims since that wasn't part of last release) to handle this, but regardless, all the proposed changes in this PR look great to me, so please feel free to defer this to later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds right. I think we can follow @feynmanliang's suggestion in a future PR and say wrap with an IndependentTransform or set the Transform.event_dim. I'll demote this PR from "Fixes" to "Addresses".

if hasattr(total_count, "unsqueeze"):
total_count = total_count.unsqueeze(-1)
return constraints.independent(
constraints.integer_interval(0, total_count), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to have constraints.multinomial.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I've added a constraints.multinomial in this PR. However note that it checks the slightly weaker condition

- (0 <= x).all(-1) & (x.sum(-1) == total_count)
+ (0 <= x).all(-1) & (x.sum(-1) <= total_count)

because we allow passing a bogus total_count.max() to Multinomial when it is used as a likelihood. I plan to try to address this in the future while fixing #42407.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@neerajprad
Copy link
Contributor

There are some internal test failures that I'm blocked on and still looking into.

arg_constraints = {"loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2),
"cov_diag": constraints.independent(constraints.positive, 1)}
support = constraints.real_vector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the internal failures are due to checks like if dist.support is constraints.real or isinstance(dist.support, constraints._Real). For some of the distributions, this check will need to be updated (e.g. multinomial and I don't see a way around that) in client code. Constraints wrapped within Independent will be harder to check this way. e.g. if we have independent(real, 2), what would be the recommended way to check the base constraint without the event dim information?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This points to a more general issue that I have observed, which is what is the recommended way for inferring constraint type. e.g. the most general way would be something like isinstance(constraint, _IntegerInterval) but that requires peeking into the non-public interface.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad for constraints.independent(constraints.real, 2) I suppose we could brute force and

assert isinstance(c, constraints.independent)
assert isinstance(c.base_constraint, constraints.real)
assert isinstance(c.event_dim, 2)

or we could define .__eq__() to enable syntax like:

assert c == constraints.independent(constraints.real, 2)

(we already do this for transforms).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are, I think, two issues - I created #50616 to discuss one of these. The other issue is checking for the constraint type (disregarding event dim and the wrapping independent constraint) which is where most of the usage is rather than a hard equality check. So something like, isinstance(constraint, constraints._Real) or (isinstance(constraint, constraints._Independent) and isinstance(constraint.base_constraint, constraints._Real)) will work but only if the independents aren't nested.

Update: I think __eq__ will also be very useful, but that's a separate feature, I am merely looking to support existing usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to use something like the following internally to handle these use cases, and we can consider providing this as a utility in PyTorch itself so as to avoid usage of non-public classes for these kind of checks. I am not sure how much usage like this we are going to find in the wild, so this may still end up breaking some code, which is my only (although minor) concern.

def _unwrap(constraint):
    if isinstance(constraint, constraints.independent):
        return _unwrap(constraint.base_constraint)
    return constraint if isinstance(constraint, type) else constraint.__class__


def constraint_type_eq(constraint1, constraint2):
    return _unwrap(constraint1) == _unwrap(constraint2)

Then we can do:

>>> constraint_eq(constraints.independent(constraints.real, 1), constraints.real)
True

instead of isinstance(constraint, constraints._Real) or (isinstance(constraint, constraints._Independent) and isinstance(constraint.base_constraint, constraints._Real))

facebook-github-bot pushed a commit to facebookresearch/beanmachine that referenced this pull request Jan 21, 2021
Summary:
Pull Request resolved: #542

This adds an equality testing utility for `torch.constraint` objects/classes, so that these comparisons can be consolidated within a single utility function `is_constraint_eq`.

Usage:
```
is_constraint_eq(dist.support, (constraints.real, constraints.greater_than))
```

, instead of:

```
isinstance(dist.support, (constraints._Real, constraints._GreaterThan))
```
, or the more obfuscatory

```
dist.support is constraints.real or isinstance(dist.support, constraints.greater_than)
```

See pytorch/pytorch#50616 for more details (note that the changes suggested in the issue are complementary).

 - This avoids usage of the non-public constraint classes (like `constraints._Real`, `constraints._Interval`).
 - Makes it possible to consolidate future changes (e.g. those arising out of the introduction of an `Independent` constraint - pytorch/pytorch#50547) within a single function.

This is a pre-requisite to some other fixes that are currently blocking D25918330. I will add these small fixes when I merge D25918330.

Reviewed By: feynmanliang, jpchen

Differential Revision: D25935106

fbshipit-source-id: 76296464f51e01a9cac4506fb927c531145212ce
facebook-github-bot pushed a commit that referenced this pull request Jan 22, 2021
Summary:
Addresses #50496

This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions.
- Adds a `constraints.independent` and replaces `real_vector` with `independent(real, 1)`. (this pattern has long been used in Pyro)
- Adds an `.event_dim` attribute to all constraints.
- Tests that `constraint.check(data)` has the correct shape. (Previously the shapes were incorrect).
- Adds machinery to set static `.is_discrete` and `.event_dim` for `constraints.dependent`.
- Fixes constraints for a number of distributions.

## Tested
- added a new check to the constraints tests
- added a new check for `.event_dim`

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: #50547

Reviewed By: VitalyFedyunin

Differential Revision: D25918330

Pulled By: neerajprad

fbshipit-source-id: a648c3de3e8704f70f445c0f1c39f2593c8c74db
facebook-github-bot pushed a commit that referenced this pull request Feb 2, 2021
Summary:
This fix ensures
```py
Dirichlet.arg_constraints["concentration"].event_dim == 1
```
which was missed in #50547

## Tested
- [x] added a regression test, covering all distributions

Pull Request resolved: #51369

Reviewed By: H-Huang

Differential Revision: D26160644

Pulled By: neerajprad

fbshipit-source-id: 1bb44c79480a1f0052b0ef9d4605e750ab07bea1
@github-actions github-actions bot deleted the independent-constraint branch February 10, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants