Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate inverse short-time Fourier transform from torchaudio #34827

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -303,6 +303,7 @@ Spectral Ops
.. autofunction:: rfft
.. autofunction:: irfft
.. autofunction:: stft
.. autofunction:: istft
.. autofunction:: bartlett_window
.. autofunction:: blackman_window
.. autofunction:: hamming_window
Expand Down
11 changes: 9 additions & 2 deletions test/test_jit.py
Expand Up @@ -11272,12 +11272,19 @@ def test_pack_unpack_state(self):

@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
def test_torch_functional(self):
def foo(input, n_fft):
def stft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.stft(input, n_fft)

inps = (torch.randn(10), 7)
self.assertEqual(foo(*inps), torch.jit.script(foo)(*inps))
self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps))

def istft(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.istft(input, n_fft)

inps2 = (torch.stft(*inps), inps[1])
self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2))

def lu(x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
Expand Down
222 changes: 222 additions & 0 deletions test/test_torch.py
Expand Up @@ -11888,6 +11888,228 @@ def test_fft_input_modification(self, device):
_ = torch.irfft(half_spectrum_copy, 2, signal_sizes=(2, 2))
self.assertEqual(half_spectrum, half_spectrum_copy)

@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
@dtypes(torch.double)
def test_istft_round_trip_simple_cases(self, device, dtype):
"""stft -> istft should recover the original signale"""
def _test(input, n_fft, length):
stft = torch.stft(input, n_fft=n_fft)
inverse = torch.istft(stft, n_fft=n_fft, length=length)
self.assertEqual(input, inverse, exact_dtype=True)

_test(torch.ones(4, dtype=dtype, device=device), 4, 4)
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)

@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
@dtypes(torch.double)
def test_istft_round_trip_various_params(self, device, dtype):
"""stft -> istft should recover the original signale"""
def _test_istft_is_inverse_of_stft(kwargs):
# generates a random sound signal for each tril and then does the stft/istft
# operation to check whether we can reconstruct signal
data_sizes = [(2, 20), (3, 15), (4, 10)]
num_trials = 100
for sizes in data_sizes:
for i in range(num_trials):
original = torch.randn(*sizes, dtype=dtype, device=device)
stft = torch.stft(original, **kwargs)
inversed = torch.istft(stft, length=original.size(1), **kwargs)

# trim the original for case when constructed signal is shorter than original
original = original[..., :inversed.size(-1)]
self.assertEqual(
inversed, original, 7e-6, 'istft comparison against original', exact_dtype=True)

patterns = [
# hann_window, centered, normalized, onesided
{
'n_fft': 12,
'hop_length': 4,
'win_length': 12,
'window': torch.hann_window(12, dtype=dtype, device=device),
'center': True,
'normalized': True,
'onesided': True,
},
# hann_window, centered, not normalized, not onesided
{
'n_fft': 12,
'hop_length': 2,
'win_length': 8,
'window': torch.hann_window(8, dtype=dtype, device=device),
'center': True,
'normalized': False,
'onesided': False,
},
# hamming_window, centered, normalized, not onesided
{
'n_fft': 15,
'hop_length': 3,
'win_length': 11,
'window': torch.hamming_window(11, dtype=dtype, device=device),
'center': True,
'normalized': True,
'onesided': False,
},
# hamming_window, not centered, not normalized, onesided
# window same size as n_fft
{
'n_fft': 5,
'hop_length': 2,
'win_length': 5,
'window': torch.hamming_window(5, dtype=dtype, device=device),
'center': False,
'normalized': False,
'onesided': True,
},
# hamming_window, not centered, not normalized, not onesided
# window same size as n_fft
{
'n_fft': 3,
'hop_length': 2,
'win_length': 3,
'window': torch.hamming_window(3, dtype=dtype, device=device),
'center': False,
'normalized': False,
'onesided': False,
},
]
for i, pattern in enumerate(patterns):
_test_istft_is_inverse_of_stft(pattern)

