Skip to content

Commit

Permalink
Make kaldi selective in build (#1342)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Mar 3, 2021
1 parent 5521f6c commit 3c44837
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 15 deletions.
8 changes: 6 additions & 2 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio'


def _get_build(var):
def _get_build(var, default=False):
if var not in os.environ:
return default

val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
Expand All @@ -32,6 +35,7 @@ def _get_build(var):


_BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_KALDI = _get_build("BUILD_KALDI", True)
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")


Expand Down Expand Up @@ -68,7 +72,7 @@ def build_extension(self, ext):
'-DCMAKE_VERBOSE_MAKEFILE=ON',
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
"-DBUILD_KALDI:BOOL=ON",
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
Expand Down
4 changes: 2 additions & 2 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
skipIfNoCuda,
skipIfNoExec,
skipIfNoModule,
skipIfNoExtension,
skipIfNoKaldi,
skipIfNoSox,
skipIfNoSoxBackend,
)
Expand All @@ -31,5 +31,5 @@

__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoExtension', 'skipIfNoSox',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoKaldi', 'skipIfNoSox',
'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params']
12 changes: 3 additions & 9 deletions test/torchaudio_unittest/common_utils/case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import torchaudio
from torchaudio._internal.module_utils import (
is_module_available,
is_sox_available
is_sox_available,
is_kaldi_available
)

from .backend_utils import set_audio_backend
Expand Down Expand Up @@ -99,11 +100,4 @@ def skipIfNoModule(module, display_name=None):
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')


def skipIfNoExtension(test_item):
if is_module_available('torchaudio._torchaudio'):
return test_item
if 'TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION' in os.environ:
raise RuntimeError('torchaudio C++ extension is not available.')
return unittest.skip('torchaudio C++ extension is not available')(test_item)
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_vad_different_items(self):
self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate)

@common_utils.skipIfNoExtension
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
sample_rate = 44100
n_channels = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

@common_utils.skipIfNoExtension
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.")
Expand Down
17 changes: 17 additions & 0 deletions torchaudio/_internal/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ def wrapped(*args, **kwargs):
return decorator


def is_kaldi_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_kaldi_available()


def requires_kaldi():
if is_kaldi_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires kaldi')
return wrapped
return decorator


def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()

Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_SOX)
endif()

if (BUILD_KALDI)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI)
endif()

target_include_directories(
_torchaudio
PRIVATE
Expand Down
9 changes: 9 additions & 0 deletions torchaudio/csrc/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@ bool is_sox_available() {
#endif
}

bool is_kaldi_available() {
#ifdef INCLUDE_KALDI
return true;
#else
return false;
#endif
}

} // namespace

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::is_sox_available", &is_sox_available);
m.def("torchaudio::is_kaldi_available", &is_kaldi_available);
}

} // namespace torchaudio
1 change: 1 addition & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ def apply_codec(
return augmented


@_mod_utils.requires_kaldi()
def compute_kaldi_pitch(
waveform: torch.Tensor,
sample_rate: float,
Expand Down

0 comments on commit 3c44837

Please sign in to comment.