Skip to content

Commit

Permalink
Zero sized tensor support for repeat_interleave (#23717)
Browse files Browse the repository at this point in the history
Summary:
Fixes #22753
Pull Request resolved: #23717

Differential Revision: D16623598

Pulled By: mrshenli

fbshipit-source-id: 297a3274fb5a5b2fcc0c3ad601337d7eb29fdca2
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Aug 5, 2019
1 parent f87a4cc commit 520982d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/native/Repeat.h
Expand Up @@ -9,6 +9,9 @@ static inline Tensor repeat_interleave_common(const Tensor &repeats) {
TORCH_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
TORCH_CHECK(repeats.scalar_type() == at::kLong, "repeats has to be Long tensor");
TORCH_CHECK((repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
if (repeats.size(0) == 0) {
return at::empty_like(repeats);
}
Tensor repeats_ = repeats.contiguous();
Tensor cumsum = repeats.cumsum(0);
int64_t total = cumsum[-1].item<int64_t>();
Expand Down
9 changes: 9 additions & 0 deletions test/test_torch.py
Expand Up @@ -9121,6 +9121,15 @@ def test_repeat_interleave(self):
with self.assertRaises(RuntimeError):
torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0)

# test zero sized dimension
x = torch.zeros((5, 0))
y = torch.repeat_interleave(x, repeats=3, dim=1)
self.assertEqual(y, x.new_zeros(5, 0))

x = torch.tensor([], dtype=torch.int64)
y = torch.repeat_interleave(x, x)
self.assertEqual(y, x)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_repeat_tile(self):

Expand Down

0 comments on commit 520982d

Please sign in to comment.