Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve torch.fft n-dimensional transforms #46911

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b648bd6
Improve torch.fft n-dimensional transforms
peterbell10 Oct 27, 2020
9da2b15
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 27, 2020
e6b0d74
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 27, 2020
a1e6ca6
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 29, 2020
a6e36e5
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 29, 2020
fc080a5
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 29, 2020
e0946ec
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 29, 2020
6cc5853
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 30, 2020
3d93379
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Oct 30, 2020
8d370bf
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Nov 9, 2020
5d162c1
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Nov 11, 2020
786a17f
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Nov 11, 2020
36e9339
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Nov 21, 2020
7a22ecc
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Nov 27, 2020
0efb859
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Dec 2, 2020
f9b5cd2
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Dec 2, 2020
4f26762
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Dec 6, 2020
5b6f748
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Dec 8, 2020
e42964b
Update on "Improve torch.fft n-dimensional transforms"
peterbell10 Dec 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
126 changes: 24 additions & 102 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,12 @@ Tensor fft_c2r(Tensor input, c10::optional<int64_t> n_opt,
if (n_opt) {
input = resize_fft_input(input, dim, n/2 + 1);
}
// _fft only operates on the last dim, so transpose the selected dim to the end
const bool must_transpose = (dim != input_dim - 1);
if (must_transpose) {
input = at::transpose(input, -1, dim);
}
const auto norm = norm_from_string(norm_str, forward);
if (forward) {
// FIXME: _fft does not support complex_output=false with inverse=false
input = at::conj(input);
}
auto out = _fft(at::view_as_real(input),
/*signal_ndim=*/1, /*complex_input=*/true,
/*complex_output=*/false, /*inverse=*/true,
/*signal_sizes=*/{n}, /*normalization=*/norm,
/*onesided=*/true);
if (must_transpose) {
out = at::transpose(out, -1, dim);
}
return out;
return at::_fft_c2r(input, dim, static_cast<int64_t>(norm), n);
}

// Real to complex FFT
Expand All @@ -153,22 +140,11 @@ Tensor fft_r2c(Tensor input, c10::optional<int64_t> n_opt,
if (n_opt) {
input = resize_fft_input(input, dim, n);
}
// _fft only operates on the last dim, so transpose the selected dim to the end
const bool must_transpose = (dim != input_dim - 1);
if (must_transpose) {
input = at::transpose(input, -1, dim);
}

const auto norm = norm_from_string(norm_str, forward);
auto out = _fft(input, /*signal_ndim=*/1, /*complex_input=*/false,
/*complex_output=*/true, /*inverse=*/false,
/*signal_sizes=*/{n}, /*normalization=*/norm,
/*onesided=*/onesided);
out = at::view_as_complex(out);
if (must_transpose) {
out = at::transpose(out, -1, dim);
}
auto out = at::_fft_r2c(input, dim, static_cast<int64_t>(norm), onesided);
if (!forward) {
// FIXME: _fft does not support complex_input=false with inverse=true
// FIXME: _fft_r2c doesn't support native r2c IFFT
out = at::conj(out);
}
return out;
Expand All @@ -186,22 +162,8 @@ Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt,
if (n_opt) {
input = resize_fft_input(input, dim, n);
}
// _fft only operates on the last dim, so transpose the selected dim to the end
const bool must_transpose = (dim != input_dim - 1);
if (must_transpose) {
input = at::transpose(input, -1, dim);
}
const auto norm = norm_from_string(norm_str, forward);
auto out = _fft(at::view_as_real(input),
/*signal_ndim=*/1, /*complex_input=*/true,
/*complex_output=*/true, /*inverse=*/!forward,
/*signal_sizes=*/{}, /*normalization=*/norm,
/*onesided=*/false);
out = at::view_as_complex(out);
if (must_transpose) {
out = at::transpose(out, -1, dim);
}
return out;
return at::_fft_c2c(input, dim, static_cast<int64_t>(norm), forward);
}

// Dimensions to transform, and the signal shape in those dimensions
Expand Down Expand Up @@ -277,44 +239,12 @@ Tensor fftn_c2c(
const Tensor& input, IntArrayRef shape, IntArrayRef dim,
c10::optional<std::string> norm_str, bool forward) {
TORCH_CHECK(input.is_complex(), "Expected a complex input tensor to FFT");
const auto input_dim = input.dim();

Tensor x = resize_fft_input(input, dim, shape);
x = at::view_as_real(x);

const int64_t transform_ndim = dim.size();
const auto norm = norm_from_string(norm_str, forward);
// _fft_with_size only supports 3 dimensions being transformed at a time.
// This limit is inherited from cuFFT.
constexpr int64_t max_signal_ndim = 3;

// Transform n dimensions, up to 3 at a time
// TODO: rewrite _fft_with_size to transform more than 3 dimensions at once.
for (int64_t i = 0; i < transform_ndim; i += max_signal_ndim) {
const int64_t signal_ndim = std::min(transform_ndim - i, max_signal_ndim);
DimVector source_dim(signal_ndim);
DimVector dest_dim(signal_ndim);

for (int64_t j = 0; j < signal_ndim; ++j) {
source_dim[j] = dim[i + j];
dest_dim[j] = j + (input_dim - signal_ndim);
}

// _fft operates on up-to the last 3 dims, so move selected dims to the end
x = at::movedim(x, source_dim, dest_dim);

x = _fft(x, signal_ndim, /*complex_input=*/true, /*complex_output=*/true,
/*inverse=*/!forward, /*signal_sizes=*/{}, /*normalization=*/norm,
/*onesided=*/false);

// Move transform dims back to their original order
x = at::movedim(x, dest_dim, source_dim);
}

return at::view_as_complex(x);
return at::_fft_c2c(x, dim, static_cast<int64_t>(norm), forward);
}

}
} // namespace (anonymous)

// torch.fft.fft, analogous to NumPy's numpy.fft.fft
Tensor fft_fft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
Expand Down Expand Up @@ -370,44 +300,36 @@ Tensor fft_ifftn(const Tensor& self, c10::optional<IntArrayRef> s,

Tensor fft_rfftn(const Tensor& self, c10::optional<IntArrayRef> s,
c10::optional<IntArrayRef> dim,
c10::optional<std::string> norm) {
c10::optional<std::string> norm_str) {
TORCH_CHECK(!self.is_complex(), "Expected a real input tensor to rfftn");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Expected a real input tensor to rfftn" -> "rfftn expects a real-valued input tensor, but got {dtype}!"

auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis");

const auto last_dim = desc.dim.back();
const auto last_shape = desc.shape.back();
desc.shape.pop_back();
desc.dim.pop_back();

// rfft on last dim to get hermitian complex shape
auto x = native::fft_rfft(self, last_shape, last_dim, norm);
// Normal fft on remaining dims
return fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/true);
Tensor input = promote_tensor_fft(self, /*require_complex=*/false);
Tensor x = resize_fft_input(input, desc.dim, desc.shape);
const auto norm = norm_from_string(norm_str, /*forward=*/true);
return at::_fft_r2c(x, desc.dim, static_cast<int64_t>(norm), /*onesided=*/true);
}

Tensor fft_irfftn(const Tensor& self, c10::optional<IntArrayRef> s,
c10::optional<IntArrayRef> dim,
c10::optional<std::string> norm) {
c10::optional<std::string> norm_str) {
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
TORCH_CHECK(desc.shape.size() > 0, "irfftn must transform at least one axis");

const auto last_dim = desc.dim.back();
const auto last_shape = [&]() -> c10::optional<int64_t> {
// If shape is defaulted in the last dimension,
// pass nullopt to irfft and let it calculate the default size
const auto last_dim_size = [&] {
// Fixup default shape handling in the last dimension,
if (!s.has_value() || (s->back() == -1)) {
return c10::nullopt;
const auto last_dim = desc.dim.back();
return 2 * (self.sizes()[last_dim] - 1);
}
return desc.shape.back();
}();
desc.shape.pop_back();
desc.dim.pop_back();

// Normal ifft for all but last dim
Tensor x = promote_tensor_fft(self, /*require_complex=*/true);
x = fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/false);
// Then 1d irfft on last dim to get real output
return native::fft_irfft(x, last_shape, last_dim, norm);
desc.shape.back() = last_dim_size / 2 + 1;

Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
Tensor x = resize_fft_input(input, desc.dim, desc.shape);
const auto norm = norm_from_string(norm_str, /*forward=*/false);
return at::_fft_c2r(x, desc.dim, static_cast<int64_t>(norm), last_dim_size);
}

Tensor fft_fft2(const Tensor& self, c10::optional<IntArrayRef> s,
Expand Down