@onlyOnCPUAndCUDA
def test_istft_throws(self, device):
"""istft should throw exception for invalid parameters"""
stft = torch.zeros((3, 5, 2), device=device)
# the window is size 1 but it hops 20 so there is a gap which throw an error
self.assertRaises(
AssertionError, torch.istft, stft, n_fft=4,
hop_length=20, win_length=1, window=torch.ones(1))
# A window of zeros does not meet NOLA
invalid_window = torch.zeros(4, device=device)
self.assertRaises(
AssertionError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window)
# Input cannot be empty
self.assertRaises(AssertionError, torch.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(AssertionError, torch.istft, torch.zeros((0, 3, 2)), 2)

@onlyOnCPUAndCUDA
@dtypes(torch.double)
def test_istft_of_sine(self, device, dtype):
def _test(amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
x = torch.arange(2 * L + 1, dtype=dtype)
original = amplitude * torch.sin(2 * math.pi / L * x * n)
# stft = torch.stft(original, L, hop_length=L, win_length=L,
# window=torch.ones(L), center=False, normalized=False)
stft = torch.zeros((L // 2 + 1, 2, 2), dtype=dtype)
stft_largest_val = (amplitude * L) / 2.0
if n < stft.size(0):
stft[n, :, 1] = -stft_largest_val

if 0 <= L - n < stft.size(0):
# symmetric about L // 2
stft[L - n, :, 1] = stft_largest_val

inverse = torch.istft(
stft, L, hop_length=L, win_length=L,
window=torch.ones(L, dtype=dtype), center=False, normalized=False)
# There is a larger error due to the scaling of amplitude
original = original[..., :inverse.size(-1)]
self.assertEqual(inverse, original, 1e-3)

_test(amplitude=123, L=5, n=1)
_test(amplitude=150, L=5, n=2)
_test(amplitude=111, L=5, n=3)
_test(amplitude=160, L=7, n=4)
_test(amplitude=145, L=8, n=5)
_test(amplitude=80, L=9, n=6)
_test(amplitude=99, L=10, n=7)

@onlyOnCPUAndCUDA
@dtypes(torch.double)
def test_istft_linearity(self, device, dtype):
num_trials = 100

def _test(data_size, kwargs):
for i in range(num_trials):
tensor1 = torch.randn(data_size, device=device, dtype=dtype)
tensor2 = torch.randn(data_size, device=device, dtype=dtype)
a, b = torch.rand(2, dtype=dtype, device=device)
istft1 = torch.istft(tensor1, **kwargs)
istft2 = torch.istft(tensor2, **kwargs)
istft = a * istft1 + b * istft2
estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs)
self.assertEqual(istft, estimate, 1e-5)
patterns = [
# hann_window, centered, normalized, onesided
(
(2, 7, 7, 2),
{
'n_fft': 12,
'window': torch.hann_window(12, device=device, dtype=dtype),
'center': True,
'normalized': True,
'onesided': True,
},
),
# hann_window, centered, not normalized, not onesided
(
(2, 12, 7, 2),
{
'n_fft': 12,
'window': torch.hann_window(12, device=device, dtype=dtype),
'center': True,
'normalized': False,
'onesided': False,
},
),
# hamming_window, centered, normalized, not onesided
(
(2, 12, 7, 2),
{
'n_fft': 12,
'window': torch.hamming_window(12, device=device, dtype=dtype),
'center': True,
'normalized': True,
'onesided': False,
},
),
# hamming_window, not centered, not normalized, onesided
(
(2, 7, 3, 2),
{
'n_fft': 12,
'window': torch.hamming_window(12, device=device, dtype=dtype),
'center': False,
'normalized': False,
'onesided': True,
},
)
]
for data_size, kwargs in patterns:
_test(data_size, kwargs)

@onlyOnCPUAndCUDA
def test_batch_istft(self, device):
original = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
], device=device)

single = original.repeat(1, 1, 1, 1)
multi = original.repeat(4, 1, 1, 1)

i_original = torch.istft(original, n_fft=4, length=4)
i_single = torch.istft(single, n_fft=4, length=4)
i_multi = torch.istft(multi, n_fft=4, length=4)

self.assertEqual(i_original.repeat(1, 1), i_single, 1e-6, exact_dtype=True)
self.assertEqual(i_original.repeat(4, 1), i_multi, 1e-6, exact_dtype=True)

@skipCUDAIfRocm
def test_blas_empty(self, device):

Expand Down
2 changes: 2 additions & 0 deletions torch/_overrides.py
Expand Up @@ -324,6 +324,8 @@ def get_testing_overrides():
torch.is_signed: lambda input: -1,
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
torch.isnan: lambda input: -1,
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
normalized=False, onesided=True, length=None: -1),
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
Expand Down