Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor backend and not rely on global variables on switching #698

Merged
merged 2 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

import torchaudio
from torchaudio._internal.module_utils import is_module_available


class BackendSwitch:
"""Test set/get_audio_backend works"""
backend = None
backend_module = None

def test_switch(self):
torchaudio.set_audio_backend(self.backend)
if self.backend is None:
assert torchaudio.get_audio_backend() is None
else:
assert torchaudio.get_audio_backend() == self.backend
assert torchaudio.load == self.backend_module.load
assert torchaudio.load_wav == self.backend_module.load_wav
assert torchaudio.save == self.backend_module.save
assert torchaudio.info == self.backend_module.info


class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
backend = None
backend_module = torchaudio.backend.no_backend


@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
class TestBackendSwitch_SoX(BackendSwitch, unittest.TestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend


@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
class TestBackendSwitch_soundfile(BackendSwitch, unittest.TestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
3 changes: 3 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def test_librispeech(self):
data[0]

@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #715

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fixes #715 completely?

Copy link
Collaborator Author

@mthrok mthrok Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. This is a temporary fix. #719 gives the principled fix which fully utilizes no_backend from this PR.

def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
data[0]

@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_commonvoice_diskcache(self):
data = COMMONVOICE(self.path, url="tatar")
data = diskcache_iterator(data)
Expand All @@ -43,6 +45,7 @@ def test_commonvoice_diskcache(self):
data[0]

@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_commonvoice_bg(self):
data = COMMONVOICE(self.path, url="tatar")
data = bg_iterator(data, 5)
Expand Down
118 changes: 0 additions & 118 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils
from torchaudio import (
compliance,
Expand All @@ -11,7 +7,6 @@
transforms
)
from torchaudio.backend import (
_get_audio_backend_module,
list_audio_backends,
get_audio_backend,
set_audio_backend,
Expand Down Expand Up @@ -57,116 +52,3 @@ def shutdown_sox():
This function is deprecated. See ``torchaudio.sox_effects.shutdown_sox_effects``
"""
_shutdown_sox_effects()


def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""Loads an audio file from disk into a tensor

Args:
filepath (str or Path): Path to audio file
out (Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``)
normalization (bool, float, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[-1, 1]`.
If `float`, then output is divided by that number
If `Callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result. (Default: ``True``)
channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
num_frames (int, optional): Number of frames to load. 0 to load everything after the offset.
(Default: ``0``)
offset (int, optional): Number of frames from the start of the file to begin data loading.
(Default: ``0``)
signalinfo (sox_signalinfo_t, optional): A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``)
encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``)
filetype (str, optional): A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``)

Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)

Example
>>> data, sample_rate = torchaudio.load('foo.mp3')
>>> print(data.size())
torch.Size([2, 278756])
>>> print(sample_rate)
44100
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
>>> print(data_vol_normalized.abs().max())
1.

"""
return _get_audio_backend_module().load(
filepath,
out=out,
normalization=normalization,
channels_first=channels_first,
num_frames=num_frames,
offset=offset,
signalinfo=signalinfo,
encodinginfo=encodinginfo,
filetype=filetype,
)


def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
the input right by 16 bits.

Args:
filepath (str or Path): Path to audio file

Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
"""
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)


def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""Convenience function for `save_encinfo`.

Args:
filepath (str): Path to audio file
src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate (int): An integer which is the sample rate of the
audio (as listed in the metadata of the file)
precision (int, optional): Bit precision (Default: ``16``)
channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``)
"""

return _get_audio_backend_module().save(
filepath, src, sample_rate, precision=precision, channels_first=channels_first
)


def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""Gets metadata from an audio file without loading the signal.

Args:
filepath (str): Path to audio file

Returns:
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info

Example
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""
return _get_audio_backend_module().info(filepath)
1 change: 0 additions & 1 deletion torchaudio/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from . import utils
from .utils import (
_get_audio_backend_module,
list_audio_backends,
get_audio_backend,
set_audio_backend,
Expand Down
114 changes: 113 additions & 1 deletion torchaudio/backend/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import Any, Optional


class SignalInfo:
Expand Down Expand Up @@ -29,3 +29,115 @@ def __init__(self,
self.reverse_nibbles = reverse_nibbles
self.reverse_bits = reverse_bits
self.opposite_endian = opposite_endian


_LOAD_DOCSTRING = r"""Loads an audio file from disk into a tensor

Args:
filepath: Path to audio file

out: An optional output tensor to use instead of creating one. (Default: ``None``)

normalization: Optional normalization.
If boolean `True`, then output is divided by `1 << 31`.
Assuming the input is signed 32-bit audio, this normalizes to `[-1, 1]`.
If `float`, then output is divided by that number.
If `Callable`, then the output is passed as a paramete to the given function,
then the output is divided by the result. (Default: ``True``)

channels_first: Set channels first or length first in result. (Default: ``True``)

num_frames: Number of frames to load. 0 to load everything after the offset.
(Default: ``0``)

offset: Number of frames from the start of the file to begin data loading.
(Default: ``0``)

signalinfo: A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``)

encodinginfo: A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``)

filetype: A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``)

Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where
L is the number of audio frames and
C is the number of channels.
An integer which is the sample rate of the audio (as listed in the metadata of the file)

Example
>>> data, sample_rate = torchaudio.load('foo.mp3')
>>> print(data.size())
torch.Size([2, 278756])
>>> print(sample_rate)
44100
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
>>> print(data_vol_normalized.abs().max())
1.
"""


_LOAD_WAV_DOCSTRING = r""" Loads a wave file.

It assumes that the wav file uses 16 bit per sample that needs normalization by
shifting the input right by 16 bits.

Args:
filepath: Path to audio file

Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
"""

_SAVE_DOCSTRING = r"""Saves a Tensor on file as an audio file

Args:
filepath: Path to audio file
src: An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate: An integer which is the sample rate of the
audio (as listed in the metadata of the file)
precision Bit precision (Default: ``16``)
channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``)
"""


_INFO_DOCSTRING = r"""Gets metadata from an audio file without loading the signal.

Args:
filepath: Path to audio file

Returns:
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info

Example
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""


def _impl_load(func):
setattr(func, '__doc__', _LOAD_DOCSTRING)
return func


def _impl_load_wav(func):
setattr(func, '__doc__', _LOAD_WAV_DOCSTRING)
return func


def _impl_save(func):
setattr(func, '__doc__', _SAVE_DOCSTRING)
return func


def _impl_info(func):
setattr(func, '__doc__', _INFO_DOCSTRING)
return func
Comment on lines +126 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about calling them _add_docstring_* instead?

Instead of having these functions, we could also use "no backend" as the template to grab the docstring from.

35 changes: 35 additions & 0 deletions torchaudio/backend/no_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

from torch import Tensor

from . import common
from .common import SignalInfo, EncodingInfo


@common._impl_load
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')


@common._impl_load_wav
def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')


@common._impl_save
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
raise RuntimeError('No audio I/O backend is available.')


@common._impl_info
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
raise RuntimeError('No audio I/O backend is available.')
11 changes: 11 additions & 0 deletions torchaudio/backend/soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from . import common
from .common import SignalInfo, EncodingInfo

if _mod_utils.is_module_available('soundfile'):
Expand All @@ -24,6 +25,7 @@


@_mod_utils.requires_module('soundfile')
@common._impl_load
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
Expand Down Expand Up @@ -71,6 +73,14 @@ def load(filepath: str,


@_mod_utils.requires_module('soundfile')
@common._impl_load_wav
def load_wav(filepath, **kwargs):
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)


@_mod_utils.requires_module('soundfile')
@common._impl_save
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""

Expand Down Expand Up @@ -104,6 +114,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan


@_mod_utils.requires_module('soundfile')
@common._impl_info
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""

Expand Down