Skip to content

Commit

Permalink
stft: Change require_complex warning to an error
Browse files Browse the repository at this point in the history
ghstack-source-id: 0c3cabe6a4bbd883273aa888e0bdb94a1687e1bc
Pull Request resolved: pytorch#49022
  • Loading branch information
peterbell10 committed Dec 9, 2020
1 parent b6f210a commit 743aee4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 16 deletions.
12 changes: 6 additions & 6 deletions aten/src/ATen/native/SpectralOps.cpp
Expand Up @@ -468,12 +468,12 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
auto win_length = win_lengthOpt.value_or(n_fft);
const bool return_complex = return_complexOpt.value_or(
self.is_complex() || (window.defined() && window.is_complex()));
if (!return_complexOpt && !return_complex) {
TORCH_WARN_ONCE("stft will require the return_complex parameter be explicitly "
" specified in a future PyTorch release. Use return_complex=False "
" to preserve the current behavior or return_complex=True to return "
" a complex output.");
}
TORCH_CHECK(
return_complexOpt.has_value() || return_complex,
"stft requires the return_complex parameter be explicitly "
"specified for real inputs. Use return_complex=True to return "
"a complex-valued tensor, or return_complex=True to return "
"a real-valued tensor with an extra complex dimension.");

if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
std::ostringstream ss;
Expand Down
6 changes: 3 additions & 3 deletions test/test_jit.py
Expand Up @@ -8830,7 +8830,7 @@ def test_pack_unpack_state(self):
def test_torch_functional(self):
def stft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.stft(input, n_fft)
return torch.stft(input, n_fft, return_complex=True)

inps = (torch.randn(10), 7)
self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps))
Expand All @@ -8839,8 +8839,8 @@ def istft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.istft(input, n_fft)

inps2 = (torch.stft(*inps), inps[1])
self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2))
inps2 = (stft(*inps), inps[1])
self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))

def lu(x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
Expand Down
15 changes: 11 additions & 4 deletions test/test_spectral_ops.py
Expand Up @@ -843,7 +843,8 @@ def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
else:
window = None
if expected_error is None:
result = x.stft(n_fft, hop_length, win_length, window, center=center)
result = x.stft(n_fft, hop_length, win_length, window,
center=center, return_complex=False)
# NB: librosa defaults to np.complex64 output, no matter what
# the input dtype
ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
Expand Down Expand Up @@ -1055,7 +1056,8 @@ def test_complex_stft_onesided(self, device):
with self.assertRaisesRegex(RuntimeError, 'complex'):
x.stft(10, window=window, pad_mode='constant', onesided=True)
else:
y = x.stft(10, window=window, pad_mode='constant', onesided=True)
y = x.stft(10, window=window, pad_mode='constant', onesided=True,
return_complex=False)
self.assertEqual(y.dtype, torch.double)
self.assertEqual(y.size(), (6, 51, 2))

Expand All @@ -1064,6 +1066,11 @@ def test_complex_stft_onesided(self, device):
with self.assertRaisesRegex(RuntimeError, 'complex'):
x.stft(10, pad_mode='constant', onesided=True)

def test_stft_requires_complex(self, device):
x = torch.rand(100)
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
y = x.stft(10, pad_mode='constant')

@skipCUDAIfRocm
@skipCPUIfNoMkl
def test_fft_input_modification(self, device):
Expand Down Expand Up @@ -1091,7 +1098,7 @@ def test_fft_input_modification(self, device):
def test_istft_round_trip_simple_cases(self, device, dtype):
"""stft -> istft should recover the original signale"""
def _test(input, n_fft, length):
stft = torch.stft(input, n_fft=n_fft)
stft = torch.stft(input, n_fft=n_fft, return_complex=False)
inverse = torch.istft(stft, n_fft=n_fft, length=length)
self.assertEqual(input, inverse, exact_dtype=True)

Expand All @@ -1113,7 +1120,7 @@ def _test_istft_is_inverse_of_stft(stft_kwargs):
for sizes in data_sizes:
for i in range(num_trials):
original = torch.randn(*sizes, dtype=dtype, device=device)
stft = torch.stft(original, **stft_kwargs)
stft = torch.stft(original, return_complex=False, **stft_kwargs)
inversed = torch.istft(stft, length=original.size(1), **istft_kwargs)

# trim the original for case when constructed signal is shorter than original
Expand Down
7 changes: 4 additions & 3 deletions torch/functional.py
Expand Up @@ -464,9 +464,10 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
r"""Short-time Fourier transform (STFT).
.. warning::
Setting :attr:`return_complex` explicitly will be required in a future
PyTorch release. Set it to False to preserve the current behavior or
True to return a complex output.
From version 1.8.0, :attr:`return_complex` must be given explicitly for
real inputs. Set to True to return a complex output, or False to
preserve the legacy behavior of returning a real tensor with an extra
last dimension for the real and imaginary components.
The STFT computes the Fourier transform of short overlapping windows of the
input. This giving frequency components of the signal as they change over
Expand Down

0 comments on commit 743aee4

Please sign in to comment.