diff --git a/aten/src/ATen/native/Repeat.h b/aten/src/ATen/native/Repeat.h index e44e32b6bbdc..f57a5b7ff436 100644 --- a/aten/src/ATen/native/Repeat.h +++ b/aten/src/ATen/native/Repeat.h @@ -9,6 +9,9 @@ static inline Tensor repeat_interleave_common(const Tensor &repeats) { TORCH_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat"); TORCH_CHECK(repeats.scalar_type() == at::kLong, "repeats has to be Long tensor"); TORCH_CHECK((repeats >= 0).all().item(), "repeats can not be negative"); + if (repeats.size(0) == 0) { + return at::empty_like(repeats); + } Tensor repeats_ = repeats.contiguous(); Tensor cumsum = repeats.cumsum(0); int64_t total = cumsum[-1].item(); diff --git a/test/test_torch.py b/test/test_torch.py index 2de65c3b3b3a..87449fbab170 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9121,6 +9121,15 @@ def test_repeat_interleave(self): with self.assertRaises(RuntimeError): torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0) + # test zero sized dimension + x = torch.zeros((5, 0)) + y = torch.repeat_interleave(x, repeats=3, dim=1) + self.assertEqual(y, x.new_zeros(5, 0)) + + x = torch.tensor([], dtype=torch.int64) + y = torch.repeat_interleave(x, x) + self.assertEqual(y, x) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_repeat_tile(self):