Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] MixtureOfDiagNormals #3274

Closed
cyianor opened this issue Oct 5, 2023 · 0 comments · Fixed by #3277
Closed

[bug] MixtureOfDiagNormals #3274

cyianor opened this issue Oct 5, 2023 · 0 comments · Fixed by #3277
Labels
Milestone

Comments

@cyianor
Copy link
Contributor

cyianor commented Oct 5, 2023

Problem description

I recently tried to use MixtureOfDiagNormals and ran across two issues while using it together with an AutoNormal guide.

  1. The distributions MixtureOfDiagNormal and MixtureOfDiagNormalsSharedCovariance do not implement the support property, which leads to problems when using the distribution for latent variables and e.g. the AutoNormal guide. This can be solved easily by adding support = constraints.real_vector at the beginning of both classes. I assume that real_vector is the correct constraint here since D=1 is not supported anyways.
  2. There seems to be a bug in the forward method of _MixDiagNormalSample. Specifically, when creating the white noise
        white = locs.new(noise_shape).normal_()
    This creates a new tensor of size (len(noise_shape),) instead of expected size noise_shape. This is actually done correctly in _MixDiagNormalSharedCovarianceSample with a call to torch.randn instead.
    The correct way to do it would be
        white = locs.new(*noise_shape).normal_()
    However, since torch.Tensor.new does not appear in the PyTorch documentation (at least I didn't find it), I assume it is recommended to use a method like torch.Tensor.new_empty instead.
        white = locs.new_empty(noise_shape).normal_()
    works as expected. Alternatively, the same approach as for MixtureOfDiagNormalSharedCovariance is possible.

Package versions

PyTorch: version 2.0.1
Pyro: version 1.8.6

This was referenced Oct 5, 2023
@fritzo fritzo added the bug label Oct 5, 2023
@fritzo fritzo added this to the 1.9 release milestone Oct 5, 2023
fritzo pushed a commit that referenced this issue Oct 5, 2023
* Added `support = constraints.real_vector` to `MixtureOfDiagNormals`
  and `MixtureOfDiagNormalsSharedCovariance`
* Fixed the white noise sampling bug in the `forward` method of
  `_MixDiagNormalSample`
* Same call in `_MixDiagNormalSample` and
  `_MixDiagNormalSharedCovarianceSample` to generate white noise
* Harmonized tensor shape error messages between `MixtureOfDiagNormals`
  and `MixtureOfDiagNormalsSharedCovariance`
* Added the correct class name in tensor shape errors for
  `MixtureOfDiagNormalsSharedCovariance`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants