Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add slaney normalization #589

Merged
merged 5 commits into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 10 additions & 3 deletions test/test_librosa_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,20 @@ def test_griffinlim(self):

torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)

def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None):
librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmax=fmax,
fmin=fmin,
htk=True,
norm=None)
norm=norm)
fb = F.create_fb_matrix(sample_rate=sample_rate,
n_mels=n_mels,
f_max=fmax,
f_min=fmin,
n_freqs=(n_fft // 2 + 1))
n_freqs=(n_fft // 2 + 1),
norm=norm)

for i_mel_bank in range(n_mels):
torch.testing.assert_allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]),
Expand All @@ -79,6 +80,12 @@ def test_create_fb(self):
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, norm="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, norm="slaney")

def test_amplitude_to_DB(self):
spec = torch.rand((6, 201))
Expand Down
3 changes: 2 additions & 1 deletion test/test_torchscript_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def func(_):
f_max = 20.0
n_mels = 10
sample_rate = 16000
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
norm = ""
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm)

dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
Expand Down
11 changes: 10 additions & 1 deletion torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ def create_fb_matrix(
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int
sample_rate: int,
norm: str = "",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a public API signature, I think Optional[str] looks cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was made due to comment. I'll leave it as it is for now. We can always extend the str to Optional[str] without BC breaking later :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that, that comment was meant for the type of variable to pass when running Torchscript test, not about the function signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow None in the signature, then the code should work with/without jit when passing None. It wasn't though. Is that what you meant?

Copy link
Collaborator

@mthrok mthrok May 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow None in the signature, then the code should work with/without jit when passing None.

Yes, it works.

from typing import Optional

import torch
from torch import Tensor


def bar(foo: Optional[str]=None) -> Tensor:
    if foo is None:
        return torch.zeros(1, 2)
    if foo == "a":
        return torch.ones(1, 1)

    return torch.empty(1, 1)


ts_bar = torch.jit.script(bar)

for v in [None, "a", "b"]:
    print(v)
    print(bar(v))
    print(ts_bar(v))

produces

None
tensor([[0., 0.]])
tensor([[0., 0.]])
a
tensor([[1.]])
tensor([[1.]])
b
tensor([[-2.8910e+12]])
tensor([[0.]])

also dcshift uses Optional[float] and it works fine. for both None and float input for torchscript.
#558

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just when type is optional, it firstly needs to compare against None using if var is None or if var is not None.

Copy link
Collaborator

@mthrok mthrok May 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/master/jit_language_reference.html#optional-type-refinement

TorchScript will refine the type of a variable of type Optional[T] when a comparison to None is made inside the conditional of an if-statement or checked in an assert. The compiler can reason about multiple None checks that are combined with and, or, and not. Refinement will also occur for else blocks of if-statements that are not explicitly written.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alrighty, #641 :)

) -> Tensor:
r"""Create a frequency bin conversion matrix.

Expand All @@ -434,6 +435,8 @@ def create_fb_matrix(
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
norm (str): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: '')

Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
Expand Down Expand Up @@ -461,6 +464,12 @@ def create_fb_matrix(
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(zero, torch.min(down_slopes, up_slopes))

if norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0)

return fb


Expand Down