From 5c25f8faf3d0d125aa5d642a23b24af2293ade7f Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sun, 20 Dec 2020 14:40:16 -0800 Subject: [PATCH] stft: Change require_complex warning to an error (#49022) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49022 **BC-breaking note**: Previously torch.stft took an optional `return_complex` parameter that indicated whether the output would be a floating point tensor or a complex tensor. By default `return_complex` was False to be consistent with the previous behavior of torch.stft. This PR changes this behavior so `return_complex` is a required argument. **PR Summary**: * **#49022 stft: Change require_complex warning to an error** Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25658906 Pulled By: mruberry fbshipit-source-id: 11932d1102e93f8c7bd3d2d0b2a607fd5036ec5e --- aten/src/ATen/native/SpectralOps.cpp | 18 +++++++++++++----- test/test_jit.py | 6 +++--- test/test_spectral_ops.py | 23 +++++++++++++++-------- torch/functional.py | 10 +++++++--- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 4ae2ee326b88..c8eb3cc99a01 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -468,11 +468,19 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop auto win_length = win_lengthOpt.value_or(n_fft); const bool return_complex = return_complexOpt.value_or( self.is_complex() || (window.defined() && window.is_complex())); - if (!return_complexOpt && !return_complex) { - TORCH_WARN_ONCE("stft will require the return_complex parameter be explicitly " - " specified in a future PyTorch release. Use return_complex=False " - " to preserve the current behavior or return_complex=True to return " - " a complex output."); + if (!return_complex) { + TORCH_CHECK(return_complexOpt.has_value(), + "stft requires the return_complex parameter be given for real inputs." + "You should pass return_complex=True to opt-in to complex dtype returns " + "(which will be required in a future pytorch release). " + ); + + TORCH_WARN_ONCE( + "stft with return_complex=False is deprecated. In a future pytorch " + "release, stft will return complex tensors for all inputs, and " + "return_complex=False will raise an error.\n" + "Note: you can still call torch.view_as_real on the complex output to " + "recover the old return format."); } if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) { diff --git a/test/test_jit.py b/test/test_jit.py index 07da591c2228..f002c86b630c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8874,7 +8874,7 @@ def test_pack_unpack_state(self): def test_torch_functional(self): def stft(input, n_fft): # type: (Tensor, int) -> Tensor - return torch.stft(input, n_fft) + return torch.stft(input, n_fft, return_complex=True) inps = (torch.randn(10), 7) self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps)) @@ -8883,8 +8883,8 @@ def istft(input, n_fft): # type: (Tensor, int) -> Tensor return torch.istft(input, n_fft) - inps2 = (torch.stft(*inps), inps[1]) - self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2)) + inps2 = (stft(*inps), inps[1]) + self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2)) def lu(x): # type: (Tensor) -> Tuple[Tensor, Tensor] diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 04365a5828d4..6192d6c4d6b6 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -843,7 +843,9 @@ def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, else: window = None if expected_error is None: - result = x.stft(n_fft, hop_length, win_length, window, center=center) + with self.maybeWarnsRegex(UserWarning, "stft with return_complex=False"): + result = x.stft(n_fft, hop_length, win_length, window, + center=center, return_complex=False) # NB: librosa defaults to np.complex64 output, no matter what # the input dtype ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) @@ -1055,15 +1057,20 @@ def test_complex_stft_onesided(self, device): with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, window=window, pad_mode='constant', onesided=True) else: - y = x.stft(10, window=window, pad_mode='constant', onesided=True) - self.assertEqual(y.dtype, torch.double) - self.assertEqual(y.size(), (6, 51, 2)) + y = x.stft(10, window=window, pad_mode='constant', onesided=True, + return_complex=True) + self.assertEqual(y.dtype, torch.cdouble) + self.assertEqual(y.size(), (6, 51)) - y = torch.rand(100, device=device, dtype=torch.double) - window = torch.randn(10, device=device, dtype=torch.cdouble) + x = torch.rand(100, device=device, dtype=torch.cdouble) with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, pad_mode='constant', onesided=True) + def test_stft_requires_complex(self, device): + x = torch.rand(100) + with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): + y = x.stft(10, pad_mode='constant') + @skipCUDAIfRocm @skipCPUIfNoMkl def test_fft_input_modification(self, device): @@ -1091,7 +1098,7 @@ def test_fft_input_modification(self, device): def test_istft_round_trip_simple_cases(self, device, dtype): """stft -> istft should recover the original signale""" def _test(input, n_fft, length): - stft = torch.stft(input, n_fft=n_fft) + stft = torch.stft(input, n_fft=n_fft, return_complex=True) inverse = torch.istft(stft, n_fft=n_fft, length=length) self.assertEqual(input, inverse, exact_dtype=True) @@ -1113,7 +1120,7 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): for sizes in data_sizes: for i in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) - stft = torch.stft(original, **stft_kwargs) + stft = torch.stft(original, return_complex=True, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) # trim the original for case when constructed signal is shorter than original diff --git a/torch/functional.py b/torch/functional.py index 25b0c1fb3b19..10fb6b1e41b7 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -464,9 +464,13 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, r"""Short-time Fourier transform (STFT). .. warning:: - Setting :attr:`return_complex` explicitly will be required in a future - PyTorch release. Set it to False to preserve the current behavior or - True to return a complex output. + From version 1.8.0, :attr:`return_complex` must always be given + explicitly for real inputs and `return_complex=False` has been + deprecated. Strongly prefer `return_complex=True` as in a future + pytorch release, this function will only return complex tensors. + + Note that :func:`torch.view_as_real` can be used to recover a real + tensor with an extra last dimension for real and imaginary components. The STFT computes the Fourier transform of short overlapping windows of the input. This giving frequency components of the signal as they change over