Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt PyTorch's test util to librosa compatibilities test #646

Merged
merged 1 commit into from
May 15, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
166 changes: 80 additions & 86 deletions test/test_librosa_compatibility.py
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA
Expand All @@ -17,15 +18,8 @@
import common_utils


class _LibrosaMixin:
"""Automatically skip tests if librosa is not available"""
def setUp(self):
super().setUp()
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa not available')


class TestFunctional(_LibrosaMixin, unittest.TestCase):
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
class TestFunctional(TestCase):
"""Test suite for functions in `functional` module."""
def test_griffinlim(self):
# NOTE: This test is flaky without a fixed random seed
Expand All @@ -51,7 +45,7 @@ def test_griffinlim(self):
momentum=momentum, init=init, length=length)
lr_out = torch.from_numpy(lr_out).unsqueeze(0)

torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None):
librosa_fb = librosa.filters.mel(sr=sample_rate,
Expand All @@ -69,8 +63,8 @@ def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fm
norm=norm)

for i_mel_bank in range(n_mels):
torch.testing.assert_allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]),
atol=1e-4, rtol=1e-5)
self.assertEqual(
fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4, rtol=1e-5)

def test_create_fb(self):
self._test_create_fb()
Expand Down Expand Up @@ -101,7 +95,7 @@ def test_amplitude_to_DB(self):
lr_out = librosa.core.power_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out)

torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

# Amplitude to DB
multiplier = 20.0
Expand All @@ -110,7 +104,7 @@ def test_amplitude_to_DB(self):
lr_out = librosa.core.amplitude_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out)

torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)


@pytest.mark.parametrize('complex_specgrams', [
Expand Down Expand Up @@ -161,73 +155,73 @@ def _load_audio_asset(*asset_paths, **kwargs):
return sound, sample_rate


def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)

# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)

out_torch = spect_transform(sound).squeeze().cpu()
torch.testing.assert_allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
torch.testing.assert_allclose(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)

# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
torch.testing.assert_allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)

mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
torch.testing.assert_allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)

power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
torch.testing.assert_allclose(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)

# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)

# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in
# https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
#
# librosa_mfcc = librosa.feature.mfcc(
# y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
# hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)

librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()

torch.testing.assert_allclose(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)


class TestTransforms(_LibrosaMixin, unittest.TestCase):
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
class TestTransforms(TestCase):
"""Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)

# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)

out_torch = spect_transform(sound).squeeze().cpu()
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertEqual(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)

# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertEqual(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)

mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
self.assertEqual(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)

power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertEqual(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)

# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)

# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in
# https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
#
# librosa_mfcc = librosa.feature.mfcc(
# y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
# hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)

librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()

self.assertEqual(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)

def test_basics1(self):
kwargs = {
'n_fft': 400,
Expand All @@ -237,7 +231,7 @@ def test_basics1(self):
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

def test_basics2(self):
kwargs = {
Expand All @@ -248,7 +242,7 @@ def test_basics2(self):
'n_mfcc': 20,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

# NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
@unittest.skipIf('CI' in os.environ, 'Test is known to fail on CI')
Expand All @@ -261,7 +255,7 @@ def test_basics3(self):
'n_mfcc': 50,
'sample_rate': 24000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

def test_basics4(self):
kwargs = {
Expand All @@ -272,7 +266,7 @@ def test_basics4(self):
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)

@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope("sox")
Expand All @@ -295,7 +289,7 @@ def test_MelScale(self):
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
torch.testing.assert_allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)
self.assertEqual(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)

def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
Expand Down Expand Up @@ -338,7 +332,7 @@ def test_InverseMelScale(self):
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
torch.testing.assert_allclose(spec_ta, spec_lr, atol=threshold, rtol=1e-5)
self.assertEqual(spec_ta, spec_lr, atol=threshold, rtol=1e-5)

threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
Expand Down