Skip to content

Commit

Permalink
Add Inverse Short Time Fourier Transform in ATen native (#35569)
Browse files Browse the repository at this point in the history
Summary:
Ported `torchaudio`'s implementation (test, and documentation as well) to ATen.

Note
 - Batch packing/unpacking is performed in Python. ATen implementation expects 4D input tensor.
 - The way `hop_length` is initialized in the same way as `stft` implementation. [The Torchaudio's version tried to mimic the same behavior but slightly different](https://github.com/pytorch/audio/blob/7da61a4beeec7b7ff9ff5f1532b2adf99220a9b1/torchaudio/functional.py#L152-L157).

Closes #34827
Relates #3775
Pull Request resolved: #35569

Differential Revision: D21178090

Pulled By: mthrok

fbshipit-source-id: 2701a8b241a36a6fb1b740c2fb2b07cb938185d4
  • Loading branch information
mthrok authored and facebook-github-bot committed Apr 24, 2020
1 parent 20328f6 commit 5a27ec0
Show file tree
Hide file tree
Showing 14 changed files with 465 additions and 3 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -393,6 +393,7 @@ _(aten, is_set_to) \
_(aten, is_signed) \
_(aten, is_sparse) \
_(aten, isclose) \
_(aten, istft) \
_(aten, kl_div) \
_(aten, kl_div_backward) \
_(aten, kthvalue) \
Expand Down
136 changes: 136 additions & 0 deletions aten/src/ATen/native/SpectralOps.cpp
Expand Up @@ -257,4 +257,140 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
}
}

Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool center, const bool normalized, const bool onesided,
const optional<int64_t> lengthOpt) {
#define REPR(SS) \
SS << "istft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
<< ", window="; \
if (window.defined()) { \
SS << window.toString() << "{" << window.sizes() << "}"; \
} else { \
SS << "None"; \
} \
SS << ", center=" << center << ", normalized=" << normalized << ", onesided=" << onesided << ", length="; \
if (lengthOpt.has_value()) { \
SS << lengthOpt.value(); \
} else { \
SS << "None"; \
} \
SS << ")"

// default_init hop_length and win_length
const auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
const auto win_length = win_lengthOpt.value_or(n_fft);

const auto input_dim = self.dim();
const auto n_frames = self.size(-2);
const auto fft_size = self.size(-3);

const auto expected_output_signal_len = n_fft + hop_length * (n_frames - 1);

const auto options = at::device(self.device()).dtype(self.dtype());
if (self.numel() == 0) {
std::ostringstream ss;
REPR(ss) << ": input tensor cannot be empty.";
AT_ERROR(ss.str());
}
if (input_dim != 3 && input_dim != 4) {
std::ostringstream ss;
REPR(ss) << ": expected a tensor with 3 or 4 dimensions, but got " << input_dim;
AT_ERROR(ss.str());
}
if (self.size(-1) != 2) {
std::ostringstream ss;
REPR(ss) << ": expected the last dimension to be 2 (corresponding to real and imaginary parts), but got " << self.size(-1);
AT_ERROR(ss.str());
}

if (onesided) {
if (n_fft / 2 + 1 != fft_size) {
std::ostringstream ss;
REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onsided=True, but got " << fft_size;
AT_ERROR(ss.str());
}
} else {
if (n_fft != fft_size) {
std::ostringstream ss;
REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onsided=False, but got " << fft_size;
AT_ERROR(ss.str());
}
}

if (!(0 < hop_length && hop_length <= win_length)) {
std::ostringstream ss;
REPR(ss) << ": expected 0 < hop_length <= win_length";
AT_ERROR(ss.str());
}

if (!(0 < win_length && win_length <= n_fft)) {
std::ostringstream ss;
REPR(ss) << ": expected 0 < win_length <= n_fft";
AT_ERROR(ss.str());
}
if (window.defined()) {
if (window.dim() != 1 || window.size(0) != win_length) {
std::ostringstream ss;
REPR(ss) << ": Invalid window shape. window has to be 1D and length of `win_length`";
AT_ERROR(ss.str());
}
}

Tensor window_tmp = window.defined() ? window : at::ones({win_length,}, options);
if (win_length != n_fft) {
// center window by padding zeros on right and left side
int64_t left = (n_fft - win_length) / 2;
window_tmp = at::constant_pad_nd(window_tmp, {left, n_fft - win_length - left}, 0);
TORCH_INTERNAL_ASSERT(window_tmp.size(0) == n_fft);
}

Tensor input = self;
if (input_dim == 3) {
input = input.unsqueeze(0);
}

input = input.transpose(1, 2); // size: (channel, n_frames, fft_size, 2)
input = at::native::irfft(input, 1, normalized, onesided, {n_fft, }); // size: (channel, n_frames, n_fft)
TORCH_INTERNAL_ASSERT(input.size(2) == n_fft);

Tensor y_tmp = input * window_tmp.view({1, 1, n_fft}); // size: (channel, n_frames, n_fft)
y_tmp = y_tmp.transpose(1, 2); // size: (channel, n_fft, frame)

const Tensor eye = at::native::eye(n_fft, options).unsqueeze(1);
Tensor y = at::conv_transpose1d(y_tmp, eye,
/*bias*/ Tensor(),
/*stride*/ {hop_length,},
/*padding*/{0,}); // size: (channel, n_frames, n_fft)
window_tmp = window_tmp.pow(2).view({n_fft, 1}).repeat({1, n_frames}).unsqueeze(0); // size: (1, n_fft, n_frames)
Tensor window_envelop = at::conv_transpose1d(window_tmp, eye,
/*bias*/ Tensor(),
/*stride*/ {hop_length, },
/*padding*/{0, }); // size: (1, 1, expected_output_signal_len)
TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(2));
TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(2));

