Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing complex support for torch.norm and torch.linalg.norm #48284

Closed

Conversation

kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented Nov 20, 2020

BC-breaking note:

Previously, when given a complex input, torch.linalg.norm and torch.norm would return a complex output. torch.linalg.cond would sometimes return a complex output and sometimes return a real output when given a complex input, depending on its p argument. This PR changes this behavior to match numpy.linalg.norm and numpy.linalg.cond, so that a complex input will result in the downgraded real number type, consistent with NumPy.

PR Summary:

The following cases were previously unsupported for complex inputs, and this commit adds support:

  • Frobenius norm
  • Norm order 2 (vector and matrix)
  • CUDA vector norm

Part of #47833

@dr-ci
Copy link

dr-ci bot commented Nov 20, 2020

💊 CI failures summary and remediations

As of commit df1244f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 64 times.

@kurtamohler kurtamohler force-pushed the norm-complex-support-47833 branch 2 times, most recently from ba0ae51 to 1408afd Compare November 24, 2020 00:08
test/test_linalg.py Outdated Show resolved Hide resolved
@kurtamohler kurtamohler changed the title WIP: Add missing complex support for torch.norm and torch.linalg.norm Add missing complex support for torch.norm and torch.linalg.norm Nov 24, 2020
@kurtamohler kurtamohler marked this pull request as ready for review November 24, 2020 18:15
@kurtamohler
Copy link
Collaborator Author

Rocm test failure seems to be real

@kurtamohler
Copy link
Collaborator Author

@mruberry, in numpy, if we give norm a complex number, it returns the downgraded real type:

>>> numpy.linalg.norm([1+1j]).dtype
dtype('float64')

But in pytorch, we match the input type:

>>> torch.linalg.norm(torch.tensor([1+1j])).dtype
torch.complex64

Is it alright if we BC-break to fix this?

@codecov
Copy link

codecov bot commented Nov 24, 2020

Codecov Report

Merging #48284 (df1244f) into master (274ce26) will decrease coverage by 0.00%.
The diff coverage is 93.33%.

@@            Coverage Diff             @@
##           master   #48284      +/-   ##
==========================================
- Coverage   80.74%   80.73%   -0.01%     
==========================================
  Files        1869     1869              
  Lines      201654   201650       -4     
==========================================
- Hits       162818   162802      -16     
- Misses      38836    38848      +12     

@kurtamohler kurtamohler force-pushed the norm-complex-support-47833 branch 2 times, most recently from d80710e to bb5d498 Compare November 25, 2020 22:14
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 26, 2020
@@ -174,61 +174,72 @@ static void norm_kernel_tensor_iterator_impl(
if (p.isIntegral(false)) {
val = p.to<int64_t>();
} else if (p.isFloatingPoint()) {
val = p.to<float>();
val = p.to<double>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change from float to double here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, it seems better to cast to double to allow more precision for p.

inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const {
return acc + data * data;
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
acc_t abs_data = std::abs(data);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between the previous expression and the new one with std::abs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The std::abs() call is needed when scalar_t is complex. For real numbers, it's not needed. I could add an overload for complex numbers so that we avoid calling std::abs() for real numbers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an overload sounds good as long as it doesn't add too much code complexity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added abs_if_complex(). Not too sure if this is the most efficient way to implement it. Let me know if you think this adds too much complexity.

@mruberry
Copy link
Collaborator

mruberry commented Nov 27, 2020

Hey @kurtamohler!

This PR looks great. It has a surgical precision. I made a few comments that are mostly me asking questions to better understand what's going on.

I think this PR should update the documentation, too, to more accurately describe torch.norm's and torch.linalg.norm's complex support.

Also, how are we testing complex autograd for torch.linalg.norm? Should something like this be updated?

https://github.com/pytorch/pytorch/blob/bb5d4984b912f3f9f775b00b31de198fd3d01a7f/test/test_linalg.py#L937

Also, the lint build should be fixed if you rebase.

@kurtamohler
Copy link
Collaborator Author

Thanks for the review @mruberry! Yes, more accurate documentation is a good idea, and I will look into autograd testing

torch/linalg/__init__.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks @kurtamohler.

Two small cleanup suggestions in the docs and one question about using torch.promote_types() to simplify a test. Just let me know when this is ready to merge.

@kurtamohler
Copy link
Collaborator Author

@mruberry, I think this is ready to merge

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Dec 4, 2020

Looks like the return dtype change (return real when given complex) is breaking one of the tests for torch.linalg.cond because it depends on torch.linalg.norm. In numpy, cond has the same return type behavior as norm:

>>> a = numpy.random.rand(10, 10) + 1j * numpy.random.rand(10, 10)
>>> a.dtype
dtype('complex128')
>>> numpy.linalg.cond(a).dtype
dtype('float64')

So I can fix the failing errors and add cond to the BC breaking note above.

@kurtamohler kurtamohler force-pushed the norm-complex-support-47833 branch 3 times, most recently from 73deca9 to 41df1fa Compare December 4, 2020 23:31
@kurtamohler
Copy link
Collaborator Author

I'm not sure what's causing the pytorch-linux-bionic-rocm3.9-py3.6 failure. I'll look into it

@kurtamohler kurtamohler mentioned this pull request Dec 6, 2020
@mruberry
Copy link
Collaborator

mruberry commented Dec 6, 2020

I'm not sure what's causing the pytorch-linux-bionic-rocm3.9-py3.6 failure. I'll look into it

It doesn't look related. We can probably ignore it (unless the same failure happens again).

@kurtamohler
Copy link
Collaborator Author

In that case, I think this is ready to merge @mruberry , unless another test fails after I rebased

@kurtamohler
Copy link
Collaborator Author

All the failures are due to an upstream mypy issue

@mruberry mruberry added the module: bc-breaking Related to a BC-breaking change label Dec 8, 2020
torch/functional.py Outdated Show resolved Hide resolved
will be returned. Its data type must be either a floating point or complex type. For complex
inputs, the norm is calculated on of the absolute values of each element. If the input is
complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will be
the corresponding downgraded real number type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence needs to be updated to be consistent with the torch/functional.py documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the torch.linalg.cond documentation need a similar update, too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated this and torch.linalg.cond

@mruberry
Copy link
Collaborator

mruberry commented Dec 8, 2020

Hey @kurtamohler! Just a few comments/questions. Overall things look great.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 54f0556.

@@ -1679,7 +1677,7 @@ static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optio
// when the input contains extreme values (like nan or +/-inf) or if the input
// size is degenerate (like size(0), size(0, N), etc)
case_was_overridden = true;
self_ = self.abs();
self_ = self_.abs();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we just declare self_ without assignment above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditional dtype conversion opt_dtype.has_value() ? self.to(opt_dtype.value()) : self needs to be performed for all cases, which is the reason why self_ is defined

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: bc-breaking Related to a BC-breaking change open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support torch.linalg.norm for complex tensors on both CPU and CUDA
6 participants