New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add complex support for torch.mean [CUDA] #47048
Add complex support for torch.mean [CUDA] #47048
Conversation
💊 CI failures summary and remediationsAs of commit de87d22 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_vulkan_linux_bionic_py3_6_clang9_build (1/1)Step: "Build" (full log | diagnosis details | 🔁 rerun)
|
d618111
to
b0de1be
Compare
Codecov Report
@@ Coverage Diff @@
## master #47048 +/- ##
===========================================
+ Coverage 35.95% 53.27% +17.31%
===========================================
Files 438 2747 +2309
Lines 55454 254304 +198850
===========================================
+ Hits 19939 135476 +115537
- Misses 35515 118828 +83313 |
@anjali411 Thanks so much for reviewing this PR. |
@@ -36,26 +36,40 @@ static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_s | |||
|
|||
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting acc_t = typename c10::scalar_value_type<scalar_t>::type
should resolve the issue here
c10::scalar_value_type<scalar_t>::type
returns scalar_t
for all non-complex dtypes and returns T
for c10::complex<T>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the tip. The latest code now manipulates c10::scalar_value_type<acc_t>::type
to get the type of the factor, the overload functions for complex numbers are not needed.
Hi @RockingJavaBean I think we shouldn't need to define overload functions for complex types, after the change I suggested in my comment. But this PR looks great overall, and should be ready to merge after that change! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@@ -514,7 +514,8 @@ def test_mean_dim(self): | |||
self._test_dim_ops( | |||
lambda t, d: t.mean(d), | |||
lambda n, d: n.mean(d), | |||
use_integral=False) | |||
use_integral=False, | |||
use_complex=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test doesn't run on CUDA. can you please extend the test for mean in tensor_op_tests
to also test complex dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out, the tests for complex dtypes are added to tensor_op_tests
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's update the CUDA test for mean to test complex dtypes as well
@anjali411 I'm really grateful for your tip on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm and the windows test failure is an upstream test failure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@anjali411 thanks so much for reviewing this PR, the CUDA tests for
|
@RockingJavaBean can you please rebase? |
…h_mean_complex
@anjali411 thank you so much for the kind reminder, I just rebased this PR with the latest code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@anjali411 merged this pull request in f90da88. |
Fixes #46982