Skip to content

Commit

Permalink
Fix kthvalue error for scalar input
Browse files Browse the repository at this point in the history
ghstack-source-id: 393880d63a051d07d83378b960124106f0561f6b
Pull Request resolved: #47600
  • Loading branch information
heitorschueroff committed Nov 10, 2020
1 parent 22d2141 commit 1a3de87
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Sorting.cu
Expand Up @@ -250,7 +250,7 @@ void kthvalue_cuda_template(
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
int64_t slicesize = self.size(dim);
int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
// FIXME: This seems bogus, I only do this because it was the old behaviour.
// The reductions are fine, as long as the axis being reduced along
// isn't of 0 elements (and the output has elements).
Expand Down
5 changes: 5 additions & 0 deletions test/test_torch.py
Expand Up @@ -9322,6 +9322,11 @@ def test_kthvalue(self, device, dtype):
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0)
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0)

# Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818)
# Tests that passing a scalar tensor or 1D tensor with 1 element work either way
x = torch.tensor([2], device=device, dtype=dtype)
self.assertEqual(x.squeeze().kthvalue(1), x.kthvalue(1))

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
Expand Down
8 changes: 3 additions & 5 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -867,11 +867,9 @@ def method_tests():
('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]),
('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]),
('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]),
# TODO: https://github.com/pytorch/pytorch/issues/30818
('kthvalue', (), (1,), 'scalar', (), (), [expectedFailureCUDA]),
('kthvalue', (), (1, 0,), 'scalar_dim', (), [1], [expectedFailureCUDA]),
('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1], [expectedFailureCUDA]),
# END TODO
('kthvalue', (), (1,), 'scalar', (), ()),
('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]),
('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]),
('quantile', (S, S, S), (0.5,)),
('quantile', (S, S, S), (0.5, 0), 'dim', (), [1]),
('quantile', (S, S, S), (0.5, None, True), 'keepdim'),
Expand Down

0 comments on commit 1a3de87

Please sign in to comment.