Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torchaudio melscale 'slaney' instead of librosa in WaveRNN pipeline preprocessing #1444

Merged
merged 2 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/pipeline_wavernn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB
from processing import NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint


Expand Down Expand Up @@ -269,12 +269,12 @@ def main(args):
}

transforms = torch.nn.Sequential(
torchaudio.transforms.Spectrogram(**melkwargs),
LinearToMel(
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
n_mels=args.n_freq,
fmin=args.f_min,
f_min=args.f_min,
mel_scale='slaney',
**melkwargs,
),
NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
)
Expand Down
27 changes: 1 addition & 26 deletions examples/pipeline_wavernn/processing.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
import librosa
import torch
import torch.nn as nn


# TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved
class LinearToMel(nn.Module):
def __init__(self, sample_rate, n_fft, n_mels, fmin, htk=False, norm="slaney"):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.n_mels = n_mels
self.fmin = fmin
self.htk = htk
self.norm = norm

def forward(self, specgram):
specgram = librosa.feature.melspectrogram(
S=specgram.squeeze(0).numpy(),
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.fmin,
htk=self.htk,
norm=self.norm,
)
return torch.from_numpy(specgram)


class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value
"""
Expand All @@ -37,7 +12,7 @@ def __init__(self, min_level_db, normalization):
self.normalization = normalization

def forward(self, specgram):
specgram = torch.log10(torch.clamp(specgram, min=1e-5))
specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5))
if self.normalization:
return torch.clamp(
(self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1
Expand Down