Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable mel_scale option #593

Merged
merged 1 commit into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 19 additions & 3 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,29 @@ def test_griffinlim(self):

self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None):
def _test_create_fb(
self, n_mels=40,
sample_rate=22050,
n_fft=2048,
fmin=0.0,
fmax=8000.0,
norm=None,
mel_scale="htk",
):
librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmax=fmax,
fmin=fmin,
htk=True,
htk=mel_scale == "htk",
norm=norm)
fb = F.create_fb_matrix(sample_rate=sample_rate,
n_mels=n_mels,
f_max=fmax,
f_min=fmin,
n_freqs=(n_fft // 2 + 1),
norm=norm)
norm=norm,
mel_scale=mel_scale)

for i_mel_bank in range(n_mels):
self.assertEqual(
Expand All @@ -73,6 +82,13 @@ def test_create_fb(self):
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
self._test_create_fb(mel_scale="slaney")
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
self._test_create_fb(n_mels=128, sample_rate=44100, mel_scale="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, mel_scale="slaney")
if StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
return
self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney")
Expand Down
18 changes: 10 additions & 8 deletions test/torchaudio_unittest/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,32 @@ def test_spectrogram(self, n_fft, hop_length, power):
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

@parameterized.expand([
param(norm=norm, **p.kwargs)
param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [
param(n_fft=400, hop_length=200, n_mels=128),
param(n_fft=600, hop_length=100, n_mels=128),
param(n_fft=200, hop_length=50, n_mels=128),
]
for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
])
def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm):
def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
sample_rate = 16000
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
sound_librosa = sound.cpu().numpy().squeeze()
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm)
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm, mel_scale=mel_scale)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 👍

librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm)
hop_length=hop_length, n_mels=n_mels, htk=mel_scale == "htk", norm=norm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertEqual(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)

@parameterized.expand([
param(norm=norm, **p.kwargs)
param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [
param(n_fft=400, hop_length=200, power=2.0, n_mels=128),
param(n_fft=600, hop_length=100, power=2.0, n_mels=128),
Expand All @@ -79,8 +80,9 @@ def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm):
param(n_fft=200, hop_length=50, power=2.0, n_mels=128, skip_ci=True),
]
for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
])
def test_s2db(self, n_fft, hop_length, power, n_mels, norm, skip_ci=False):
def test_s2db(self, n_fft, hop_length, power, n_mels, norm, mel_scale, skip_ci=False):
if skip_ci and 'CI' in os.environ:
self.skipTest('Test is known to fail on CI')
sample_rate = 16000
Expand All @@ -92,10 +94,10 @@ def test_s2db(self, n_fft, hop_length, power, n_mels, norm, skip_ci=False):
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm)
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm, mel_scale=mel_scale)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm)
hop_length=hop_length, n_mels=n_mels, htk=mel_scale == "htk", norm=norm)

power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
Expand Down
81 changes: 75 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,81 @@ def DB_to_amplitude(
return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)


def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
r"""Convert Hz to Mels.

Args:
freqs (float): Frequencies in Hz
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

Returns:
mels (float): Frequency in Mels
"""

if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')

if mel_scale == "htk":
return 2595.0 * math.log10(1.0 + (freq / 700.0))

# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3

mels = (freq - f_min) / f_sp

# Fill in the log-scale part
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0

if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz) / logstep

return mels


def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
"""Convert mel bin numbers to frequencies.

Args:
mels (Tensor): Mel frequencies
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

Returns:
freqs (Tensor): Mels converted in Hz
"""

if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')

if mel_scale == "htk":
return 700.0 * (10.0**(mels / 2595.0) - 1.0)

# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels

# And now the nonlinear scale
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0

log_t = (mels >= min_log_mel)
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

return freqs


def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> Tensor:
r"""Create a frequency bin conversion matrix.

Expand All @@ -316,6 +384,7 @@ def create_fb_matrix(
sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
Expand All @@ -333,12 +402,12 @@ def create_fb_matrix(
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

# calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
m_max = _hz_to_mel(f_max, mel_scale=mel_scale)

m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)

# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
Expand Down
32 changes: 25 additions & 7 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class MelScale(torch.nn.Module):
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']

Expand All @@ -258,18 +259,21 @@ def __init__(self,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: Optional[int] = None,
norm: Optional[str] = None) -> None:
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(MelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
self.norm = norm
self.mel_scale = mel_scale

assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)

fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm)
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
self.register_buffer('fb', fb)

def forward(self, specgram: Tensor) -> Tensor:
Expand All @@ -287,7 +291,8 @@ def forward(self, specgram: Tensor) -> Tensor:

if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
self.n_mels, self.sample_rate, self.norm)
self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
Expand Down Expand Up @@ -321,6 +326,7 @@ class InverseMelScale(torch.nn.Module):
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
"""
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs']
Expand All @@ -335,7 +341,8 @@ def __init__(self,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None,
norm: Optional[str] = None) -> None:
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(InverseMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
Expand All @@ -348,7 +355,8 @@ def __init__(self,

assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)

fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm,
mel_scale)
self.register_buffer('fb', fb)

def forward(self, melspec: Tensor) -> Tensor:
Expand Down Expand Up @@ -427,6 +435,7 @@ class MelSpectrogram(torch.nn.Module):
avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
Expand All @@ -450,7 +459,8 @@ def __init__(self,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
norm: Optional[str] = None) -> None:
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
Expand All @@ -467,7 +477,15 @@ def __init__(self,
pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs,
center=center, pad_mode=pad_mode, onesided=onesided)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
norm,
mel_scale
)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand Down