Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support out argument in torch.fft ops #49335

Closed
wants to merge 2 commits into from

Conversation

peterbell10
Copy link
Collaborator

Ref #42175

This adds out argument support to all functions in the torch.fft namespace except for fftshift and ifftshift because they rely on at::roll which doesn't have an out argument version.

Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from out arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to call at::mul_out instead of an inplace multiply.

If it's desirable, I could add a special case to transform into the output when out.numel() == 0 since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation where out having the correct shape follows a different code path from out.resize_(0).

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Dec 14, 2020

💊 CI failures summary and remediations

As of commit 6d6569e (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 18 times.

@peterbell10 peterbell10 force-pushed the fft-out-argument branch 3 times, most recently from 35c7b24 to 42385c2 Compare December 15, 2020 03:08
@codecov
Copy link

codecov bot commented Dec 15, 2020

Codecov Report

Merging #49335 (6d6569e) into master (1ac05cf) will increase coverage by 0.00%.
The diff coverage is 92.06%.

@@           Coverage Diff           @@
##           master   #49335   +/-   ##
=======================================
  Coverage   80.56%   80.56%           
=======================================
  Files        1884     1885    +1     
  Lines      204375   204453   +78     
=======================================
+ Hits       164660   164727   +67     
- Misses      39715    39726   +11     

@mrshenli mrshenli added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 18, 2020
if (!forward) {
// FIXME: _fft_r2c doesn't support native r2c IFFT
out = at::conj(out);
return out.defined() ? at::conj_out(out, ret) : at::conj(ret);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case that out is defined, ret == out, right? Would it be more natural to perform the conj inplace instead of using its out variant?

This question is mostly academic since conj doesn't have an inplace variant. cc @anjali411 we should probably add an inplace conj_ for consistency with our other operators?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case that out is defined, ret == out, right?

No, you can see above that _fft_r2c_out is only used if out.defined() && forward but this is the !forward case. This avoids doing an explicit copy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry yeah I agree I don't see any reason to not add an in-place version of conj

@@ -157,7 +174,12 @@ Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt,
input = resize_fft_input(input, dim, n);
}
const auto norm = norm_from_string(norm_str, forward);
return at::_fft_c2c(input, dim, static_cast<int64_t>(norm), forward);
if (out.defined()) {
TORCH_CHECK(out.is_complex(), "fft expects a complex output tensor, but got ", out.scalar_type());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"fft" here is interesting because this warning could be triggered by many operations.

In the above warning a capital "FFT" is used, which doesn't seem much better.

Maybe this function should accept a string with the appropriate function name to display? So it could say, for example, "fft2 expects..." when torch.fft.fft2 is called?

I guess c2r and r2c already have enough logic to disambiguate which function called them, but they could adopt this approach for consistency, too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the name argument but fft2 will still say fftn in error messages because I'm calling native::fftn rather than fftn_c2c which expects slightly different arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can live with that.

}

Tensor& fft_fftfreq_out(Tensor& out, int64_t n, double d) {
ScalarType dtype = out.scalar_type();
TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype),
"fftfreq requires a floating point or complex dtype");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the error messages for other tensor creation ops, like torch.arange and torch.zeros, like for the out vs tensor options case? What happens if both out= and some tensor options are specified?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if both out= and some tensor options are specified?

The function I moved out of python_torch_functions.cpp is automatically called by codegen to ensure the tensor options arguments are valid. For some reason this happens even though I don't take any TensorOptions arguments here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's odd. Don't other functions, like those mentioned above, have this same issue? What's the implication of this function always being called even though no TensorOptions are in the signature?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or are the fftfreq functions the first tensor factories to hit some unique aspect of this issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the implication of this function always being called even though no TensorOptions are in the signature?

The non-out variant uses TensorOptions to create the out tensor; the out variant respects the TensorOptions of the out tensor. And the python API ensures that if both are given, they must be consistent. So, everything is consistent and always has a single TensorOptions to take.

Or are the fftfreq functions the first tensor factories to hit some unique aspect of this issue?

