diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index c6688b286914..889ccf606152 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -250,7 +250,7 @@ void kthvalue_cuda_template( int64_t dim_, bool keepdim) { int64_t dim = maybe_wrap_dim(dim_, self.dim()); - int64_t slicesize = self.size(dim); + int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim); // FIXME: This seems bogus, I only do this because it was the old behaviour. // The reductions are fine, as long as the axis being reduced along // isn't of 0 elements (and the output has elements). diff --git a/test/test_torch.py b/test/test_torch.py index b1dd0ec92d9a..cf25ccb3bde5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9322,6 +9322,11 @@ def test_kthvalue(self, device, dtype): self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) + # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818) + # Tests that passing a scalar tensor or 1D tensor with 1 element work either way + x = torch.tensor([2], device=device, dtype=dtype) + self.assertEqual(x.squeeze().kthvalue(1), x.kthvalue(1)) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c61f6e709afe..54fbe810d7e1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -867,11 +867,9 @@ def method_tests(): ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]), ('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]), ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]), - # TODO: https://github.com/pytorch/pytorch/issues/30818 - ('kthvalue', (), (1,), 'scalar', (), (), [expectedFailureCUDA]), - ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1], [expectedFailureCUDA]), - ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1], [expectedFailureCUDA]), - # END TODO + ('kthvalue', (), (1,), 'scalar', (), ()), + ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]), + ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]), ('quantile', (S, S, S), (0.5,)), ('quantile', (S, S, S), (0.5, 0), 'dim', (), [1]), ('quantile', (S, S, S), (0.5, None, True), 'keepdim'),