-
Notifications
You must be signed in to change notification settings - Fork 634
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
94 additions
and
0 deletions.
There are no files selected for viewing
14 changes: 14 additions & 0 deletions
14
test/torchaudio_unittest/transforms/transforms_cpu_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import torch | ||
|
||
from torchaudio_unittest.common_utils import PytorchTestCase | ||
from . transforms_test_impl import TransformsTestBase | ||
|
||
|
||
class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase): | ||
device = 'cpu' | ||
dtype = torch.float32 | ||
|
||
|
||
class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase): | ||
device = 'cpu' | ||
dtype = torch.float64 |
19 changes: 19 additions & 0 deletions
19
test/torchaudio_unittest/transforms/transforms_cuda_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
from torchaudio_unittest.common_utils import ( | ||
PytorchTestCase, | ||
skipIfNoCuda, | ||
) | ||
from . transforms_test_impl import TransformsTestBase | ||
|
||
|
||
@skipIfNoCuda | ||
class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase): | ||
device = 'cuda' | ||
dtype = torch.float32 | ||
|
||
|
||
@skipIfNoCuda | ||
class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase): | ||
device = 'cuda' | ||
dtype = torch.float64 |
61 changes: 61 additions & 0 deletions
61
test/torchaudio_unittest/transforms/transforms_test_impl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import torchaudio.transforms as T | ||
|
||
from torchaudio_unittest.common_utils import ( | ||
TestBaseMixin, | ||
get_whitenoise, | ||
get_spectrogram, | ||
) | ||
|
||
|
||
def _get_ratio(mat): | ||
return (mat.sum() / mat.numel()).item() | ||
|
||
|
||
class TransformsTestBase(TestBaseMixin): | ||
def test_InverseMelScale(self): | ||
"""Gauge the quality of InverseMelScale transform. | ||
As InverseMelScale is currently implemented with | ||
random initialization + iterative optimization, | ||
it is not practically possible to assert the difference between | ||
the estimated spectrogram and the original spectrogram as a whole. | ||
Estimated spectrogram has very huge descrepency locally. | ||
Thus in this test we gauge what percentage of elements are bellow | ||
certain tolerance. | ||
At the moment, the quality of estimated spectrogram is not good. | ||
When implementation is changed in a way it makes the quality even worse, | ||
this test will fail. | ||
""" | ||
n_fft = 400 | ||
power = 1 | ||
n_mels = 64 | ||
sample_rate = 8000 | ||
|
||
n_stft = n_fft // 2 + 1 | ||
|
||
# Generate reference spectrogram and input mel-scaled spectrogram | ||
expected = get_spectrogram( | ||
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), | ||
n_fft=n_fft, power=power).to(self.device, self.dtype) | ||
input = T.MelScale( | ||
n_mels=n_mels, sample_rate=sample_rate | ||
).to(self.device, self.dtype)(expected) | ||
|
||
# Run transform | ||
transform = T.InverseMelScale( | ||
n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype) | ||
torch.random.manual_seed(0) | ||
result = transform(input) | ||
|
||
# Compare | ||
epsilon = 1e-60 | ||
relative_diff = torch.abs((result - expected) / (expected + epsilon)) | ||
|
||
for tol in [1e-1, 1e-3, 1e-5, 1e-10]: | ||
print( | ||
f"Ratio of relative diff smaller than {tol:e} is " | ||
f"{_get_ratio(relative_diff < tol)}") | ||
assert _get_ratio(relative_diff < 1e-1) > 0.2 | ||
assert _get_ratio(relative_diff < 1e-3) > 5e-3 | ||
assert _get_ratio(relative_diff < 1e-5) > 1e-5 |