-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Distribution.infer_shapes() for static shape analysis #901
Conversation
@fehiepsi any idea why tests are failing? The failures all seem to be numerical precision issues, which is weird because this PR only modifies type metadata 😕 EDIT Oops, it was probably my incorrect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! Thanks, @fritzo!
@@ -853,7 +894,8 @@ def tree_flatten(self): | |||
|
|||
|
|||
class Delta(Distribution): | |||
arg_constraints = {'v': real, 'log_density': real} | |||
# FIXME v and log_density should be constraints.independent(constraints.real, ???) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what is a good solution here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess to be completely correct we'd want something like Uniform.arg_constraints
:
arg_constraints = {'v': constraints.dependent(is_discrete=False),
'log_density': constraints.real}
but I was worried that might break something and seemed out of scope for this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's a possible fix: https://github.com/pyro-ppl/numpyro/compare/fix-delta-constraint
shapes. | ||
:rtype: tuple | ||
""" | ||
if cls.support.event_dim > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is surprising to me that this works for @constraints.dependent_property
. I have been thinking that we can't make infer_shapes
a classmethod for many distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a kind of cool trick. The two versions
self.support.event_dim
type(self).support.event_dim
follow entirely different code paths: the former runs the property method to compute a Constraint
instance then extracts the .event_dim
of that instance. The latter cannot create a Constraint
instance but instead extracts a class attribute of the dependent_property
itself. In PyTorch I've added tests that these two code paths agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally make sense to me! Thanks for explaning and being creative!
Thanks for reviewing @fehiepsi ! |
Addresses pyro-ppl/funsor#412
Ports
.infer_shape()
from pyro-ppl/pyro#2739Ports
dependent_property
from pytorch/pytorch#50581This adds a static method
Disribution.infer_shapes()
to statically infer(batch_shape, event_shape)
from the shapes of distribution parameters. This should speed up funsor computations as described in pyro-ppl/funsor#412. Note this works for only a subset of distributions; shapes cannot currently be statically inferred for distributions with non-tensor inputs such asTransformedDistribution
which takes a distribution, orDelta
which takes anevent_dim: int
.This generic logic requires static
.event_dim
attributes on dependent supports. Because NumPyro usesConstraint.__call__()
rather thanConstraint.check()
I've added a type check to determine behavior in_Dependent.__call__()
.Tested