Skip to content

Commit

Permalink
Move the backend implementation to _backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Aug 12, 2023
1 parent 9467fc4 commit dcfb5db
Show file tree
Hide file tree
Showing 24 changed files with 896 additions and 850 deletions.
22 changes: 9 additions & 13 deletions test/torchaudio_unittest/backend/dispatcher/dispatcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
import torch

from parameterized import parameterized
from torchaudio._backend.utils import (
FFmpegBackend,
get_info_func,
get_load_func,
get_save_func,
SoundfileBackend,
SoXBackend,
)
from torchaudio._backend.ffmpeg import FFmpegBackend
from torchaudio._backend.soundfile import SoundfileBackend
from torchaudio._backend.sox import SoXBackend
from torchaudio._backend.utils import get_info_func, get_load_func, get_save_func
from torchaudio_unittest.common_utils import PytorchTestCase


Expand Down Expand Up @@ -47,7 +43,7 @@ def test_info_fileobj(self, available_backends, expected_backend):
format = "wav"
buffer_size = 8192
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.info"
f"{expected_backend.__module__}.{expected_backend.__name__}.info"
) as mock_info:
get_info_func()(f, format=format, buffer_size=buffer_size)
mock_info.assert_called_once_with(f, format, buffer_size)
Expand All @@ -64,7 +60,7 @@ def test_load(self, available_backends, expected_backend):
filename = "test.wav"
format = "wav"
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.load"
f"{expected_backend.__module__}.{expected_backend.__name__}.load"
) as mock_load:
get_load_func()(filename, format=format)
mock_load.assert_called_once_with(filename, 0, -1, True, True, format, 4096)
Expand All @@ -83,7 +79,7 @@ def test_load_fileobj(self, available_backends, expected_backend):
format = "wav"
buffer_size = 8192
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.load"
f"{expected_backend.__module__}.{expected_backend.__name__}.load"
) as mock_load:
get_load_func()(f, format=format, buffer_size=buffer_size)
mock_load.assert_called_once_with(f, 0, -1, True, True, format, buffer_size)
Expand All @@ -102,7 +98,7 @@ def test_save(self, available_backends, expected_backend):
format = "wav"
sample_rate = 16000
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
f"{expected_backend.__module__}.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(filename, src, sample_rate, format=format)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096)
Expand All @@ -123,7 +119,7 @@ def test_save_fileobj(self, available_backends, expected_backend):
buffer_size = 8192
sample_rate = 16000
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
f"{expected_backend.__module__}.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size)
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/soundfile/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from unittest.mock import patch

