-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrap torch.distributions.Normal for use in Pyro #607
Conversation
|
||
def batch_shape(self, x=None): | ||
x_shape = [] if x is None else x.size() | ||
shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might lead to some hard to find bugs like #414 if the event dimensions do not match between the data and the parameters. Should we allow this kind of broadcasting, or limit it to the batch dimensions only (i.e. any x
's rightmost sizes must exactly agree with sample_shape + event_shape
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that you have added a strict argument which should take care of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that was actually needed to pass some of the expected-error tests.
:returns: broadcasted shape | ||
:rtype: tuple | ||
:raises: ValueError | ||
""" | ||
strict = kwargs.pop('strict', False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be nice to extend the tests for this utility function, by specifying this as another parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, will do.
|
||
def batch_shape(self, x=None): | ||
x_shape = [] if x is None else x.size() | ||
shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that you have added a strict argument which should take care of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Addresses #606
This creates an optional wrapper to use
torch.distributions.Normal
in Pyro. The torch version is only used if all of the following are satisfied:PYRO_USE_TORCH_DISTRIBUTIONS=1
is settorch.distributions
module exists (it is missing in PyTorch 0.2 release)torch.distributions.Normal
class existstorch.distributions.Normal
is reparameterized ifreparameterized=True
, alsolog_pdf_mask
is not supported).If any of the previous conditions are not satisfied, Pyro falls back to the standard implementation.
Tested
The torch distribution will not be exercised on travis. Tests pass locally except for an unrelated bug in PyTorch master 0.4.0a0+709fcfd.