Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix timestretch docstring and tutorial visualization #3694

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions examples/tutorials/audio_feature_augmentation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset

######################################################################
Expand Down Expand Up @@ -69,11 +70,6 @@ def get_spectrogram(
return spectrogram(waveform)


def plot_spec(ax, spec, title, ylabel="freq_bin"):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")


######################################################################
# SpecAugment
# -----------
Expand All @@ -98,11 +94,15 @@ def plot_spec(ax, spec, title, ylabel="freq_bin"):
spec_12 = stretch(spec, overriding_rate=1.2)
spec_09 = stretch(spec, overriding_rate=0.9)

######################################################################
#


######################################################################
# Visualization
# ~~~~~~~~~~~~~
def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.amplitude_to_db(spec), origin="lower", aspect="auto")

fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
Expand All @@ -112,6 +112,30 @@ def plot():

plot()


######################################################################
# Audio Samples
# ~~~~~~~~~~~~~
def preview(spec, rate=16000):
ispec = T.InverseSpectrogram()
waveform = ispec(spec)

return Audio(waveform[0].numpy().T, rate=rate)


preview(spec)


######################################################################
#
preview(spec_12)


######################################################################
#
preview(spec_09)


######################################################################
# Time and Frequency Masking
# --------------------------
Expand All @@ -131,6 +155,10 @@ def plot():


def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")

fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
Expand Down
30 changes: 13 additions & 17 deletions src/torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module):
Proposed in *SpecAugment* :cite:`specaugment`.

Args:
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
hop_length (int or None, optional): Length of hop between STFT windows.
(Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float or None, optional): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)

.. note::

The expected input is raw, complex-valued spectrogram.

Example
>>> spectrogram = torchaudio.transforms.Spectrogram()
>>> spectrogram = torchaudio.transforms.Spectrogram(power=None)
>>> stretch = torchaudio.transforms.TimeStretch()
>>>
>>> original = spectrogram(waveform)
>>> streched_1_2 = stretch(original, 1.2)
>>> streched_0_9 = stretch(original, 0.9)

.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
:width: 600
:alt: Spectrogram streched by 1.2
>>> stretched_1_2 = stretch(original, 1.2)
>>> stretched_0_9 = stretch(original, 0.9)

.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png
:width: 600
:alt: The original spectrogram

.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
:width: 600
:alt: Spectrogram streched by 0.9

:alt: The visualization of stretched spectrograms.
"""
__constants__ = ["fixed_rate"]

Expand All @@ -1067,8 +1063,8 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =

Returns:
Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
"""
if overriding_rate is None:
if self.fixed_rate is None:
Expand Down
Loading