From 172260f9cd6641b18455d228d0760ca1c81a44f3 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 9 Nov 2023 23:21:59 -0500 Subject: [PATCH] Update TimeStretch doc and tutorial (#3694) --- .../audio_feature_augmentation_tutorial.py | 44 +++++++++++++++---- src/torchaudio/transforms/_transforms.py | 30 ++++++------- 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/examples/tutorials/audio_feature_augmentation_tutorial.py b/examples/tutorials/audio_feature_augmentation_tutorial.py index 6e69ef5056..53395b4606 100644 --- a/examples/tutorials/audio_feature_augmentation_tutorial.py +++ b/examples/tutorials/audio_feature_augmentation_tutorial.py @@ -25,6 +25,7 @@ import librosa import matplotlib.pyplot as plt +from IPython.display import Audio from torchaudio.utils import download_asset ###################################################################### @@ -69,11 +70,6 @@ def get_spectrogram( return spectrogram(waveform) -def plot_spec(ax, spec, title, ylabel="freq_bin"): - ax.set_title(title) - ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto") - - ###################################################################### # SpecAugment # ----------- @@ -98,11 +94,15 @@ def plot_spec(ax, spec, title, ylabel="freq_bin"): spec_12 = stretch(spec, overriding_rate=1.2) spec_09 = stretch(spec, overriding_rate=0.9) -###################################################################### -# - +###################################################################### +# Visualization +# ~~~~~~~~~~~~~ def plot(): + def plot_spec(ax, spec, title): + ax.set_title(title) + ax.imshow(librosa.amplitude_to_db(spec), origin="lower", aspect="auto") + fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2") plot_spec(axes[1], torch.abs(spec[0]), title="Original") @@ -112,6 +112,30 @@ def plot(): plot() + +###################################################################### +# Audio Samples +# ~~~~~~~~~~~~~ +def preview(spec, rate=16000): + ispec = T.InverseSpectrogram() + waveform = ispec(spec) + + return Audio(waveform[0].numpy().T, rate=rate) + + +preview(spec) + + +###################################################################### +# +preview(spec_12) + + +###################################################################### +# +preview(spec_09) + + ###################################################################### # Time and Frequency Masking # -------------------------- @@ -131,6 +155,10 @@ def plot(): def plot(): + def plot_spec(ax, spec, title): + ax.set_title(title) + ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto") + fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) plot_spec(axes[0], spec[0], title="Original") plot_spec(axes[1], time_masked[0], title="Masked along time axis") diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 6da6a4cd2f..654b25c213 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module): Proposed in *SpecAugment* :cite:`specaugment`. Args: - hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + hop_length (int or None, optional): Length of hop between STFT windows. + (Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``) n_freq (int, optional): number of filter banks from stft. (Default: ``201``) fixed_rate (float or None, optional): rate to speed up or slow down by. If None is provided, rate must be passed to the forward method. (Default: ``None``) + .. note:: + + The expected input is raw, complex-valued spectrogram. + Example - >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> spectrogram = torchaudio.transforms.Spectrogram(power=None) >>> stretch = torchaudio.transforms.TimeStretch() >>> >>> original = spectrogram(waveform) - >>> streched_1_2 = stretch(original, 1.2) - >>> streched_0_9 = stretch(original, 0.9) - - .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png - :width: 600 - :alt: Spectrogram streched by 1.2 + >>> stretched_1_2 = stretch(original, 1.2) + >>> stretched_0_9 = stretch(original, 0.9) - .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png :width: 600 - :alt: The original spectrogram - - .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png - :width: 600 - :alt: Spectrogram streched by 0.9 - + :alt: The visualization of stretched spectrograms. """ __constants__ = ["fixed_rate"] @@ -1067,8 +1063,8 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = Returns: Tensor: - Stretched spectrogram. The resulting tensor is of the same dtype as the input - spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. + Stretched spectrogram. The resulting tensor is of the corresponding complex dtype + as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``. """ if overriding_rate is None: if self.fixed_rate is None: