Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Distribution.infer_shapes() for static shape analysis #901

Merged
merged 8 commits into from
Feb 2, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jan 31, 2021

Addresses pyro-ppl/funsor#412
Ports .infer_shape() from pyro-ppl/pyro#2739
Ports dependent_property from pytorch/pytorch#50581

This adds a static method Disribution.infer_shapes() to statically infer (batch_shape, event_shape) from the shapes of distribution parameters. This should speed up funsor computations as described in pyro-ppl/funsor#412. Note this works for only a subset of distributions; shapes cannot currently be statically inferred for distributions with non-tensor inputs such as TransformedDistribution which takes a distribution, or Delta which takes an event_dim: int.

This generic logic requires static .event_dim attributes on dependent supports. Because NumPyro uses Constraint.__call__() rather than Constraint.check() I've added a type check to determine behavior in _Dependent.__call__().

Tested

  • added a unit test to test_distributions.py

@fritzo
Copy link
Member Author

fritzo commented Jan 31, 2021

@fehiepsi any idea why tests are failing? The failures all seem to be numerical precision issues, which is weird because this PR only modifies type metadata 😕

EDIT Oops, it was probably my incorrect LogNormal.support 😊

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Thanks, @fritzo!

@@ -853,7 +894,8 @@ def tree_flatten(self):


class Delta(Distribution):
arg_constraints = {'v': real, 'log_density': real}
# FIXME v and log_density should be constraints.independent(constraints.real, ???)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what is a good solution here...

Copy link
Member Author

@fritzo fritzo Feb 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess to be completely correct we'd want something like Uniform.arg_constraints:

arg_constraints = {'v': constraints.dependent(is_discrete=False),
                   'log_density': constraints.real}

but I was worried that might break something and seemed out of scope for this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shapes.
:rtype: tuple
"""
if cls.support.event_dim > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is surprising to me that this works for @constraints.dependent_property. I have been thinking that we can't make infer_shapes a classmethod for many distributions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a kind of cool trick. The two versions

self.support.event_dim
type(self).support.event_dim

follow entirely different code paths: the former runs the property method to compute a Constraint instance then extracts the .event_dim of that instance. The latter cannot create a Constraint instance but instead extracts a class attribute of the dependent_property itself. In PyTorch I've added tests that these two code paths agree.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally make sense to me! Thanks for explaning and being creative!

@fehiepsi fehiepsi merged commit 13c76dc into master Feb 2, 2021
@fehiepsi fehiepsi deleted the infer-shapes branch February 2, 2021 16:10
@fritzo
Copy link
Member Author

fritzo commented Feb 2, 2021

Thanks for reviewing @fehiepsi !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants