Skip to content

Commit

Permalink
Update on "Improve torch.fft n-dimensional transforms"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
peterbell10 committed Dec 8, 2020
2 parents 5b6f748 + 3e9686f commit e42964b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ Tensor fft_ifftn(const Tensor& self, c10::optional<IntArrayRef> s,
Tensor fft_rfftn(const Tensor& self, c10::optional<IntArrayRef> s,
c10::optional<IntArrayRef> dim,
c10::optional<std::string> norm_str) {
TORCH_CHECK(!self.is_complex(), "rfftn expects a real input tensor, but got ", self.scalar_type());
TORCH_CHECK(!self.is_complex(), "rfftn expects a real-valued input tensor, but got ", self.scalar_type());
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis");
Tensor input = promote_tensor_fft(self, /*require_complex=*/false);
Expand Down
6 changes: 3 additions & 3 deletions test/test_spectral_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_fft_invalid_dtypes(self, device):
with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"):
torch.fft.rfft(t)

with self.assertRaisesRegex(RuntimeError, "rfftn expects a real input tensor"):
with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"):
torch.fft.rfftn(t)

with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"):
Expand Down Expand Up @@ -479,7 +479,7 @@ def test_fftn_invalid(self, device):
func(a, s=(10, 10, 10, 10))

c = torch.complex(a, a)
with self.assertRaisesRegex(RuntimeError, "rfftn expects a real input"):
with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"):
torch.fft.rfftn(c)

# 2d-fft tests
Expand Down Expand Up @@ -591,7 +591,7 @@ def test_fft2_invalid(self, device):
func(a, dim=(2, 3))

c = torch.complex(a, a)
with self.assertRaisesRegex(RuntimeError, "rfftn expects a real input"):
with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"):
torch.fft.rfft2(c)

# Helper functions
Expand Down

0 comments on commit e42964b

Please sign in to comment.