Skip to content

Commit

Permalink
List backends dynamically based on availability
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 5, 2020
1 parent 0c6e29c commit a18ed45
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torchaudio

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends
BACKENDS = torchaudio._backend._BACKENDS


def get_asset_path(*paths):
Expand Down
48 changes: 27 additions & 21 deletions torchaudio/_backend.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
from typing import Any
from typing import Any, Optional

import platform
from torchaudio._internal import module_utils as _mod_utils
from . import _soundfile_backend, _sox_backend

_BACKEND = None
_BACKENDS = {}

from . import _soundfile_backend, _sox_backend
if _mod_utils.is_module_available('soundfile'):
_BACKENDS['soundfile'] = _soundfile_backend
_BACKEND = 'soundfile'
if _mod_utils.is_module_available('torchaudio._torchaudio'):
_BACKENDS['sox'] = _sox_backend
_BACKEND = 'sox'


if platform.system() == "Windows":
_audio_backend = "soundfile"
_audio_backends = {"soundfile": _soundfile_backend}
else:
_audio_backend = "sox"
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}
def list_audio_backends():
return list(_BACKENDS.keys())


def set_audio_backend(backend: str) -> None:
"""
Specifies the package used to load.
Args:
backend (str): Name of the backend. One of {}.
""".format(_audio_backends.keys())
global _audio_backend
if backend not in _audio_backends:
raise ValueError(
"Invalid backend '{}'. Options are {}.".format(backend, _audio_backends.keys())
)
_audio_backend = backend
backend (str): Name of the backend. One of "sox" or "soundfile",
based on availability of the system.
"""
if backend not in _BACKENDS:
raise RuntimeError(
f'Backend "{backend}" is not one of '
f'available backends: {list_audio_backends()}.')
global _BACKEND
_BACKEND = backend


def get_audio_backend() -> str:
def get_audio_backend() -> Optional[str]:
"""
Gets the name of the package used to load.
"""
return _audio_backend
return _BACKEND


def _get_audio_backend_module() -> Any:
"""
Gets the module backend to load.
"""
backend = get_audio_backend()
return _audio_backends[backend]
if _BACKEND is None:
raise RuntimeError('Backend is not initialized.')
return _BACKENDS[_BACKEND]

0 comments on commit a18ed45

Please sign in to comment.