// We need to trim the front padding away if centered
const auto start = center ? n_fft / 2 : 0;
const auto end = lengthOpt.has_value()? start + lengthOpt.value() : - n_fft / 2;

y = y.slice(2, start, end, 1);
window_envelop = window_envelop.slice(2, start, end, 1);
const auto window_envelop_lowest = window_envelop.abs().min().item().toDouble();
if (window_envelop_lowest < 1e-11) {
std::ostringstream ss;
REPR(ss) << "window overlap add min: " << window_envelop_lowest;
AT_ERROR(ss.str());
}

y = (y / window_envelop).squeeze(1); // size: (channel, expected_output_signal_len)
if (input_dim == 3) {
y = y.squeeze(0);
}
return y;

#undef REPR
}

}} // at::native
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2683,6 +2683,9 @@
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool onesided=True) -> Tensor
variants: function, method

- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool onesided=True, int? length=None) -> Tensor
variants: function, method

- func: stride.int(Tensor self, int dim) -> int
use_c10_dispatcher: full
variants: function, method
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Expand Up @@ -323,6 +323,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: is_shared
.. automethod:: is_signed
.. autoattribute:: is_sparse
.. automethod:: istft
.. automethod:: item
.. automethod:: kthvalue
.. automethod:: le
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -305,6 +305,7 @@ Spectral Ops
.. autofunction:: rfft
.. autofunction:: irfft
.. autofunction:: stft
.. autofunction:: istft
.. autofunction:: bartlett_window
.. autofunction:: blackman_window
.. autofunction:: hamming_window
Expand Down
11 changes: 9 additions & 2 deletions test/test_jit.py
Expand Up @@ -9942,12 +9942,19 @@ def test_pack_unpack_state(self):

@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
def test_torch_functional(self):
def foo(input, n_fft):
def stft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.stft(input, n_fft)

inps = (torch.randn(10), 7)
self.assertEqual(foo(*inps), torch.jit.script(foo)(*inps))
self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps))

def istft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.istft(input, n_fft)

inps2 = (torch.stft(*inps), inps[1])
self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2))

def lu(x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
Expand Down

0 comments on commit 5a27ec0

Please sign in to comment.