Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: torch.distributions.mixture_same_distribution._pad_mixture_dimension #73792

Open
CJMenart opened this issue Mar 4, 2022 · 2 comments
Open
Labels
high priority module: correctness (silent) issue that returns an incorrect result silently module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@CJMenart
Copy link
Contributor

CJMenart commented Mar 4, 2022

馃悰 Describe the bug

I already tracked down what's going on with this one; there is a bug in mixture_same_distribution._pad_mixture_dimension. It uses numel() on Size objects in order to get the number of dimensions in its tensors--however, Size.numel() multiplies all the elements of the size tensor together (it's telling you how many numel are in the tensor whose size it represents!), thus returning what is clearly the wrong value.

This (which is in both 1.8, which I was on originally, and master):

def _pad_mixture_dimensions(self, x):

        dist_batch_ndims = self.batch_shape.numel()
        cat_batch_ndims = self.mixture_distribution.batch_shape.numel()
        pad_ndims = 0 if cat_batch_ndims == 1 else \
            dist_batch_ndims - cat_batch_ndims
        xs = x.shape
        x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) +
                      xs[-1:] + torch.Size(self._event_ndims * [1]))
        return x

Should be this:

def _pad_mixture_dimensions(self, x):
        dist_batch_ndims = len(self.batch_shape)
        cat_batch_ndims = len(self.mixture_distribution.batch_shape)
        pad_ndims = 0 if cat_batch_ndims == 1 else \
            dist_batch_ndims - cat_batch_ndims
        xs = x.shape
        x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) +
                      xs[-1:] + torch.Size(self._event_ndims * [1]))
        return x

I didn't see this error anywhere else, but I haven't really done a thorough search.
I don't know what branch it's right to make a pull request on (master?) so I thought I'd report it here first.

Versions

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.8.0
[pip3] torchvision==0.9.0
[conda] blas 2.109 mkl conda-forge
[conda] blas-devel 3.9.0 9_mkl conda-forge
[conda] cudatoolkit 10.2.89 h8f6ccaa_8 conda-forge
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libblas 3.9.0 9_mkl conda-forge
[conda] libcblas 3.9.0 9_mkl conda-forge
[conda] liblapack 3.9.0 9_mkl conda-forge
[conda] liblapacke 3.9.0 9_mkl conda-forge
[conda] mkl 2021.2.0 h726a3e6_389 conda-forge
[conda] mkl-devel 2021.2.0 ha770c72_390 conda-forge
[conda] mkl-include 2021.2.0 h726a3e6_389 conda-forge
[conda] numpy 1.20.3 py38h9894fe3_1 conda-forge
[conda] pytorch 1.8.0 py3.8_cuda10.2_cudnn7.6.5_0 pytorch
[conda] torchvision 0.9.0 py38_cu102 pytorch

cc @ezyang @gchanan @zou3519 @fritzo @neerajprad @alicanb @nikitaved

@mruberry mruberry added high priority module: correctness (silent) issue that returns an incorrect result silently module: distributions Related to torch.distributions labels Mar 7, 2022
@mruberry
Copy link
Collaborator

mruberry commented Mar 7, 2022

Would you extend this issue with a small code snippet that demonstrates the bug and shows the before/after with the proposed fix?

Yep, just open a pull request for your branch to be merged with master and we'll be able to review it.

@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Mar 7, 2022
@CJMenart
Copy link
Contributor Author

CJMenart commented Mar 7, 2022

Example given:

The below snippet, under the current version, will throw "RuntimeError: only tensors with up to 64 dims are supported", because the bug is trying to create a 1000-dimension tensor. It correctly prints the mean of the gaussian mixture (100 zeros) under the fix.

import torch
from torch.distributions import Categorical, Normal, MixtureSameFamily

def bug_example():
    # A 3d tensor of gaussian distributions
    component_distribution = Normal(torch.zeros([10,10,10]), torch.ones([10,10,10]))
    # A 3D tensor of uniform distributions over 10 classes, 
    mix_distribution = Categorical(torch.ones([10,1,10]))
    # This SHOULD be 10 mixtures of 10 10-dimensional isotropic Gaussians
    gmm = MixtureSameFamily(mix_distribution, component_distribution)
    mean = gmm.mean  # This kicks off the call to _pad_mixture_dimension
    print(mean)
    

if __name__ == '__main__':
    bug_example()

pytorchmergebot pushed a commit that referenced this issue Feb 6, 2024
Fixes Issue #73792

This is a duplicate of pull request. #73864. It's a small bugfix that should have happened a long time ago, but it didn't because I didn't actually follow up with the pull request after originally submitting. That's my bad. Trying to remedy the error.

This contains a fix to _pad_mixture_dimension, which intends to count the number of dimensions in its referent tensors, but accidentally counts the number of elements (and can thus end up creating tensors with potentially thousands of dimensions by mistake). Also contains a single test for the fixed behavior.

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
Pull Request resolved: #118947
Approved by: https://github.com/soulitzer
pytorch-bot bot pushed a commit that referenced this issue Feb 8, 2024
Fixes Issue #73792

This is a duplicate of pull request. #73864. It's a small bugfix that should have happened a long time ago, but it didn't because I didn't actually follow up with the pull request after originally submitting. That's my bad. Trying to remedy the error.

This contains a fix to _pad_mixture_dimension, which intends to count the number of dimensions in its referent tensors, but accidentally counts the number of elements (and can thus end up creating tensors with potentially thousands of dimensions by mistake). Also contains a single test for the fixed behavior.

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
Pull Request resolved: #118947
Approved by: https://github.com/soulitzer
vfdev-5 pushed a commit to vfdev-5/pytorch that referenced this issue Feb 9, 2024
Fixes Issue pytorch#73792

This is a duplicate of pull request. pytorch#73864. It's a small bugfix that should have happened a long time ago, but it didn't because I didn't actually follow up with the pull request after originally submitting. That's my bad. Trying to remedy the error.

This contains a fix to _pad_mixture_dimension, which intends to count the number of dimensions in its referent tensors, but accidentally counts the number of elements (and can thus end up creating tensors with potentially thousands of dimensions by mistake). Also contains a single test for the fixed behavior.

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
Pull Request resolved: pytorch#118947
Approved by: https://github.com/soulitzer
clee2000 pushed a commit that referenced this issue Feb 14, 2024
Fixes Issue #73792

This is a duplicate of pull request. #73864. It's a small bugfix that should have happened a long time ago, but it didn't because I didn't actually follow up with the pull request after originally submitting. That's my bad. Trying to remedy the error.

This contains a fix to _pad_mixture_dimension, which intends to count the number of dimensions in its referent tensors, but accidentally counts the number of elements (and can thus end up creating tensors with potentially thousands of dimensions by mistake). Also contains a single test for the fixed behavior.

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
Pull Request resolved: #118947
Approved by: https://github.com/soulitzer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: correctness (silent) issue that returns an incorrect result silently module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants