Skip to content

Commit

Permalink
Refactor Kaldi compatibility tests (#1359)
Browse files Browse the repository at this point in the history
* Refactor Kaldi compatibility tests

Co-authored-by: Jeff Hwang <jeffhwang@fb.com>
  • Loading branch information
hwangjeff and hwangjeff committed Mar 5, 2021
1 parent 64551a6 commit 301a6e3
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 20 deletions.
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The following is an overview of the tests and related modules for `torchaudio`.
Test suite for numerical compatibility against librosa.
- [SoX compatibility test](./transforms/sox_compatibility_test.py)
Test suite for numerical compatibility against SoX.
- [Kaldi compatibility test](./kaldi_compatibility_test.py)
- [Kaldi compatibility test](./transforms/kaldi_compatibility_impl.py)
Test suite for numerical compatibility against Kaldi.

#### Result consistency with PyTorch framework
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .kaldi_compatibility_test_impl import KaldiCPUOnly
from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly


class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .kaldi_compatibility_test_impl import Kaldi


@skipIfNoCuda
class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from parameterized import parameterized
import torch
import torchaudio.functional as F

from torchaudio_unittest.common_utils import (
Expand All @@ -15,6 +16,28 @@
)


class Kaldi(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)

@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
'cmn_window': 600,
'min_cmn_window': 100,
'center': False,
'norm_vars': False,
}

tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result)


class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Test suites for checking numerical compatibility against Kaldi"""
import torch
import torchaudio.functional as F
import torchaudio.compliance.kaldi
from parameterized import parameterized

Expand All @@ -23,22 +21,6 @@ def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)

@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
'cmn_window': 600,
'min_cmn_window': 100,
'center': False,
'norm_vars': False,
}

tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result)

@parameterized.expand(load_params('kaldi_test_fbank_args.json'))
@skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs):
Expand Down

0 comments on commit 301a6e3

Please sign in to comment.