-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] torch.kthvalue : handle non-contiguous CUDA tensor #45802
Conversation
@@ -241,7 +241,7 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda( | |||
bool keepdim) { | |||
auto result = [&]() { | |||
NoNamesGuard guard; | |||
return kthvalue_out_impl_cuda(values, indices, self, k, dim, keepdim); | |||
return kthvalue_out_impl_cuda(values, indices, self.contiguous(), k, dim, keepdim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you just add a comment here or above the kthvalue_out_impl_cuda() that the function expects (and only works properly on) contiguous tensors?
@@ -9010,6 +9010,13 @@ def test_kthvalue(self, device, dtype): | |||
self.assertEqual(res1val, res2val, atol=0, rtol=0) | |||
self.assertEqual(res1ind, res2ind, atol=0, rtol=0) | |||
|
|||
# non-contiguous [Reference: https://github.com/pytorch/pytorch/issues/45721] | |||
non_contig_t = torch.tensor([0, -1, 1, -2, 2], dtype=dtype, device=device)[::2] | |||
expected_val, expected_ind = non_contig_t.contiguous().kthvalue(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you also compare it with the cpu?
expected_val_cpu, expected_ind_cpu = non_contig_t.kthvalue(2)
...
self.assertEqual(expected_val_cpu, out_val.cpu(), atol=0, rtol=0)
self.assertEqual(expected_ind_cpu, out_ind.cpu(), atol=0, rtol=0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kshitij12345 Thanks for fixing this high priority issue! This is a great, simple fix. Maybe in the future we should consider if the kth value kernel should be updated to handle discontiguous inputs, too.
I just made two small suggestions, and it looks like this PR needs to rebased, too (sorry about that). Just ping me when they're done!
Since this fixes a silent correctness issue I think we'll want to put it in the 1.7 release, too.
Have addressed the comments. PTAL:) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Codecov Report
@@ Coverage Diff @@
## master #45802 +/- ##
=======================================
Coverage 68.20% 68.20%
=======================================
Files 410 410
Lines 53453 53453
=======================================
+ Hits 36457 36458 +1
+ Misses 16996 16995 -1
Continue to review full report at Codecov.
|
Fixes #45721
TODO