Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap torch.distributions.Normal for use in Pyro #607

Merged
merged 5 commits into from
Nov 28, 2017

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Nov 28, 2017

Addresses #606

This creates an optional wrapper to use torch.distributions.Normal in Pyro. The torch version is only used if all of the following are satisfied:

  • The environment variable PYRO_USE_TORCH_DISTRIBUTIONS=1 is set
  • The torch.distributions module exists (it is missing in PyTorch 0.2 release)
  • The torch.distributions.Normal class exists
  • All requested features are available (e.g. torch.distributions.Normal is reparameterized if reparameterized=True, also log_pdf_mask is not supported).

If any of the previous conditions are not satisfied, Pyro falls back to the standard implementation.

Tested

The torch distribution will not be exercised on travis. Tests pass locally except for an unrelated bug in PyTorch master 0.4.0a0+709fcfd.


def batch_shape(self, x=None):
x_shape = [] if x is None else x.size()
shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might lead to some hard to find bugs like #414 if the event dimensions do not match between the data and the parameters. Should we allow this kind of broadcasting, or limit it to the batch dimensions only (i.e. any x's rightmost sizes must exactly agree with sample_shape + event_shape)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have added a strict argument which should take care of this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that was actually needed to pass some of the expected-error tests.

:returns: broadcasted shape
:rtype: tuple
:raises: ValueError
"""
strict = kwargs.pop('strict', False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be nice to extend the tests for this utility function, by specifying this as another parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, will do.


def batch_shape(self, x=None):
x_shape = [] if x is None else x.size()
shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have added a strict argument which should take care of this.

Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@neerajprad neerajprad merged commit b772145 into dev Nov 28, 2017
@martinjankowiak martinjankowiak deleted the wrap-torch-distributions branch November 29, 2017 23:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants