Skip to content

Commit

Permalink
Call torch.stft directly
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 14, 2020
1 parent f4397fb commit d12b635
Showing 1 changed file with 3 additions and 29 deletions.
32 changes: 3 additions & 29 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,6 @@
]


# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore
def _stft(
waveform: Tensor,
n_fft: int,
hop_length: Optional[int],
win_length: Optional[int],
window: Optional[Tensor],
center: bool,
pad_mode: str,
normalized: bool,
onesided: bool
) -> Tensor:
return torch.stft(
waveform,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
)


def istft(
stft_matrix: Tensor,
n_fft: int,
Expand Down Expand Up @@ -265,7 +239,7 @@ def spectrogram(
waveform = waveform.view(-1, shape[-1])

# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft(
spec_f = torch.stft(
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
)

Expand Down Expand Up @@ -365,8 +339,8 @@ def griffinlim(
length=length).float()

# Rebuild the spectrogram
rebuilt = _stft(inverse, n_fft, hop_length, win_length, window,
True, 'reflect', False, True)
rebuilt = torch.stft(inverse, n_fft, hop_length, win_length, window,
True, 'reflect', False, True)

# Update our phase estimates
angles = rebuilt - tprev.mul_(momentum / (1 + momentum))
Expand Down

0 comments on commit d12b635

Please sign in to comment.