import torch
from torchaudio._backend import soundfile as soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.backend.common import get_bits_per_sample, get_encoding
from torchaudio_unittest.common_utils import (
get_wav_data,
Expand Down
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/soundfile/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
from parameterized import parameterized
from torchaudio._backend import soundfile as soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
Expand Down
3 changes: 2 additions & 1 deletion test/torchaudio_unittest/backend/soundfile/save_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
from unittest.mock import patch

from torchaudio._backend import soundfile as soundfile_backend

from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
Expand Down
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/sox_io/info_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._backend import sox as sox_io_backend
from torchaudio_unittest.backend.common import get_encoding
from torchaudio_unittest.common_utils import (
get_asset_path,
Expand Down
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._backend import sox as sox_io_backend
from torchaudio_unittest.common_utils import (
get_asset_path,
get_wav_data,
Expand Down
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._backend import sox as sox_io_backend
from torchaudio_unittest.common_utils import get_wav_data, PytorchTestCase, skipIfNoExec, skipIfNoSox, TempDirMixin

from .common import get_enc_params, name_func
Expand Down
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/sox_io/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._backend import sox as sox_io_backend
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
Expand Down
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/sox_io/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._backend import sox as sox_io_backend
from torchaudio_unittest.common_utils import get_wav_data, skipIfNoSox, TempDirMixin, TorchaudioTestCase

from .common import name_func
Expand Down
11 changes: 5 additions & 6 deletions test/torchaudio_unittest/backend/sox_io/torchscript_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torchaudio
from parameterized import parameterized
from torchaudio._backend import sox as sox_io_backend
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
Expand All @@ -19,12 +20,12 @@
from .common import get_enc_params, name_func


def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
return torchaudio.backend.sox_io_backend.info(filepath)
def py_info_func(filepath: str) -> torchaudio.AudioMetaData:
return sox_io_backend.info(filepath)


def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return torchaudio.backend.sox_io_backend.load(filepath, normalize=normalize, channels_first=channels_first)
return sox_io_backend.load(filepath, normalize=normalize, channels_first=channels_first)


def py_save_func(
Expand All @@ -36,9 +37,7 @@ def py_save_func(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
torchaudio.backend.sox_io_backend.save(
filepath, tensor, sample_rate, channels_first, compression, None, encoding, bits_per_sample
)
sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression, None, encoding, bits_per_sample)


@skipIfNoExec("sox")
Expand Down
4 changes: 2 additions & 2 deletions test/torchaudio_unittest/backend/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
@common_utils.skipIfNoSox
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = "sox_io"
backend_module = torchaudio.backend.sox_io_backend
backend_module = torchaudio._backend.sox


@common_utils.skipIfNoModule("soundfile")
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = "soundfile"
backend_module = torchaudio.backend.soundfile_backend
backend_module = torchaudio._backend.soundfile
19 changes: 17 additions & 2 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
transforms,
utils,
)
from .backend.common import AudioMetaData
from ._backend.backend import AudioMetaData


try:
Expand All @@ -20,6 +20,9 @@
pass


################################################################################
# Backend stuff - to be cleaned up after 2.1 release
################################################################################
def _is_backend_dispatcher_enabled():
import os

Expand All @@ -29,11 +32,23 @@ def _is_backend_dispatcher_enabled():
if _is_backend_dispatcher_enabled():
from ._backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
else:
from .backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
from ._backend.legacy.utils import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend


_init_backend()

del _init_backend


# For backward compatibility.
# Looking at the previous __init__.py, we did not explicitly export `backend` here
# but, it was leaking and accessible as an attribute of torchaudio module.
#
# Perhaps we should remove it after couple of releases like 2.3?
from . import backend # noqa

################################################################################


__all__ = [
"AudioMetaData",
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch import Tensor

from torchaudio.backend.common import AudioMetaData
from .common import AudioMetaData


class Backend(ABC):
Expand Down
53 changes: 53 additions & 0 deletions torchaudio/_backend/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# TODO: Switch to dataclass when we completely get rid of torchscript support in I/O
class AudioMetaData:
"""AudioMetaData()
Return type of ``torchaudio.info`` function.
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
:ivar str encoding: Audio encoding
The values encoding can take are one of the following:
* ``PCM_S``: Signed integer linear PCM
* ``PCM_U``: Unsigned integer linear PCM
* ``PCM_F``: Floating point linear PCM
* ``FLAC``: Flac, Free Lossless Audio Codec
* ``ULAW``: Mu-law
* ``ALAW``: A-law
* ``MP3`` : MP3, MPEG-1 Audio Layer III
* ``VORBIS``: OGG Vorbis
* ``AMR_WB``: Adaptive Multi-Rate Wideband
* ``AMR_NB``: Adaptive Multi-Rate Narrowband
* ``OPUS``: Opus
* ``HTK``: Single channel 16-bit PCM
* ``UNKNOWN`` : None of above
"""

def __init__(
self,
sample_rate: int,
num_frames: int,
num_channels: int,
bits_per_sample: int,
encoding: str,
):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample
self.encoding = encoding

def __str__(self):
return (
f"AudioMetaData("
f"sample_rate={self.sample_rate}, "
f"num_frames={self.num_frames}, "
f"num_channels={self.num_channels}, "
f"bits_per_sample={self.bits_per_sample}, "
f"encoding={self.encoding}"
f")"
)
3 changes: 2 additions & 1 deletion torchaudio/_backend/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import torch
import torchaudio
from torchaudio.backend.common import AudioMetaData
from torchaudio.io import StreamWriter

from .backend import Backend

from .common import AudioMetaData

if torchaudio._extension._FFMPEG_EXT is not None:
StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
else:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from typing import List, Optional

import torchaudio
from torchaudio._backend import soundfile as soundfile_backend, sox as sox_io_backend
from torchaudio._internal import module_utils as _mod_utils

from . import no_backend, soundfile_backend, sox_io_backend
from . import no_backend

__all__ = [
"list_audio_backends",
Expand Down

0 comments on commit dcfb5db

Please sign in to comment.