From 5c0ce872a4ec27cdb594c45335f5f220919c342f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 24 Jul 2019 12:52:45 -0700 Subject: [PATCH] Fix categorical sample view --- torch/distributions/categorical.py | 7 ++----- torch/distributions/transforms.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 14cd7d0142cda..685673feeb810 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -103,12 +103,9 @@ def sample(self, sample_shape=torch.Size()): sample_shape = self._extended_shape(sample_shape) param_shape = sample_shape + torch.Size((self._num_events,)) probs = self.probs.expand(param_shape) - if self.probs.dim() == 1 or self.probs.size(0) == 1: - probs_2d = probs.view(-1, self._num_events) - else: - probs_2d = probs.contiguous().view(-1, self._num_events) + probs_2d = probs.reshape(-1, self._num_events) sample_2d = torch.multinomial(probs_2d, 1, True) - return sample_2d.contiguous().view(sample_shape) + return sample_2d.reshape(sample_shape) def log_prob(self, value): if self._validate_args: diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index c7e99e188ac3a..8ce6464860032 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -555,7 +555,7 @@ def _call(self, x): return torch.stack([self._call_on_event(flat_x[i]) for i in range(flat_x.size(0))]).view(x.shape) def _inverse(self, y): - flat_y = y.contiguous().view((-1,) + y.shape[-2:]) + flat_y = y.reshape((-1,) + y.shape[-2:]) return torch.stack([self._inverse_on_event(flat_y[i]) for i in range(flat_y.size(0))]).view(y.shape)