I don't think it's unique, I think it's just a quirk of the python codegen. For example, I see that empty.out doesn't have a dtype argument yet I can still do torch.empty(100, out=out, dtype=torch.float).

There is also no TensorOptions argument in the native function:

Tensor& empty_out(
Tensor& result,
IntArrayRef size,
c10::optional<c10::MemoryFormat> optional_memory_format) {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's really interesting. cc @bhosmer for yet another quirk of our current codegen.

@@ -2062,6 +2062,12 @@
CPU: _fft_r2c_mkl
CUDA: _fft_r2c_cufft

- func: _fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!)
Copy link
Collaborator

@mruberry mruberry Dec 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine these new registrations run afoul of #49510 and need an update.

cc @smessmer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I wrote this I had to drop use_c10_dispatcher: full because the out argument can't go at the end since it doesn't have a default. Looks like hacky_wrapper_for_legacy_signatures works fine though and it compiles after rebasing on #49510.

namespace torch {
namespace utils {

void check_out_type_matches(const at::Tensor& result,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment describing this function & how it's intended to be used

Copy link
Collaborator

@mruberry mruberry Dec 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the thinking for moving this to a new file? As best I can tell it's still only used in that file, so moving this doesn't seem like a natural part of this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called by the generated code whenever a factory function has an out argument. So, it needs to be available in python_fft_functions.cpp for the fftfreq functions.

return;
}
if (!scalarType_is_none && result.scalar_type() != scalarType) {
AT_ERROR(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AT_ERROR is deprecated in favor of TORCH_CHECK (for errors that a user input can trigger) or TORCH_INTERNAL_ASSERT (for internal logic failures that should never occur).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just moved from the existing code.

}
if (!scalarType_is_none && result.scalar_type() != scalarType) {
AT_ERROR(
"dtype ", scalarType,
Copy link
Collaborator

@mruberry mruberry Dec 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this change is kept, this will probably also need to be passed a string to produce an error message associating the calling function with the error - although that's a bigger change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Altering codegen seems unnecessary for this PR. The calling function is clearly visible in the stack trace anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double)
def test_fftfreq_out(self, device, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we should make these OpInfos in the (near) future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any OpInfos for tensor creation functions yet?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to start somewhere.

test/test_spectral_ops.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @peterbell10! Overall this looks great, as usual. Just a few comments (inline) and (probably) a merge conflict to resolve.

Let me know if you have questions about updating the dispatch metadata. @smessmer - we should probably update the ATen/native README?

@mruberry
Copy link
Collaborator

mruberry commented Jan 3, 2021

Hey @peterbell10, just checking in on this. Is this gtg?

@peterbell10
Copy link
Collaborator Author

Yes @mruberry, this should be gtg.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks @peterbell10!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 2639114.

hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 14, 2021
Summary:
Ref pytorch#42175

This adds out argument support to all functions in the `torch.fft` namespace except for `fftshift` and `ifftshift` because they rely on `at::roll` which doesn't have an out argument version.

Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from `out` arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to call `at::mul_out` instead of an inplace multiply.

If it's desirable, I could add a special case to transform into the output when `out.numel() == 0` since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation where `out` having the correct shape follows a different code path from `out.resize_(0)`.

Pull Request resolved: pytorch#49335

Reviewed By: mrshenli

Differential Revision: D25756635

Pulled By: mruberry

fbshipit-source-id: d29843f024942443c8857139a2abdde09affd7d6
facebook-github-bot pushed a commit that referenced this pull request Apr 25, 2021
Summary:
An oversight from #49335, the documentation was never updated to include `out` arguments.

Pull Request resolved: #56732

Reviewed By: ezyang

Differential Revision: D27960478

Pulled By: mruberry

fbshipit-source-id: a342a4f590369d6d2e17bed014fa64e49ee72936
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
An oversight from pytorch#49335, the documentation was never updated to include `out` arguments.

Pull Request resolved: pytorch#56732

Reviewed By: ezyang

Differential Revision: D27960478

Pulled By: mruberry

fbshipit-source-id: a342a4f590369d6d2e17bed014fa64e49ee72936
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants