New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add missing complex support for torch.norm and torch.linalg.norm #48284
Add missing complex support for torch.norm and torch.linalg.norm #48284
Conversation
💊 CI failures summary and remediationsAs of commit df1244f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 64 times. |
ba0ae51
to
1408afd
Compare
1408afd
to
d279241
Compare
d279241
to
684bbd6
Compare
Rocm test failure seems to be real |
@mruberry, in numpy, if we give norm a complex number, it returns the downgraded real type:
But in pytorch, we match the input type:
Is it alright if we BC-break to fix this? |
Codecov Report
@@ Coverage Diff @@
## master #48284 +/- ##
==========================================
- Coverage 80.74% 80.73% -0.01%
==========================================
Files 1869 1869
Lines 201654 201650 -4
==========================================
- Hits 162818 162802 -16
- Misses 38836 38848 +12 |
d80710e
to
bb5d498
Compare
@@ -174,61 +174,72 @@ static void norm_kernel_tensor_iterator_impl( | |||
if (p.isIntegral(false)) { | |||
val = p.to<int64_t>(); | |||
} else if (p.isFloatingPoint()) { | |||
val = p.to<float>(); | |||
val = p.to<double>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change from float to double here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, it seems better to cast to double to allow more precision for p
.
inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { | ||
return acc + data * data; | ||
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { | ||
acc_t abs_data = std::abs(data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between the previous expression and the new one with std::abs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The std::abs()
call is needed when scalar_t
is complex. For real numbers, it's not needed. I could add an overload for complex numbers so that we avoid calling std::abs()
for real numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding an overload sounds good as long as it doesn't add too much code complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added abs_if_complex()
. Not too sure if this is the most efficient way to implement it. Let me know if you think this adds too much complexity.
Hey @kurtamohler! This PR looks great. It has a surgical precision. I made a few comments that are mostly me asking questions to better understand what's going on. I think this PR should update the documentation, too, to more accurately describe torch.norm's and torch.linalg.norm's complex support. Also, how are we testing complex autograd for torch.linalg.norm? Should something like this be updated? Also, the lint build should be fixed if you rebase. |
Thanks for the review @mruberry! Yes, more accurate documentation is a good idea, and I will look into autograd testing |
bb5d498
to
a96a3c3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thanks @kurtamohler.
Two small cleanup suggestions in the docs and one question about using torch.promote_types() to simplify a test. Just let me know when this is ready to merge.
c5feedf
to
3669243
Compare
@mruberry, I think this is ready to merge |
Looks like the return dtype change (return real when given complex) is breaking one of the tests for
So I can fix the failing errors and add |
73deca9
to
41df1fa
Compare
I'm not sure what's causing the |
It doesn't look related. We can probably ignore it (unless the same failure happens again). |
41df1fa
to
4e944f4
Compare
In that case, I think this is ready to merge @mruberry , unless another test fails after I rebased |
All the failures are due to an upstream mypy issue |
torch/linalg/__init__.py
Outdated
will be returned. Its data type must be either a floating point or complex type. For complex | ||
inputs, the norm is calculated on of the absolute values of each element. If the input is | ||
complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will be | ||
the corresponding downgraded real number type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence needs to be updated to be consistent with the torch/functional.py documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the torch.linalg.cond documentation need a similar update, too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated this and torch.linalg.cond
Hey @kurtamohler! Just a few comments/questions. Overall things look great. |
4e944f4
to
df1244f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@@ -1679,7 +1677,7 @@ static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optio | |||
// when the input contains extreme values (like nan or +/-inf) or if the input | |||
// size is degenerate (like size(0), size(0, N), etc) | |||
case_was_overridden = true; | |||
self_ = self.abs(); | |||
self_ = self_.abs(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we just declare self_ without assignment above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional dtype conversion opt_dtype.has_value() ? self.to(opt_dtype.value()) : self
needs to be performed for all cases, which is the reason why self_
is defined
BC-breaking note:
Previously, when given a complex input,
torch.linalg.norm
andtorch.norm
would return a complex output.torch.linalg.cond
would sometimes return a complex output and sometimes return a real output when given a complex input, depending on itsp
argument. This PR changes this behavior to matchnumpy.linalg.norm
andnumpy.linalg.cond
, so that a complex input will result in the downgraded real number type, consistent with NumPy.PR Summary:
The following cases were previously unsupported for complex inputs, and this commit adds support:
Part of #47833