diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index c8eb3cc99a01..7f5211be9095 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -469,18 +469,20 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop const bool return_complex = return_complexOpt.value_or( self.is_complex() || (window.defined() && window.is_complex())); if (!return_complex) { - TORCH_CHECK(return_complexOpt.has_value(), - "stft requires the return_complex parameter be given for real inputs." - "You should pass return_complex=True to opt-in to complex dtype returns " - "(which will be required in a future pytorch release). " + if (!return_complexOpt.has_value()) { + TORCH_WARN_ONCE( + "stft will soon require the return_complex parameter be given for real inputs, " + "and will further require that return_complex=True in a future PyTorch release." ); + } - TORCH_WARN_ONCE( - "stft with return_complex=False is deprecated. In a future pytorch " - "release, stft will return complex tensors for all inputs, and " - "return_complex=False will raise an error.\n" - "Note: you can still call torch.view_as_real on the complex output to " - "recover the old return format."); + + // TORCH_WARN_ONCE( + // "stft with return_complex=False is deprecated. In a future pytorch " + // "release, stft will return complex tensors for all inputs, and " + // "return_complex=False will raise an error.\n" + // "Note: you can still call torch.view_as_real on the complex output to " + // "recover the old return format."); } if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) { diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 6192d6c4d6b6..9082668e8596 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -1066,10 +1066,12 @@ def test_complex_stft_onesided(self, device): with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, pad_mode='constant', onesided=True) + # stft is currently warning that it requires return-complex while an upgrader is written def test_stft_requires_complex(self, device): x = torch.rand(100) - with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): - y = x.stft(10, pad_mode='constant') + y = x.stft(10, pad_mode='constant') + # with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): + # y = x.stft(10, pad_mode='constant') @skipCUDAIfRocm @skipCPUIfNoMkl