diff --git a/examples/tutorials/audio_feature_augmentation_tutorial.py b/examples/tutorials/audio_feature_augmentation_tutorial.py index 6e69ef50560..be2b8c5a9bb 100644 --- a/examples/tutorials/audio_feature_augmentation_tutorial.py +++ b/examples/tutorials/audio_feature_augmentation_tutorial.py @@ -69,11 +69,6 @@ def get_spectrogram( return spectrogram(waveform) -def plot_spec(ax, spec, title, ylabel="freq_bin"): - ax.set_title(title) - ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto") - - ###################################################################### # SpecAugment # ----------- @@ -103,6 +98,10 @@ def plot_spec(ax, spec, title, ylabel="freq_bin"): def plot(): + def plot_spec(ax, spec, title): + ax.set_title(title) + ax.imshow(librosa.amplitude_to_db(spec), origin="lower", aspect="auto") + fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2") plot_spec(axes[1], torch.abs(spec[0]), title="Original") @@ -131,6 +130,10 @@ def plot(): def plot(): + def plot_spec(ax, spec, title): + ax.set_title(title) + ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto") + fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) plot_spec(axes[0], spec[0], title="Original") plot_spec(axes[1], time_masked[0], title="Masked along time axis") diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 6da6a4cd2f8..f5d8aebfcd5 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module): Proposed in *SpecAugment* :cite:`specaugment`. Args: - hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + hop_length (int or None, optional): Length of hop between STFT windows. + (Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``) n_freq (int, optional): number of filter banks from stft. (Default: ``201``) fixed_rate (float or None, optional): rate to speed up or slow down by. If None is provided, rate must be passed to the forward method. (Default: ``None``) + .. note:: + + The expected input is raw, complex-valued spectrogram. + Example - >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> spectrogram = torchaudio.transforms.Spectrogram(power=None) >>> stretch = torchaudio.transforms.TimeStretch() >>> >>> original = spectrogram(waveform) - >>> streched_1_2 = stretch(original, 1.2) - >>> streched_0_9 = stretch(original, 0.9) - - .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png - :width: 600 - :alt: Spectrogram streched by 1.2 + >>> stretched_1_2 = stretch(original, 1.2) + >>> stretched_0_9 = stretch(original, 0.9) - .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png :width: 600 - :alt: The original spectrogram - - .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png - :width: 600 - :alt: Spectrogram streched by 0.9 - + :alt: The visualization of stretched spectrograms. """ __constants__ = ["fixed_rate"]