diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 4ae2ee326b88..dc680c382ab3 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -468,12 +468,12 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop auto win_length = win_lengthOpt.value_or(n_fft); const bool return_complex = return_complexOpt.value_or( self.is_complex() || (window.defined() && window.is_complex())); - if (!return_complexOpt && !return_complex) { - TORCH_WARN_ONCE("stft will require the return_complex parameter be explicitly " - " specified in a future PyTorch release. Use return_complex=False " - " to preserve the current behavior or return_complex=True to return " - " a complex output."); - } + TORCH_CHECK( + return_complexOpt.has_value() || return_complex, + "stft requires the return_complex parameter be explicitly " + "specified for real inputs. Use return_complex=True to return " + "a complex-valued tensor, or return_complex=True to return " + "a real-valued tensor with an extra complex dimension."); if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) { std::ostringstream ss; diff --git a/test/test_jit.py b/test/test_jit.py index c85fcbd19747..836066a7f84b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8830,7 +8830,7 @@ def test_pack_unpack_state(self): def test_torch_functional(self): def stft(input, n_fft): # type: (Tensor, int) -> Tensor - return torch.stft(input, n_fft) + return torch.stft(input, n_fft, return_complex=True) inps = (torch.randn(10), 7) self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps)) @@ -8839,8 +8839,8 @@ def istft(input, n_fft): # type: (Tensor, int) -> Tensor return torch.istft(input, n_fft) - inps2 = (torch.stft(*inps), inps[1]) - self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2)) + inps2 = (stft(*inps), inps[1]) + self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2)) def lu(x): # type: (Tensor) -> Tuple[Tensor, Tensor] diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 04365a5828d4..8a35e71c035c 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -843,7 +843,8 @@ def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, else: window = None if expected_error is None: - result = x.stft(n_fft, hop_length, win_length, window, center=center) + result = x.stft(n_fft, hop_length, win_length, window, + center=center, return_complex=False) # NB: librosa defaults to np.complex64 output, no matter what # the input dtype ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) @@ -1055,7 +1056,8 @@ def test_complex_stft_onesided(self, device): with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, window=window, pad_mode='constant', onesided=True) else: - y = x.stft(10, window=window, pad_mode='constant', onesided=True) + y = x.stft(10, window=window, pad_mode='constant', onesided=True, + return_complex=False) self.assertEqual(y.dtype, torch.double) self.assertEqual(y.size(), (6, 51, 2)) @@ -1064,6 +1066,11 @@ def test_complex_stft_onesided(self, device): with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, pad_mode='constant', onesided=True) + def test_stft_requires_complex(self, device): + x = torch.rand(100) + with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): + y = x.stft(10, pad_mode='constant') + @skipCUDAIfRocm @skipCPUIfNoMkl def test_fft_input_modification(self, device): @@ -1091,7 +1098,7 @@ def test_fft_input_modification(self, device): def test_istft_round_trip_simple_cases(self, device, dtype): """stft -> istft should recover the original signale""" def _test(input, n_fft, length): - stft = torch.stft(input, n_fft=n_fft) + stft = torch.stft(input, n_fft=n_fft, return_complex=False) inverse = torch.istft(stft, n_fft=n_fft, length=length) self.assertEqual(input, inverse, exact_dtype=True) @@ -1113,7 +1120,7 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): for sizes in data_sizes: for i in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) - stft = torch.stft(original, **stft_kwargs) + stft = torch.stft(original, return_complex=False, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) # trim the original for case when constructed signal is shorter than original diff --git a/torch/functional.py b/torch/functional.py index cdecee2d3e61..cbdbdae66823 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -464,9 +464,10 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, r"""Short-time Fourier transform (STFT). .. warning:: - Setting :attr:`return_complex` explicitly will be required in a future - PyTorch release. Set it to False to preserve the current behavior or - True to return a complex output. + From version 1.8.0, :attr:`return_complex` must be given explicitly for + real inputs. Set to True to return a complex output, or False to + preserve the legacy behavior of returning a real tensor with an extra + last dimension for the real and imaginary components. The STFT computes the Fourier transform of short overlapping windows of the input. This giving frequency components of the signal as they change over