Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stft: Change require_complex warning to an error #49022

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 13 additions & 5 deletions aten/src/ATen/native/SpectralOps.cpp
Expand Up @@ -468,11 +468,19 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
auto win_length = win_lengthOpt.value_or(n_fft);
const bool return_complex = return_complexOpt.value_or(
self.is_complex() || (window.defined() && window.is_complex()));
if (!return_complexOpt && !return_complex) {
TORCH_WARN_ONCE("stft will require the return_complex parameter be explicitly "
" specified in a future PyTorch release. Use return_complex=False "
" to preserve the current behavior or return_complex=True to return "
" a complex output.");
if (!return_complex) {
TORCH_CHECK(return_complexOpt.has_value(),
"stft requires the return_complex parameter be given for real inputs."
"You should pass return_complex=True to opt-in to complex dtype returns "
"(which will be required in a future pytorch release). "
);

TORCH_WARN_ONCE(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning needs to be gated on the value actually being false, though. We don't want people who set return_complex=True to see it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait... I understand how this will work now. OK. That's cool.

So the conditional only triggers if return_complex is unspecified or False and it needs to be specified. Then the check will capture the case where it's unspecified. If the check is hit this function will stop (since it threw an exception). If not, then the value is specified but it's False, so the warning is thrown.

Cool.

"stft with return_complex=False is deprecated. In a future pytorch "
"release, stft will return complex tensors for all inputs, and "
"return_complex=False will raise an error.\n"
"Note: you can still call torch.view_as_real on the complex output to "
"recover the old return format.");
}

if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
Expand Down
6 changes: 3 additions & 3 deletions test/test_jit.py
Expand Up @@ -8780,7 +8780,7 @@ def test_pack_unpack_state(self):
def test_torch_functional(self):
def stft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.stft(input, n_fft)
return torch.stft(input, n_fft, return_complex=True)

inps = (torch.randn(10), 7)
self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps))
Expand All @@ -8789,8 +8789,8 @@ def istft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.istft(input, n_fft)

inps2 = (torch.stft(*inps), inps[1])
self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2))
inps2 = (stft(*inps), inps[1])
self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))

def lu(x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
Expand Down
23 changes: 15 additions & 8 deletions test/test_spectral_ops.py
Expand Up @@ -843,7 +843,9 @@ def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
else:
window = None
if expected_error is None:
result = x.stft(n_fft, hop_length, win_length, window, center=center)
with self.maybeWarnsRegex(UserWarning, "stft with return_complex=False"):
result = x.stft(n_fft, hop_length, win_length, window,
center=center, return_complex=False)
# NB: librosa defaults to np.complex64 output, no matter what
# the input dtype
ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
Expand Down Expand Up @@ -1055,15 +1057,20 @@ def test_complex_stft_onesided(self, device):
with self.assertRaisesRegex(RuntimeError, 'complex'):
x.stft(10, window=window, pad_mode='constant', onesided=True)
else:
y = x.stft(10, window=window, pad_mode='constant', onesided=True)
self.assertEqual(y.dtype, torch.double)
self.assertEqual(y.size(), (6, 51, 2))
y = x.stft(10, window=window, pad_mode='constant', onesided=True,
return_complex=True)
self.assertEqual(y.dtype, torch.cdouble)
self.assertEqual(y.size(), (6, 51))

y = torch.rand(100, device=device, dtype=torch.double)
window = torch.randn(10, device=device, dtype=torch.cdouble)
x = torch.rand(100, device=device, dtype=torch.cdouble)
with self.assertRaisesRegex(RuntimeError, 'complex'):
x.stft(10, pad_mode='constant', onesided=True)

def test_stft_requires_complex(self, device):
x = torch.rand(100)
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
y = x.stft(10, pad_mode='constant')

@skipCUDAIfRocm
@skipCPUIfNoMkl
def test_fft_input_modification(self, device):
Expand Down Expand Up @@ -1091,7 +1098,7 @@ def test_fft_input_modification(self, device):
def test_istft_round_trip_simple_cases(self, device, dtype):
"""stft -> istft should recover the original signale"""
def _test(input, n_fft, length):
stft = torch.stft(input, n_fft=n_fft)
stft = torch.stft(input, n_fft=n_fft, return_complex=True)
inverse = torch.istft(stft, n_fft=n_fft, length=length)
self.assertEqual(input, inverse, exact_dtype=True)

Expand All @@ -1113,7 +1120,7 @@ def _test_istft_is_inverse_of_stft(stft_kwargs):
for sizes in data_sizes:
for i in range(num_trials):
original = torch.randn(*sizes, dtype=dtype, device=device)
stft = torch.stft(original, **stft_kwargs)
stft = torch.stft(original, return_complex=True, **stft_kwargs)
inversed = torch.istft(stft, length=original.size(1), **istft_kwargs)

# trim the original for case when constructed signal is shorter than original
Expand Down
10 changes: 7 additions & 3 deletions torch/functional.py
Expand Up @@ -433,9 +433,13 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
r"""Short-time Fourier transform (STFT).

.. warning::
Setting :attr:`return_complex` explicitly will be required in a future
PyTorch release. Set it to False to preserve the current behavior or
True to return a complex output.
From version 1.8.0, :attr:`return_complex` must always be given
explicitly for real inputs and `return_complex=False` has been
deprecated. Strongly prefer `return_complex=True` as in a future
pytorch release, this function will only return complex tensors.

Note that :func:`torch.view_as_real` can be used to recover a real
tensor with an extra last dimension for real and imaginary components.

The STFT computes the Fourier transform of short overlapping windows of the
input. This giving frequency components of the signal as they change over
Expand Down