Skip to content

Commit

Permalink
Bugfix to MixtureSameFamily's _pad_mixture_dimension (#118947)
Browse files Browse the repository at this point in the history
Fixes Issue #73792

This is a duplicate of pull request. #73864. It's a small bugfix that should have happened a long time ago, but it didn't because I didn't actually follow up with the pull request after originally submitting. That's my bad. Trying to remedy the error.

This contains a fix to _pad_mixture_dimension, which intends to count the number of dimensions in its referent tensors, but accidentally counts the number of elements (and can thus end up creating tensors with potentially thousands of dimensions by mistake). Also contains a single test for the fixed behavior.

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
Pull Request resolved: #118947
Approved by: https://github.com/soulitzer
  • Loading branch information
CJMenart authored and pytorchmergebot committed Feb 6, 2024
1 parent 499040a commit a77be63
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 6 additions & 0 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4048,6 +4048,12 @@ def test_continuous_bernoulli_shape_tensor_params(self):
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2)
self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))

@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_mixture_same_family_mean_shape(self):
mix_distribution = Categorical(torch.ones([3, 1, 3]))
component_distribution = Normal(torch.zeros([3, 3, 3]), torch.ones([3, 3, 3]))
gmm = MixtureSameFamily(mix_distribution, component_distribution)
self.assertEqual(len(gmm.mean.shape), 2)

@skipIfTorchDynamo("Not a TorchDynamo suitable test")
class TestKL(DistributionsTestCase):
Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/mixture_same_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def _pad(self, x):
return x.unsqueeze(-1 - self._event_ndims)

def _pad_mixture_dimensions(self, x):
dist_batch_ndims = self.batch_shape.numel()
cat_batch_ndims = self.mixture_distribution.batch_shape.numel()
dist_batch_ndims = len(self.batch_shape)
cat_batch_ndims = len(self.mixture_distribution.batch_shape)
pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
xs = x.shape
x = x.reshape(
Expand Down

0 comments on commit a77be63

Please sign in to comment.