Skip to content

Commit

Permalink
Update calls to torch.stft to have return_complex=True (#1096)
Browse files Browse the repository at this point in the history
* Resolves #1095
  • Loading branch information
mthrok committed Dec 17, 2020
1 parent d25a4dd commit 3ace593
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
25 changes: 14 additions & 11 deletions test/torchaudio_unittest/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,22 @@ def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100

kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2

complex_specgrams = torch.stft(waveform, **kwargs)
complex_specgrams = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
)

# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
Expand Down
31 changes: 27 additions & 4 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,19 @@ def spectrogram(
waveform = waveform.reshape(-1, shape[-1])

# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
spec_f = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

# unpack batch
Expand Down Expand Up @@ -174,8 +185,20 @@ def griffinlim(
length=length).float()

# Rebuild the spectrogram
rebuilt = torch.stft(inverse, n_fft, hop_length, win_length, window,
True, 'reflect', False, True)
rebuilt = torch.view_as_real(
torch.stft(
input=inverse,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True,
)
)

# Update our phase estimates
angles = rebuilt
Expand Down

0 comments on commit 3ace593

Please sign in to comment.