New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Independent constraint #50547
Independent constraint #50547
Conversation
💊 CI failures summary and remediationsAs of commit d50b992 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Fixes #50496 edit should have read first... thanks for this! After this merges we can think about enriching the constraint registry e.g. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like the proposals:
- adding
independent
constraint to beautifully correct the constraints/supports of Independent and other parameters (such as those in LowRank distribution) - having
is_discrete
attribute to those supports. This would simplify users' code to check if a distribution is discrete or continuous.
I don't quite understand the logic of fn
, is_discrete
, event_dim
in those dependent
, dependent_property
decorators, hence would like to defer to @neerajprad to review it. :)
To clarify, I have added arg_constraints = {"x": constraints.dependent(event_dim=1)}
@constraints.dependent_property(is_discrete=True)
def support(self):
... But to be compatible with the old syntax without parens, we can't exactly add those new parameters to the constructor. So instead I've allowed creating of a temporary @PyroSample
def x():
...
@PyroSample(infer={...})
def x():
... |
Thanks @fritzo, that makes sense to me! Luckily we don't use |
def _transform_to_real(constraint): | ||
return transforms.identity_transform | ||
|
||
|
||
@biject_to.register(constraints.independent) | ||
def _biject_to_independent(constraint): | ||
return biject_to(constraint.base_constraint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this will be enough to address #50496 since the event dim information will be lost when we call log_abs_det_jacobian
, i.e. it seems to me that the transform returned for both the independent as well as the base constraint is the same. One way would be to post-hoc modify the output event dim of a transform to handle that use case. @feynmanliang - please correct me if I missed something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just think that we no longer need input_event_dim
, output_event_dim
in transforms. All we need is to define a correct domain
, codomain
. For composed transform, we still need to handle the logic properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, yes that's one of the differences w.r.t. numpyro. That's probably worth another discussion. We do have domain, co-domain. I see what you mean - we need to adjust the event dims appropriately and with this change we can remove input/output event dims altogether. That sounds nice actually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds right. I think we can follow @feynmanliang's suggestion in a future PR and say wrap with an IndependentTransform
or set the Transform.event_dim
. I'll demote this PR from "Fixes" to "Addresses".
torch/distributions/multinomial.py
Outdated
if hasattr(total_count, "unsqueeze"): | ||
total_count = total_count.unsqueeze(-1) | ||
return constraints.independent( | ||
constraints.integer_interval(0, total_count), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to have constraints.multinomial
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I've added a constraints.multinomial
in this PR. However note that it checks the slightly weaker condition
- (0 <= x).all(-1) & (x.sum(-1) == total_count)
+ (0 <= x).all(-1) & (x.sum(-1) <= total_count)
because we allow passing a bogus total_count.max()
to Multinomial
when it is used as a likelihood. I plan to try to address this in the future while fixing #42407.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There are some internal test failures that I'm blocked on and still looking into. |
arg_constraints = {"loc": constraints.real_vector, | ||
"cov_factor": constraints.independent(constraints.real, 2), | ||
"cov_diag": constraints.independent(constraints.positive, 1)} | ||
support = constraints.real_vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the internal failures are due to checks like if dist.support is constraints.real
or isinstance(dist.support, constraints._Real)
. For some of the distributions, this check will need to be updated (e.g. multinomial and I don't see a way around that) in client code. Constraints wrapped within Independent
will be harder to check this way. e.g. if we have independent(real, 2)
, what would be the recommended way to check the base constraint without the event dim
information?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This points to a more general issue that I have observed, which is what is the recommended way for inferring constraint type. e.g. the most general way would be something like isinstance(constraint, _IntegerInterval)
but that requires peeking into the non-public interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad for constraints.independent(constraints.real, 2)
I suppose we could brute force and
assert isinstance(c, constraints.independent)
assert isinstance(c.base_constraint, constraints.real)
assert isinstance(c.event_dim, 2)
or we could define .__eq__()
to enable syntax like:
assert c == constraints.independent(constraints.real, 2)
(we already do this for transforms).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are, I think, two issues - I created #50616 to discuss one of these. The other issue is checking for the constraint type (disregarding event dim and the wrapping independent
constraint) which is where most of the usage is rather than a hard equality check. So something like, isinstance(constraint, constraints._Real) or (isinstance(constraint, constraints._Independent) and isinstance(constraint.base_constraint, constraints._Real))
will work but only if the independent
s aren't nested.
Update: I think __eq__
will also be very useful, but that's a separate feature, I am merely looking to support existing usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I plan to use something like the following internally to handle these use cases, and we can consider providing this as a utility in PyTorch itself so as to avoid usage of non-public classes for these kind of checks. I am not sure how much usage like this we are going to find in the wild, so this may still end up breaking some code, which is my only (although minor) concern.
def _unwrap(constraint):
if isinstance(constraint, constraints.independent):
return _unwrap(constraint.base_constraint)
return constraint if isinstance(constraint, type) else constraint.__class__
def constraint_type_eq(constraint1, constraint2):
return _unwrap(constraint1) == _unwrap(constraint2)
Then we can do:
>>> constraint_eq(constraints.independent(constraints.real, 1), constraints.real)
True
instead of isinstance(constraint, constraints._Real) or (isinstance(constraint, constraints._Independent) and isinstance(constraint.base_constraint, constraints._Real))
Summary: Pull Request resolved: #542 This adds an equality testing utility for `torch.constraint` objects/classes, so that these comparisons can be consolidated within a single utility function `is_constraint_eq`. Usage: ``` is_constraint_eq(dist.support, (constraints.real, constraints.greater_than)) ``` , instead of: ``` isinstance(dist.support, (constraints._Real, constraints._GreaterThan)) ``` , or the more obfuscatory ``` dist.support is constraints.real or isinstance(dist.support, constraints.greater_than) ``` See pytorch/pytorch#50616 for more details (note that the changes suggested in the issue are complementary). - This avoids usage of the non-public constraint classes (like `constraints._Real`, `constraints._Interval`). - Makes it possible to consolidate future changes (e.g. those arising out of the introduction of an `Independent` constraint - pytorch/pytorch#50547) within a single function. This is a pre-requisite to some other fixes that are currently blocking D25918330. I will add these small fixes when I merge D25918330. Reviewed By: feynmanliang, jpchen Differential Revision: D25935106 fbshipit-source-id: 76296464f51e01a9cac4506fb927c531145212ce
Summary: Addresses #50496 This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions. - Adds a `constraints.independent` and replaces `real_vector` with `independent(real, 1)`. (this pattern has long been used in Pyro) - Adds an `.event_dim` attribute to all constraints. - Tests that `constraint.check(data)` has the correct shape. (Previously the shapes were incorrect). - Adds machinery to set static `.is_discrete` and `.event_dim` for `constraints.dependent`. - Fixes constraints for a number of distributions. ## Tested - added a new check to the constraints tests - added a new check for `.event_dim` cc fehiepsi feynmanliang stefanwebb Pull Request resolved: #50547 Reviewed By: VitalyFedyunin Differential Revision: D25918330 Pulled By: neerajprad fbshipit-source-id: a648c3de3e8704f70f445c0f1c39f2593c8c74db
Summary: This fix ensures ```py Dirichlet.arg_constraints["concentration"].event_dim == 1 ``` which was missed in #50547 ## Tested - [x] added a regression test, covering all distributions Pull Request resolved: #51369 Reviewed By: H-Huang Differential Revision: D26160644 Pulled By: neerajprad fbshipit-source-id: 1bb44c79480a1f0052b0ef9d4605e750ab07bea1
Addresses #50496
This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions.
constraints.independent
and replacesreal_vector
withindependent(real, 1)
. (this pattern has long been used in Pyro).event_dim
attribute to all constraints.constraint.check(data)
has the correct shape. (Previously the shapes were incorrect)..is_discrete
and.event_dim
forconstraints.dependent
.Tested
.event_dim
cc @fehiepsi @feynmanliang @stefanwebb