New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Independent constraint #50547
Independent constraint #50547
Changes from 6 commits
fb8274f
1d75d76
1a93eda
7ceabb7
d50b992
20dafb7
0e58ebe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,10 +73,10 @@ class LowRankMultivariateNormal(Distribution): | |
|
||
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor | ||
""" | ||
arg_constraints = {"loc": constraints.real, | ||
"cov_factor": constraints.real, | ||
"cov_diag": constraints.positive} | ||
support = constraints.real | ||
arg_constraints = {"loc": constraints.real_vector, | ||
"cov_factor": constraints.independent(constraints.real, 2), | ||
"cov_diag": constraints.independent(constraints.positive, 1)} | ||
support = constraints.real_vector | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that the internal failures are due to checks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This points to a more general issue that I have observed, which is what is the recommended way for inferring constraint type. e.g. the most general way would be something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @neerajprad for assert isinstance(c, constraints.independent)
assert isinstance(c.base_constraint, constraints.real)
assert isinstance(c.event_dim, 2) or we could define assert c == constraints.independent(constraints.real, 2) (we already do this for transforms). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are, I think, two issues - I created #50616 to discuss one of these. The other issue is checking for the constraint type (disregarding event dim and the wrapping Update: I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I plan to use something like the following internally to handle these use cases, and we can consider providing this as a utility in PyTorch itself so as to avoid usage of non-public classes for these kind of checks. I am not sure how much usage like this we are going to find in the wild, so this may still end up breaking some code, which is my only (although minor) concern. def _unwrap(constraint):
if isinstance(constraint, constraints.independent):
return _unwrap(constraint.base_constraint)
return constraint if isinstance(constraint, type) else constraint.__class__
def constraint_type_eq(constraint1, constraint2):
return _unwrap(constraint1) == _unwrap(constraint2) Then we can do: >>> constraint_eq(constraints.independent(constraints.real, 1), constraints.real)
True instead of |
||
has_rsample = True | ||
|
||
def __init__(self, loc, cov_factor, cov_diag, validate_args=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this will be enough to address #50496 since the event dim information will be lost when we call
log_abs_det_jacobian
, i.e. it seems to me that the transform returned for both the independent as well as the base constraint is the same. One way would be to post-hoc modify the output event dim of a transform to handle that use case. @feynmanliang - please correct me if I missed something.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just think that we no longer need
input_event_dim
,output_event_dim
in transforms. All we need is to define a correctdomain
,codomain
. For composed transform, we still need to handle the logic properly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, yes that's one of the differences w.r.t. numpyro. That's probably worth another discussion.We do have domain, co-domain. I see what you mean - we need to adjust the event dims appropriately and with this change we can remove input/output event dims altogether. That sounds nice actually.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fritzo - I like @fehiepsi's suggestion of modifying the domain/codomain's event dim (and remove input/output event dims since that wasn't part of last release) to handle this, but regardless, all the proposed changes in this PR look great to me, so please feel free to defer this to later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds right. I think we can follow @feynmanliang's suggestion in a future PR and say wrap with an
IndependentTransform
or set theTransform.event_dim
. I'll demote this PR from "Fixes" to "Addresses".