Skip to content

Commit

Permalink
Add test for InverseMelScale
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 7, 2021
1 parent 2943214 commit bd9486a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_cpu_test.py
@@ -0,0 +1,14 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from . transforms_test_impl import TransformsTestBase


class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
dtype = torch.float32


class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
dtype = torch.float64
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_cuda_test.py
@@ -0,0 +1,19 @@
import torch

from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoCuda,
)
from . transforms_test_impl import TransformsTestBase


@skipIfNoCuda
class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cuda'
dtype = torch.float32


@skipIfNoCuda
class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cuda'
dtype = torch.float64
61 changes: 61 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_test_impl.py
@@ -0,0 +1,61 @@
import torch
import torchaudio.transforms as T

from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
)


def _get_ratio(mat):
return (mat.sum() / mat.numel()).item()


class TransformsTestBase(TestBaseMixin):
def test_InverseMelScale(self):
"""Gauge the quality of InverseMelScale transform.
As InverseMelScale is currently implemented with
random initialization + iterative optimization,
it is not practically possible to assert the difference between
the estimated spectrogram and the original spectrogram as a whole.
Estimated spectrogram has very huge descrepency locally.
Thus in this test we gauge what percentage of elements are bellow
certain tolerance.
At the moment, the quality of estimated spectrogram is not good.
When implementation is changed in a way it makes the quality even worse,
this test will fail.
"""
n_fft = 400
power = 1
n_mels = 64
sample_rate = 8000

n_stft = n_fft // 2 + 1

# Generate reference spectrogram and input mel-scaled spectrogram
expected = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2),
n_fft=n_fft, power=power).to(self.device, self.dtype)
input = T.MelScale(
n_mels=n_mels, sample_rate=sample_rate
).to(self.device, self.dtype)(expected)

# Run transform
transform = T.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype)
torch.random.manual_seed(0)
result = transform(input)

# Compare
epsilon = 1e-60
relative_diff = torch.abs((result - expected) / (expected + epsilon))

for tol in [1e-1, 1e-3, 1e-5, 1e-10]:
print(
f"Ratio of relative diff smaller than {tol:e} is "
f"{_get_ratio(relative_diff < tol)}")
assert _get_ratio(relative_diff < 1e-1) > 0.2
assert _get_ratio(relative_diff < 1e-3) > 5e-3
assert _get_ratio(relative_diff < 1e-5) > 1e-5

0 comments on commit bd9486a

Please sign in to comment.