Skip to content

Commit

Permalink
Fix timestretch docstring and tutorial visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Nov 10, 2023
1 parent 65df10b commit 31697a2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
13 changes: 8 additions & 5 deletions examples/tutorials/audio_feature_augmentation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ def get_spectrogram(
return spectrogram(waveform)


def plot_spec(ax, spec, title, ylabel="freq_bin"):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")


######################################################################
# SpecAugment
# -----------
Expand Down Expand Up @@ -103,6 +98,10 @@ def plot_spec(ax, spec, title, ylabel="freq_bin"):


def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.amplitude_to_db(spec), origin="lower", aspect="auto")

fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
Expand Down Expand Up @@ -131,6 +130,10 @@ def plot():


def plot():
def plot_spec(ax, spec, title):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")

fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
Expand Down
30 changes: 13 additions & 17 deletions src/torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module):
Proposed in *SpecAugment* :cite:`specaugment`.
Args:
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
hop_length (int or None, optional): Length of hop between STFT windows.
(Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float or None, optional): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
.. note::
The expected input is raw, complex-valued spectrogram.
Example
>>> spectrogram = torchaudio.transforms.Spectrogram()
>>> spectrogram = torchaudio.transforms.Spectrogram(power=None)
>>> stretch = torchaudio.transforms.TimeStretch()
>>>
>>> original = spectrogram(waveform)
>>> streched_1_2 = stretch(original, 1.2)
>>> streched_0_9 = stretch(original, 0.9)
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
:width: 600
:alt: Spectrogram streched by 1.2
>>> stretched_1_2 = stretch(original, 1.2)
>>> stretched_0_9 = stretch(original, 0.9)
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png
:width: 600
:alt: The original spectrogram
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
:width: 600
:alt: Spectrogram streched by 0.9
:alt: The visualization of stretched spectrograms.
"""
__constants__ = ["fixed_rate"]

Expand All @@ -1067,8 +1063,8 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
Returns:
Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
"""
if overriding_rate is None:
if self.fixed_rate is None:
Expand Down

0 comments on commit 31697a2

Please sign in to comment.