From 1a3de87466a32023700ff0b64614c527c8e69e98 Mon Sep 17 00:00:00 2001 From: Heitor Schueroff Date: Tue, 10 Nov 2020 08:00:05 -0800 Subject: [PATCH] Fix kthvalue error for scalar input ghstack-source-id: 393880d63a051d07d83378b960124106f0561f6b Pull Request resolved: https://github.com/pytorch/pytorch/pull/47600 --- aten/src/ATen/native/cuda/Sorting.cu | 2 +- test/test_torch.py | 5 +++++ torch/testing/_internal/common_methods_invocations.py | 8 +++----- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index c6688b286914..889ccf606152 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -250,7 +250,7 @@ void kthvalue_cuda_template( int64_t dim_, bool keepdim) { int64_t dim = maybe_wrap_dim(dim_, self.dim()); - int64_t slicesize = self.size(dim); + int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim); // FIXME: This seems bogus, I only do this because it was the old behaviour. // The reductions are fine, as long as the axis being reduced along // isn't of 0 elements (and the output has elements). diff --git a/test/test_torch.py b/test/test_torch.py index b1dd0ec92d9a..cf25ccb3bde5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9322,6 +9322,11 @@ def test_kthvalue(self, device, dtype): self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) + # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818) + # Tests that passing a scalar tensor or 1D tensor with 1 element work either way + x = torch.tensor([2], device=device, dtype=dtype) + self.assertEqual(x.squeeze().kthvalue(1), x.kthvalue(1)) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c61f6e709afe..54fbe810d7e1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -867,11 +867,9 @@ def method_tests(): ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]), ('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]), ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]), - # TODO: https://github.com/pytorch/pytorch/issues/30818 - ('kthvalue', (), (1,), 'scalar', (), (), [expectedFailureCUDA]), - ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1], [expectedFailureCUDA]), - ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1], [expectedFailureCUDA]), - # END TODO + ('kthvalue', (), (1,), 'scalar', (), ()), + ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]), + ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]), ('quantile', (S, S, S), (0.5,)), ('quantile', (S, S, S), (0.5, 0), 'dim', (), [1]), ('quantile', (S, S, S), (0.5, None, True), 'keepdim'),