-
Notifications
You must be signed in to change notification settings - Fork 634
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
List backends dynamically based on availability
- Loading branch information
Showing
2 changed files
with
28 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,49 @@ | ||
from typing import Any | ||
from typing import Any, Optional | ||
|
||
import platform | ||
from torchaudio._internal import module_utils as _mod_utils | ||
from . import _soundfile_backend, _sox_backend | ||
|
||
_BACKEND = None | ||
_BACKENDS = {} | ||
|
||
from . import _soundfile_backend, _sox_backend | ||
if _mod_utils.is_module_available('soundfile'): | ||
_BACKENDS['soundfile'] = _soundfile_backend | ||
_BACKEND = 'soundfile' | ||
if _mod_utils.is_module_available('torchaudio._torchaudio'): | ||
_BACKENDS['sox'] = _sox_backend | ||
_BACKEND = 'sox' | ||
|
||
|
||
if platform.system() == "Windows": | ||
_audio_backend = "soundfile" | ||
_audio_backends = {"soundfile": _soundfile_backend} | ||
else: | ||
_audio_backend = "sox" | ||
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend} | ||
def list_audio_backends(): | ||
return list(_BACKENDS.keys()) | ||
|
||
|
||
def set_audio_backend(backend: str) -> None: | ||
""" | ||
Specifies the package used to load. | ||
Args: | ||
backend (str): Name of the backend. One of {}. | ||
""".format(_audio_backends.keys()) | ||
global _audio_backend | ||
if backend not in _audio_backends: | ||
raise ValueError( | ||
"Invalid backend '{}'. Options are {}.".format(backend, _audio_backends.keys()) | ||
) | ||
_audio_backend = backend | ||
backend (str): Name of the backend. One of "sox" or "soundfile", | ||
based on availability of the system. | ||
""" | ||
if backend not in _BACKENDS: | ||
raise RuntimeError( | ||
f'Backend "{backend}" is not one of ' | ||
f'available backends: {list_audio_backends()}.') | ||
global _BACKEND | ||
_BACKEND = backend | ||
|
||
|
||
def get_audio_backend() -> str: | ||
def get_audio_backend() -> Optional[str]: | ||
""" | ||
Gets the name of the package used to load. | ||
""" | ||
return _audio_backend | ||
return _BACKEND | ||
|
||
|
||
def _get_audio_backend_module() -> Any: | ||
""" | ||
Gets the module backend to load. | ||
""" | ||
backend = get_audio_backend() | ||
return _audio_backends[backend] | ||
if _BACKEND is None: | ||
raise RuntimeError('Backend is not initialized.') | ||
return _BACKENDS[_BACKEND] |