New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support out argument in torch.fft ops #49335
Conversation
💊 CI failures summary and remediationsAs of commit 6d6569e (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. This comment has been revised 18 times. |
35c7b24
to
42385c2
Compare
Codecov Report
@@ Coverage Diff @@
## master #49335 +/- ##
=======================================
Coverage 80.56% 80.56%
=======================================
Files 1884 1885 +1
Lines 204375 204453 +78
=======================================
+ Hits 164660 164727 +67
- Misses 39715 39726 +11 |
if (!forward) { | ||
// FIXME: _fft_r2c doesn't support native r2c IFFT | ||
out = at::conj(out); | ||
return out.defined() ? at::conj_out(out, ret) : at::conj(ret); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case that out is defined, ret == out, right? Would it be more natural to perform the conj inplace instead of using its out variant?
This question is mostly academic since conj doesn't have an inplace variant. cc @anjali411 we should probably add an inplace conj_
for consistency with our other operators?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case that out is defined, ret == out, right?
No, you can see above that _fft_r2c_out
is only used if out.defined() && forward
but this is the !forward
case. This avoids doing an explicit copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry yeah I agree I don't see any reason to not add an in-place version of conj
aten/src/ATen/native/SpectralOps.cpp
Outdated
@@ -157,7 +174,12 @@ Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt, | |||
input = resize_fft_input(input, dim, n); | |||
} | |||
const auto norm = norm_from_string(norm_str, forward); | |||
return at::_fft_c2c(input, dim, static_cast<int64_t>(norm), forward); | |||
if (out.defined()) { | |||
TORCH_CHECK(out.is_complex(), "fft expects a complex output tensor, but got ", out.scalar_type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"fft" here is interesting because this warning could be triggered by many operations.
In the above warning a capital "FFT" is used, which doesn't seem much better.
Maybe this function should accept a string with the appropriate function name to display? So it could say, for example, "fft2 expects..." when torch.fft.fft2
is called?
I guess c2r and r2c already have enough logic to disambiguate which function called them, but they could adopt this approach for consistency, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the name argument but fft2
will still say fftn
in error messages because I'm calling native::fftn
rather than fftn_c2c
which expects slightly different arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can live with that.
} | ||
|
||
Tensor& fft_fftfreq_out(Tensor& out, int64_t n, double d) { | ||
ScalarType dtype = out.scalar_type(); | ||
TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype), | ||
"fftfreq requires a floating point or complex dtype"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the error messages for other tensor creation ops, like torch.arange
and torch.zeros
, like for the out vs tensor options case? What happens if both out= and some tensor options are specified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if both out= and some tensor options are specified?
The function I moved out of python_torch_functions.cpp
is automatically called by codegen to ensure the tensor options arguments are valid. For some reason this happens even though I don't take any TensorOptions
arguments here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's odd. Don't other functions, like those mentioned above, have this same issue? What's the implication of this function always being called even though no TensorOptions are in the signature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or are the fftfreq
functions the first tensor factories to hit some unique aspect of this issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the implication of this function always being called even though no TensorOptions are in the signature?
The non-out variant uses TensorOptions
to create the out tensor; the out variant respects the TensorOptions
of the out
tensor. And the python API ensures that if both are given, they must be consistent. So, everything is consistent and always has a single TensorOptions
to take.
Or are the fftfreq functions the first tensor factories to hit some unique aspect of this issue?
I don't think it's unique, I think it's just a quirk of the python codegen. For example, I see that empty.out
doesn't have a dtype
argument yet I can still do torch.empty(100, out=out, dtype=torch.float)
.
There is also no TensorOptions
argument in the native function:
pytorch/aten/src/ATen/native/TensorFactories.cpp
Lines 198 to 201 in 4883d39
Tensor& empty_out( | |
Tensor& result, | |
IntArrayRef size, | |
c10::optional<c10::MemoryFormat> optional_memory_format) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's really interesting. cc @bhosmer for yet another quirk of our current codegen.
@@ -2062,6 +2062,12 @@ | |||
CPU: _fft_r2c_mkl | |||
CUDA: _fft_r2c_cufft | |||
|
|||
- func: _fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I wrote this I had to drop use_c10_dispatcher: full
because the out argument can't go at the end since it doesn't have a default. Looks like hacky_wrapper_for_legacy_signatures
works fine though and it compiles after rebasing on #49510.
namespace torch { | ||
namespace utils { | ||
|
||
void check_out_type_matches(const at::Tensor& result, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment describing this function & how it's intended to be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the thinking for moving this to a new file? As best I can tell it's still only used in that file, so moving this doesn't seem like a natural part of this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's called by the generated code whenever a factory function has an out argument. So, it needs to be available in python_fft_functions.cpp
for the fftfreq
functions.
return; | ||
} | ||
if (!scalarType_is_none && result.scalar_type() != scalarType) { | ||
AT_ERROR( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AT_ERROR is deprecated in favor of TORCH_CHECK (for errors that a user input can trigger) or TORCH_INTERNAL_ASSERT (for internal logic failures that should never occur).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just moved from the existing code.
} | ||
if (!scalarType_is_none && result.scalar_type() != scalarType) { | ||
AT_ERROR( | ||
"dtype ", scalarType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this change is kept, this will probably also need to be passed a string to produce an error message associating the calling function with the error - although that's a bigger change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Altering codegen seems unnecessary for this PR. The calling function is clearly visible in the stack trace anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK.
@skipCUDAIfRocm | ||
@onlyOnCPUAndCUDA | ||
@dtypes(torch.float, torch.double) | ||
def test_fftfreq_out(self, device, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we should make these OpInfos in the (near) future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any OpInfos
for tensor creation functions yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to start somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @peterbell10! Overall this looks great, as usual. Just a few comments (inline) and (probably) a merge conflict to resolve.
Let me know if you have questions about updating the dispatch metadata. @smessmer - we should probably update the ATen/native README?
42385c2
to
6d6569e
Compare
Hey @peterbell10, just checking in on this. Is this gtg? |
Yes @mruberry, this should be gtg. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Thanks @peterbell10!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Ref pytorch#42175 This adds out argument support to all functions in the `torch.fft` namespace except for `fftshift` and `ifftshift` because they rely on `at::roll` which doesn't have an out argument version. Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from `out` arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to call `at::mul_out` instead of an inplace multiply. If it's desirable, I could add a special case to transform into the output when `out.numel() == 0` since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation where `out` having the correct shape follows a different code path from `out.resize_(0)`. Pull Request resolved: pytorch#49335 Reviewed By: mrshenli Differential Revision: D25756635 Pulled By: mruberry fbshipit-source-id: d29843f024942443c8857139a2abdde09affd7d6
Summary: An oversight from pytorch#49335, the documentation was never updated to include `out` arguments. Pull Request resolved: pytorch#56732 Reviewed By: ezyang Differential Revision: D27960478 Pulled By: mruberry fbshipit-source-id: a342a4f590369d6d2e17bed014fa64e49ee72936
Ref #42175
This adds out argument support to all functions in the
torch.fft
namespace except forfftshift
andifftshift
because they rely onat::roll
which doesn't have an out argument version.Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from
out
arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to callat::mul_out
instead of an inplace multiply.If it's desirable, I could add a special case to transform into the output when
out.numel() == 0
since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation whereout
having the correct shape follows a different code path fromout.resize_(0)
.