Skip to content

Commit

Permalink
Refactor Kaldi compatibility tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 95de1f270b5341685c87d2952be25cef15da55b0
Pull Request resolved: #1359
  • Loading branch information
hwangjeff committed Mar 5, 2021
1 parent 64551a6 commit 345c9ab
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .kaldi_compatibility_test_impl import KaldiCPUOnly
from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly


class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .kaldi_compatibility_test_impl import Kaldi


@skipIfNoCuda
class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@skipIfNoCuda
class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from parameterized import parameterized
import torch
import torchaudio.functional as F

from torchaudio_unittest.common_utils import (
Expand All @@ -15,6 +16,24 @@
)


class Kaldi(TempDirMixin, TestBaseMixin):
@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
'cmn_window': 600,
'min_cmn_window': 100,
'center': False,
'norm_vars': False,
}

tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result)


class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Test suites for checking numerical compatibility against Kaldi"""
import torch
import torchaudio.functional as F
import torchaudio.compliance.kaldi
from parameterized import parameterized

Expand All @@ -23,22 +21,6 @@ def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)

@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
'cmn_window': 600,
'min_cmn_window': 100,
'center': False,
'norm_vars': False,
}

tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result)

@parameterized.expand(load_params('kaldi_test_fbank_args.json'))
@skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs):
Expand Down

0 comments on commit 345c9ab

Please sign in to comment.