From 5a27ec09b8a2e8f834a3a5a3371561d2e02ee669 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 24 Apr 2020 12:12:33 -0700 Subject: [PATCH] Add Inverse Short Time Fourier Transform in ATen native (#35569) Summary: Ported `torchaudio`'s implementation (test, and documentation as well) to ATen. Note - Batch packing/unpacking is performed in Python. ATen implementation expects 4D input tensor. - The way `hop_length` is initialized in the same way as `stft` implementation. [The Torchaudio's version tried to mimic the same behavior but slightly different](https://github.com/pytorch/audio/blob/7da61a4beeec7b7ff9ff5f1532b2adf99220a9b1/torchaudio/functional.py#L152-L157). Closes https://github.com/pytorch/pytorch/issues/34827 Relates https://github.com/pytorch/pytorch/issues/3775 Pull Request resolved: https://github.com/pytorch/pytorch/pull/35569 Differential Revision: D21178090 Pulled By: mthrok fbshipit-source-id: 2701a8b241a36a6fb1b740c2fb2b07cb938185d4 --- aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/SpectralOps.cpp | 136 ++++++++++++ aten/src/ATen/native/native_functions.yaml | 3 + docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_jit.py | 11 +- test/test_torch.py | 230 +++++++++++++++++++++ tools/pyi/gen_pyi.py | 1 + torch/__init__.pyi.in | 2 + torch/_overrides.py | 2 + torch/_tensor_docs.py | 8 + torch/functional.py | 63 ++++++ torch/jit/_builtins.py | 3 +- torch/tensor.py | 6 + 14 files changed, 465 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 742f4f611a57..80970cc993e7 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -393,6 +393,7 @@ _(aten, is_set_to) \ _(aten, is_signed) \ _(aten, is_sparse) \ _(aten, isclose) \ +_(aten, istft) \ _(aten, kl_div) \ _(aten, kl_div_backward) \ _(aten, kthvalue) \ diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index c801affc40d1..a7e6dfd38626 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -257,4 +257,140 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop } } +Tensor istft(const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, + const optional win_lengthOpt, const Tensor& window, + const bool center, const bool normalized, const bool onesided, + const optional lengthOpt) { + #define REPR(SS) \ + SS << "istft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \ + << ", hop_length=" << hop_length << ", win_length=" << win_length \ + << ", window="; \ + if (window.defined()) { \ + SS << window.toString() << "{" << window.sizes() << "}"; \ + } else { \ + SS << "None"; \ + } \ + SS << ", center=" << center << ", normalized=" << normalized << ", onesided=" << onesided << ", length="; \ + if (lengthOpt.has_value()) { \ + SS << lengthOpt.value(); \ + } else { \ + SS << "None"; \ + } \ + SS << ")" + + // default_init hop_length and win_length + const auto hop_length = hop_lengthOpt.value_or(n_fft >> 2); + const auto win_length = win_lengthOpt.value_or(n_fft); + + const auto input_dim = self.dim(); + const auto n_frames = self.size(-2); + const auto fft_size = self.size(-3); + + const auto expected_output_signal_len = n_fft + hop_length * (n_frames - 1); + + const auto options = at::device(self.device()).dtype(self.dtype()); + if (self.numel() == 0) { + std::ostringstream ss; + REPR(ss) << ": input tensor cannot be empty."; + AT_ERROR(ss.str()); + } + if (input_dim != 3 && input_dim != 4) { + std::ostringstream ss; + REPR(ss) << ": expected a tensor with 3 or 4 dimensions, but got " << input_dim; + AT_ERROR(ss.str()); + } + if (self.size(-1) != 2) { + std::ostringstream ss; + REPR(ss) << ": expected the last dimension to be 2 (corresponding to real and imaginary parts), but got " << self.size(-1); + AT_ERROR(ss.str()); + } + + if (onesided) { + if (n_fft / 2 + 1 != fft_size) { + std::ostringstream ss; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onsided=True, but got " << fft_size; + AT_ERROR(ss.str()); + } + } else { + if (n_fft != fft_size) { + std::ostringstream ss; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onsided=False, but got " << fft_size; + AT_ERROR(ss.str()); + } + } + + if (!(0 < hop_length && hop_length <= win_length)) { + std::ostringstream ss; + REPR(ss) << ": expected 0 < hop_length <= win_length"; + AT_ERROR(ss.str()); + } + + if (!(0 < win_length && win_length <= n_fft)) { + std::ostringstream ss; + REPR(ss) << ": expected 0 < win_length <= n_fft"; + AT_ERROR(ss.str()); + } + if (window.defined()) { + if (window.dim() != 1 || window.size(0) != win_length) { + std::ostringstream ss; + REPR(ss) << ": Invalid window shape. window has to be 1D and length of `win_length`"; + AT_ERROR(ss.str()); + } + } + + Tensor window_tmp = window.defined() ? window : at::ones({win_length,}, options); + if (win_length != n_fft) { + // center window by padding zeros on right and left side + int64_t left = (n_fft - win_length) / 2; + window_tmp = at::constant_pad_nd(window_tmp, {left, n_fft - win_length - left}, 0); + TORCH_INTERNAL_ASSERT(window_tmp.size(0) == n_fft); + } + + Tensor input = self; + if (input_dim == 3) { + input = input.unsqueeze(0); + } + + input = input.transpose(1, 2); // size: (channel, n_frames, fft_size, 2) + input = at::native::irfft(input, 1, normalized, onesided, {n_fft, }); // size: (channel, n_frames, n_fft) + TORCH_INTERNAL_ASSERT(input.size(2) == n_fft); + + Tensor y_tmp = input * window_tmp.view({1, 1, n_fft}); // size: (channel, n_frames, n_fft) + y_tmp = y_tmp.transpose(1, 2); // size: (channel, n_fft, frame) + + const Tensor eye = at::native::eye(n_fft, options).unsqueeze(1); + Tensor y = at::conv_transpose1d(y_tmp, eye, + /*bias*/ Tensor(), + /*stride*/ {hop_length,}, + /*padding*/{0,}); // size: (channel, n_frames, n_fft) + window_tmp = window_tmp.pow(2).view({n_fft, 1}).repeat({1, n_frames}).unsqueeze(0); // size: (1, n_fft, n_frames) + Tensor window_envelop = at::conv_transpose1d(window_tmp, eye, + /*bias*/ Tensor(), + /*stride*/ {hop_length, }, + /*padding*/{0, }); // size: (1, 1, expected_output_signal_len) + TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(2)); + TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(2)); + + // We need to trim the front padding away if centered + const auto start = center ? n_fft / 2 : 0; + const auto end = lengthOpt.has_value()? start + lengthOpt.value() : - n_fft / 2; + + y = y.slice(2, start, end, 1); + window_envelop = window_envelop.slice(2, start, end, 1); + const auto window_envelop_lowest = window_envelop.abs().min().item().toDouble(); + if (window_envelop_lowest < 1e-11) { + std::ostringstream ss; + REPR(ss) << "window overlap add min: " << window_envelop_lowest; + AT_ERROR(ss.str()); + } + + y = (y / window_envelop).squeeze(1); // size: (channel, expected_output_signal_len) + if (input_dim == 3) { + y = y.squeeze(0); + } + return y; + + #undef REPR +} + }} // at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index dfda9ae6ce0c..0d227929e489 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2683,6 +2683,9 @@ - func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool onesided=True) -> Tensor variants: function, method +- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool onesided=True, int? length=None) -> Tensor + variants: function, method + - func: stride.int(Tensor self, int dim) -> int use_c10_dispatcher: full variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 3e3f1382f0eb..98cb2effeb59 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -323,6 +323,7 @@ view of a storage and defines numeric operations on it. .. automethod:: is_shared .. automethod:: is_signed .. autoattribute:: is_sparse + .. automethod:: istft .. automethod:: item .. automethod:: kthvalue .. automethod:: le diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 6224fe2f54a0..b17387ad6384 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -305,6 +305,7 @@ Spectral Ops .. autofunction:: rfft .. autofunction:: irfft .. autofunction:: stft +.. autofunction:: istft .. autofunction:: bartlett_window .. autofunction:: blackman_window .. autofunction:: hamming_window diff --git a/test/test_jit.py b/test/test_jit.py index d293f88f0d3b..b8190e082984 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9942,12 +9942,19 @@ def test_pack_unpack_state(self): @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") def test_torch_functional(self): - def foo(input, n_fft): + def stft(input, n_fft): # type: (Tensor, int) -> Tensor return torch.stft(input, n_fft) inps = (torch.randn(10), 7) - self.assertEqual(foo(*inps), torch.jit.script(foo)(*inps)) + self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps)) + + def istft(input, n_fft): + # type: (Tensor, int) -> Tensor + return torch.istft(input, n_fft) + + inps2 = (torch.stft(*inps), inps[1]) + self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2)) def lu(x): # type: (Tensor) -> Tuple[Tensor, Tensor] diff --git a/test/test_torch.py b/test/test_torch.py index 596def69909e..a806a38782ca 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -12392,6 +12392,236 @@ def test_fft_input_modification(self, device): _ = torch.irfft(half_spectrum_copy, 2, signal_sizes=(2, 2)) self.assertEqual(half_spectrum, half_spectrum_copy) + @onlyOnCPUAndCUDA + @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") + @dtypes(torch.double) + def test_istft_round_trip_simple_cases(self, device, dtype): + """stft -> istft should recover the original signale""" + def _test(input, n_fft, length): + stft = torch.stft(input, n_fft=n_fft) + inverse = torch.istft(stft, n_fft=n_fft, length=length) + self.assertEqual(input, inverse, exact_dtype=True) + + _test(torch.ones(4, dtype=dtype, device=device), 4, 4) + _test(torch.zeros(4, dtype=dtype, device=device), 4, 4) + + @onlyOnCPUAndCUDA + @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") + @dtypes(torch.double) + def test_istft_round_trip_various_params(self, device, dtype): + """stft -> istft should recover the original signale""" + def _test_istft_is_inverse_of_stft(stft_kwargs): + # generates a random sound signal for each tril and then does the stft/istft + # operation to check whether we can reconstruct signal + data_sizes = [(2, 20), (3, 15), (4, 10)] + num_trials = 100 + istft_kwargs = stft_kwargs.copy() + del istft_kwargs['pad_mode'] + for sizes in data_sizes: + for i in range(num_trials): + original = torch.randn(*sizes, dtype=dtype, device=device) + stft = torch.stft(original, **stft_kwargs) + inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) + + # trim the original for case when constructed signal is shorter than original + original = original[..., :inversed.size(-1)] + self.assertEqual( + inversed, original, 'istft comparison against original', atol=7e-6, exact_dtype=True) + + patterns = [ + # hann_window, centered, normalized, onesided + { + 'n_fft': 12, + 'hop_length': 4, + 'win_length': 12, + 'window': torch.hann_window(12, dtype=dtype, device=device), + 'center': True, + 'pad_mode': 'reflect', + 'normalized': True, + 'onesided': True, + }, + # hann_window, centered, not normalized, not onesided + { + 'n_fft': 12, + 'hop_length': 2, + 'win_length': 8, + 'window': torch.hann_window(8, dtype=dtype, device=device), + 'center': True, + 'pad_mode': 'reflect', + 'normalized': False, + 'onesided': False, + }, + # hamming_window, centered, normalized, not onesided + { + 'n_fft': 15, + 'hop_length': 3, + 'win_length': 11, + 'window': torch.hamming_window(11, dtype=dtype, device=device), + 'center': True, + 'pad_mode': 'constant', + 'normalized': True, + 'onesided': False, + }, + # hamming_window, not centered, not normalized, onesided + # window same size as n_fft + { + 'n_fft': 5, + 'hop_length': 2, + 'win_length': 5, + 'window': torch.hamming_window(5, dtype=dtype, device=device), + 'center': False, + 'pad_mode': 'constant', + 'normalized': False, + 'onesided': True, + }, + # hamming_window, not centered, not normalized, not onesided + # window same size as n_fft + { + 'n_fft': 3, + 'hop_length': 2, + 'win_length': 3, + 'window': torch.hamming_window(3, dtype=dtype, device=device), + 'center': False, + 'pad_mode': 'reflect', + 'normalized': False, + 'onesided': False, + }, + ] + for i, pattern in enumerate(patterns): + _test_istft_is_inverse_of_stft(pattern) + + @onlyOnCPUAndCUDA + def test_istft_throws(self, device): + """istft should throw exception for invalid parameters""" + stft = torch.zeros((3, 5, 2), device=device) + # the window is size 1 but it hops 20 so there is a gap which throw an error + self.assertRaises( + RuntimeError, torch.istft, stft, n_fft=4, + hop_length=20, win_length=1, window=torch.ones(1)) + # A window of zeros does not meet NOLA + invalid_window = torch.zeros(4, device=device) + self.assertRaises( + RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window) + # Input cannot be empty + self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2) + self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2) + + @onlyOnCPUAndCUDA + @dtypes(torch.double) + def test_istft_of_sine(self, device, dtype): + def _test(amplitude, L, n): + # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L + x = torch.arange(2 * L + 1, dtype=dtype) + original = amplitude * torch.sin(2 * math.pi / L * x * n) + # stft = torch.stft(original, L, hop_length=L, win_length=L, + # window=torch.ones(L), center=False, normalized=False) + stft = torch.zeros((L // 2 + 1, 2, 2), dtype=dtype) + stft_largest_val = (amplitude * L) / 2.0 + if n < stft.size(0): + stft[n, :, 1] = -stft_largest_val + + if 0 <= L - n < stft.size(0): + # symmetric about L // 2 + stft[L - n, :, 1] = stft_largest_val + + inverse = torch.istft( + stft, L, hop_length=L, win_length=L, + window=torch.ones(L, dtype=dtype), center=False, normalized=False) + # There is a larger error due to the scaling of amplitude + original = original[..., :inverse.size(-1)] + self.assertEqual(inverse, original, 1e-3) + + _test(amplitude=123, L=5, n=1) + _test(amplitude=150, L=5, n=2) + _test(amplitude=111, L=5, n=3) + _test(amplitude=160, L=7, n=4) + _test(amplitude=145, L=8, n=5) + _test(amplitude=80, L=9, n=6) + _test(amplitude=99, L=10, n=7) + + @onlyOnCPUAndCUDA + @dtypes(torch.double) + def test_istft_linearity(self, device, dtype): + num_trials = 100 + + def _test(data_size, kwargs): + for i in range(num_trials): + tensor1 = torch.randn(data_size, device=device, dtype=dtype) + tensor2 = torch.randn(data_size, device=device, dtype=dtype) + a, b = torch.rand(2, dtype=dtype, device=device) + istft1 = torch.istft(tensor1, **kwargs) + istft2 = torch.istft(tensor2, **kwargs) + istft = a * istft1 + b * istft2 + estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs) + self.assertEqual(istft, estimate, 1e-5) + patterns = [ + # hann_window, centered, normalized, onesided + ( + (2, 7, 7, 2), + { + 'n_fft': 12, + 'window': torch.hann_window(12, device=device, dtype=dtype), + 'center': True, + 'normalized': True, + 'onesided': True, + }, + ), + # hann_window, centered, not normalized, not onesided + ( + (2, 12, 7, 2), + { + 'n_fft': 12, + 'window': torch.hann_window(12, device=device, dtype=dtype), + 'center': True, + 'normalized': False, + 'onesided': False, + }, + ), + # hamming_window, centered, normalized, not onesided + ( + (2, 12, 7, 2), + { + 'n_fft': 12, + 'window': torch.hamming_window(12, device=device, dtype=dtype), + 'center': True, + 'normalized': True, + 'onesided': False, + }, + ), + # hamming_window, not centered, not normalized, onesided + ( + (2, 7, 3, 2), + { + 'n_fft': 12, + 'window': torch.hamming_window(12, device=device, dtype=dtype), + 'center': False, + 'normalized': False, + 'onesided': True, + }, + ) + ] + for data_size, kwargs in patterns: + _test(data_size, kwargs) + + @onlyOnCPUAndCUDA + @skipCUDAIfRocm + def test_batch_istft(self, device): + original = torch.tensor([ + [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] + ], device=device) + + single = original.repeat(1, 1, 1, 1) + multi = original.repeat(4, 1, 1, 1) + + i_original = torch.istft(original, n_fft=4, length=4) + i_single = torch.istft(single, n_fft=4, length=4) + i_multi = torch.istft(multi, n_fft=4, length=4) + + self.assertEqual(i_original.repeat(1, 1), i_single, 1e-6, exact_dtype=True) + self.assertEqual(i_original.repeat(4, 1), i_multi, 1e-6, exact_dtype=True) + @skipCUDAIfRocm def test_blas_empty(self, device): diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index dc99181c0d39..96413dd48b9f 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -75,6 +75,7 @@ 'norm', 'chain_matmul', 'stft', + 'istft', 'tensordot', 'norm', 'split', diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in index 40ea0b9e9ebb..0afb5347e6f1 100644 --- a/torch/__init__.pyi.in +++ b/torch/__init__.pyi.in @@ -133,6 +133,8 @@ class Tensor: def norm(self, p="fro", dim=None, keepdim=False): ... def stft(self, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): ... + def istft(self, n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None): ... def split(self, split_size, dim=0): ... def unique(self, sorted=True, return_inverse=False, dim=None): ... def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ... diff --git a/torch/_overrides.py b/torch/_overrides.py index 1b6e2adc97aa..3778da34c5c7 100644 --- a/torch/_overrides.py +++ b/torch/_overrides.py @@ -331,6 +331,8 @@ def get_testing_overrides(): torch.is_signed: lambda input: -1, torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1, torch.isnan: lambda input: -1, + torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, + normalized=False, onesided=True, length=None: -1), torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1, torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1, diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 0652ccd658b4..0a3d228b66e2 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3455,6 +3455,14 @@ def callable(a, b) -> number See :func:`torch.stft` """) +add_docstr_all('istft', + r""" +istft(n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None) -> Tensor + +See :func:`torch.istft` +""") + add_docstr_all('fft', r""" fft(signal_ndim, normalized=False) -> Tensor diff --git a/torch/functional.py b/torch/functional.py index c2e3a5621d63..b0233154632a 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -18,6 +18,7 @@ 'cdist', 'chain_matmul', 'einsum', + 'istft', 'lu', 'lu_unpack', 'norm', @@ -438,6 +439,68 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, return _VF.stft(input, n_fft, hop_length, win_length, window, normalized, onesided) +def istft(input, n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None): + # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, bool, bool, Optional[int]) -> Tensor + r"""Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`. + It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the + least squares estimation of the original signal. The algorithm will check using the NOLA condition ( + nonzero overlap). + + Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop + created by the summation of all the windows is never zero at certain point in time. Specifically, + :math:`\sum_{t=-\infty}^{\infty} w^2[n-t\times hop\_length] \cancel{=} 0`. + + Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame, + ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False + since the signal isn't padded). + + If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc. + Left padding can be trimmed off exactly because they can be calculated but right padding cannot be + calculated without additional information. + + Example: Suppose the last window is: + ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]`` + + The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation + of right padding. These additional values could be zeros or a reflection of the signal so providing + :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed + (some loss of signal). + + [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform," + IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984. + + Arguments: + input (Tensor): The input tensor. Expected to be output of :func:`~torch.stft`, + either 3D (``fft_size``, ``n_frame``, 2) or 4D (``channel``, ``fft_size``, ``n_frame``, 2). + n_fft (int): Size of Fourier transform + hop_length (Optional[int]): The distance between neighboring sliding window frames. + (Default: ``n_fft // 4``) + win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``) + window (Optional[torch.Tensor]): The optional window function. + (Default: ``torch.ones(win_length)``) + center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is + centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + normalized (bool): Whether the STFT was normalized. (Default: ``False``) + onesided (bool): Whether the STFT is onesided. (Default: ``True``) + length (Optional[int]): The amount to trim the signal by (i.e. the + original signal length). (Default: whole signal) + + Returns: + Tensor: Least squares estimation of the original signal of size (..., signal_length) + """ + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, + length=length) + + return _VF.istft( + input, n_fft, hop_length, win_length, window, center, normalized, onesided, length) + + del torch.unique_dim diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index 98f8639805b0..9630d3f475a5 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -78,6 +78,7 @@ (torch._C._get_tracing_state, "aten::_get_tracing_state"), (warnings.warn, "aten::warn"), (torch._VF.stft, "aten::stft"), + (torch._VF.istft, "aten::istft"), (torch._VF.cdist, "aten::cdist"), (torch._VF.norm, "aten::norm"), (torch._VF.nuclear_norm, "aten::nuclear_norm"), @@ -93,7 +94,7 @@ def _gen_torch_functional_registered_ops(): # but we are currently only able to compile some of the functions. additionally, # some functions directly map to their aten:: implementations. # TODO: add support for more ops - ops = ["stft", "lu", "lu_unpack", "cdist", "norm"] + ops = ["stft", "istft", "lu", "lu_unpack", "cdist", "norm"] return set(getattr(torch.functional, name) for name in ops) _functional_registered_ops = _gen_torch_functional_registered_ops() diff --git a/torch/tensor.py b/torch/tensor.py index f025bdcbf372..ec424d797456 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -347,6 +347,12 @@ def stft(self, n_fft, hop_length=None, win_length=None, window=None, return torch.stft(self, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) + def istft(self, n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None): + r"""See :func:`torch.istft`""" + return torch.istft(self, n_fft, hop_length, win_length, window, center, + normalized, onesided, length) + def resize(self, *sizes): warnings.warn("non-inplace resize is deprecated") from torch.autograd._functions import Resize