Skip to content

Commit

Permalink
Update spectrogram to use complex
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Mar 10, 2021
1 parent ea85794 commit 2212bfa
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
17 changes: 15 additions & 2 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def spectrogram(
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True
onesided: bool = True,
return_complex: bool = False,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Expand All @@ -70,12 +71,22 @@ def spectrogram(
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
return_complex (bool, optional):
``return_complex = True``, this function returns the resulting Tensor in
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
dimension for real and imaginary parts. (see ``torch.view_as_real``).
When ``power`` is provided, the value must be False, as the resulting
Tensor represents real-valued power.
Returns:
Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
if power is not None and return_complex:
raise ValueError(
'When `power` is provided, the return value is real-valued. '
'Therefore, `return_complex` must be False.')

if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
Expand Down Expand Up @@ -108,7 +119,9 @@ def spectrogram(
if power == 1.0:
return spec_f.abs()
return spec_f.abs().pow(power)
return torch.view_as_real(spec_f)
if not return_complex:
return torch.view_as_real(spec_f)
return spec_f


def griffinlim(
Expand Down
35 changes: 27 additions & 8 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class Spectrogram(torch.nn.Module):
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy Default: ``True``
return_complex (bool, optional):
``return_complex = True``, this function returns the resulting Tensor in
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
dimension for real and imaginary parts. (see ``torch.view_as_real``).
When ``power`` is provided, the value must be False, as the resulting
Tensor represents real-valued power.
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']

Expand All @@ -67,7 +73,8 @@ def __init__(self,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True) -> None:
onesided: bool = True,
return_complex: bool = False) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
Expand All @@ -82,6 +89,7 @@ def __init__(self,
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
self.return_complex = return_complex

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -104,7 +112,8 @@ def forward(self, waveform: Tensor) -> Tensor:
self.normalized,
self.center,
self.pad_mode,
self.onesided
self.onesided,
self.return_complex,
)


Expand Down Expand Up @@ -457,7 +466,8 @@ def __init__(self,
pad_mode: str = "reflect",
onesided: bool = True,
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
mel_scale: str = "htk",
return_complex: bool = False) -> None:
super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
Expand All @@ -469,11 +479,20 @@ def __init__(self,
self.n_mels = n_mels # number of mel frequency bins
self.f_max = f_max
self.f_min = f_min
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs,
center=center, pad_mode=pad_mode, onesided=onesided)
self.spectrogram = Spectrogram(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad,
window_fn=window_fn,
power=self.power,
normalized=self.normalized,
wkwargs=wkwargs,
center=center,
pad_mode=pad_mode,
onesided=onesided,
return_complex=return_complex,
)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
Expand Down

0 comments on commit 2212bfa

Please sign in to comment.