Skip to content

Commit

Permalink
Make stft (temporarily) warn (#50102)
Browse files Browse the repository at this point in the history
Summary:
When continuing the deprecation process for stft it was made to throw an error when `use_complex` was not explicitly set by the user. Unfortunately this PR missed a model relying on the historic stft functionality. Before re-enabling the error we'll need to write an upgrader for that model.

This PR turns the error back into a warning to allow that model to continue running as before.

Pull Request resolved: #50102

Reviewed By: ngimel

Differential Revision: D25784325

Pulled By: mruberry

fbshipit-source-id: 825fb38af39b423ce11b376ad3c4a8b21c410b95
  • Loading branch information
Mike Ruberry authored and facebook-github-bot committed Jan 5, 2021
1 parent 4a6c178 commit 5e1c8f2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
22 changes: 12 additions & 10 deletions aten/src/ATen/native/SpectralOps.cpp
Expand Up @@ -469,18 +469,20 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
const bool return_complex = return_complexOpt.value_or(
self.is_complex() || (window.defined() && window.is_complex()));
if (!return_complex) {
TORCH_CHECK(return_complexOpt.has_value(),
"stft requires the return_complex parameter be given for real inputs."
"You should pass return_complex=True to opt-in to complex dtype returns "
"(which will be required in a future pytorch release). "
if (!return_complexOpt.has_value()) {
TORCH_WARN_ONCE(
"stft will soon require the return_complex parameter be given for real inputs, "
"and will further require that return_complex=True in a future PyTorch release."
);
}

TORCH_WARN_ONCE(
"stft with return_complex=False is deprecated. In a future pytorch "
"release, stft will return complex tensors for all inputs, and "
"return_complex=False will raise an error.\n"
"Note: you can still call torch.view_as_real on the complex output to "
"recover the old return format.");

// TORCH_WARN_ONCE(
// "stft with return_complex=False is deprecated. In a future pytorch "
// "release, stft will return complex tensors for all inputs, and "
// "return_complex=False will raise an error.\n"
// "Note: you can still call torch.view_as_real on the complex output to "
// "recover the old return format.");
}

if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
Expand Down
6 changes: 4 additions & 2 deletions test/test_spectral_ops.py
Expand Up @@ -1066,10 +1066,12 @@ def test_complex_stft_onesided(self, device):
with self.assertRaisesRegex(RuntimeError, 'complex'):
x.stft(10, pad_mode='constant', onesided=True)

# stft is currently warning that it requires return-complex while an upgrader is written
def test_stft_requires_complex(self, device):
x = torch.rand(100)
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
y = x.stft(10, pad_mode='constant')
y = x.stft(10, pad_mode='constant')
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
# y = x.stft(10, pad_mode='constant')

@skipCUDAIfRocm
@skipCPUIfNoMkl
Expand Down

0 comments on commit 5e1c8f2

Please sign in to comment.