Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline typing _backend #527

Merged
merged 4 commits into from
Apr 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 8 additions & 6 deletions torchaudio/_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import wraps
from typing import Any, List, Union

import platform
import torch
from torch import Tensor

from . import _soundfile_backend, _sox_backend

Expand All @@ -10,11 +12,11 @@
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}


def set_audio_backend(backend):
def set_audio_backend(backend: str) -> None:
"""
Specifies the package used to load.
Args:
backend (string): Name of the backend. One of {}.
backend (str): Name of the backend. One of {}.
""".format(_audio_backends.keys())
global _audio_backend
if backend not in _audio_backends:
Expand All @@ -24,22 +26,22 @@ def set_audio_backend(backend):
_audio_backend = backend


def get_audio_backend():
def get_audio_backend() -> str:
"""
Gets the name of the package used to load.
"""
return _audio_backend


def _get_audio_backend_module():
def _get_audio_backend_module() -> Any:
"""
Gets the module backend to load.
"""
backend = get_audio_backend()
return _audio_backends[backend]


def _audio_backend_guard(backends):
def _audio_backend_guard(backends: Union[str, List[str]]) -> Any:

if isinstance(backends, str):
backends = [backends]
Expand All @@ -55,7 +57,7 @@ def wrapper(*args, **kwargs):
return decorator


def check_input(src):
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
Expand Down
52 changes: 27 additions & 25 deletions torchaudio/_soundfile_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Any, Optional, Tuple, Union

import torch
from torch import Tensor

_subtype_to_precision = {
'PCM_S8': 8,
Expand All @@ -12,24 +14,26 @@


class SignalInfo:
def __init__(self, channels=None, rate=None, precision=None, length=None):
def __init__(self,
channels: Optional[int] = None,
rate: Optional[float] = None,
precision: Optional[int] = None,
length: Optional[int] = None) -> None:
self.channels = channels
self.rate = rate
self.precision = precision
self.length = length


class EncodingInfo:
def __init__(
self,
encoding=None,
bits_per_sample=None,
compression=None,
reverse_bytes=None,
reverse_nibbles=None,
reverse_bits=None,
opposite_endian=None
):
def __init__(self,
encoding: Any = None,
bits_per_sample: Optional[int] = None,
compression: Optional[float] = None,
reverse_bytes: Any = None,
reverse_nibbles: Any = None,
reverse_bits: Any = None,
opposite_endian: Optional[bool] = None) -> None:
self.encoding = encoding
self.bits_per_sample = bits_per_sample
self.compression = compression
Expand All @@ -39,24 +43,22 @@ def __init__(
self.opposite_endian = opposite_endian


def check_input(src):
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))


def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
channels_first: Optional[bool] = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: SignalInfo = None,
encodinginfo: EncodingInfo = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""See torchaudio.load"""

assert out is None
Expand Down Expand Up @@ -96,7 +98,7 @@ def load(
return out, sample_rate


def save(filepath, src, sample_rate, precision=16, channels_first=True):
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""

ch_idx, len_idx = (0, 1) if channels_first else (1, 0)
Expand Down Expand Up @@ -129,7 +131,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return soundfile.write(filepath, src, sample_rate, precision)


def info(filepath):
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""

import soundfile
Expand Down
32 changes: 16 additions & 16 deletions torchaudio/_sox_backend.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os.path
from typing import Any, Optional, Tuple, Union

import torch

from torch import Tensor
import torchaudio


def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo


def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
channels_first: Optional[bool] = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: SignalInfo = None,
encodinginfo: EncodingInfo = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""See torchaudio.load"""

# stringify if `pathlib.Path` (noop if already `str`)
Expand Down Expand Up @@ -53,7 +53,7 @@ def load(
return out, sample_rate


def save(filepath, src, sample_rate, precision=16, channels_first=True):
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""

si = torchaudio.sox_signalinfo_t()
Expand All @@ -65,7 +65,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return torchaudio.save_encinfo(filepath, src, channels_first, si)


def info(filepath):
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""

import _torch_sox
Expand Down