Skip to content

Commit

Permalink
Inline typing _backend (#527)
Browse files Browse the repository at this point in the history
* add inline typing

* fix error

* minor change

* minor fix
  • Loading branch information
tomassosorio committed Apr 9, 2020
1 parent 98fe8b4 commit c29598d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 47 deletions.
14 changes: 8 additions & 6 deletions torchaudio/_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import wraps
from typing import Any, List, Union

import platform
import torch
from torch import Tensor

from . import _soundfile_backend, _sox_backend

Expand All @@ -10,11 +12,11 @@
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}


def set_audio_backend(backend):
def set_audio_backend(backend: str) -> None:
"""
Specifies the package used to load.
Args:
backend (string): Name of the backend. One of {}.
backend (str): Name of the backend. One of {}.
""".format(_audio_backends.keys())
global _audio_backend
if backend not in _audio_backends:
Expand All @@ -24,22 +26,22 @@ def set_audio_backend(backend):
_audio_backend = backend


def get_audio_backend():
def get_audio_backend() -> str:
"""
Gets the name of the package used to load.
"""
return _audio_backend


def _get_audio_backend_module():
def _get_audio_backend_module() -> Any:
"""
Gets the module backend to load.
"""
backend = get_audio_backend()
return _audio_backends[backend]


def _audio_backend_guard(backends):
def _audio_backend_guard(backends: Union[str, List[str]]) -> Any:

if isinstance(backends, str):
backends = [backends]
Expand All @@ -55,7 +57,7 @@ def wrapper(*args, **kwargs):
return decorator


def check_input(src):
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
Expand Down
52 changes: 27 additions & 25 deletions torchaudio/_soundfile_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Any, Optional, Tuple, Union

import torch
from torch import Tensor

_subtype_to_precision = {
'PCM_S8': 8,
Expand All @@ -12,24 +14,26 @@


class SignalInfo:
def __init__(self, channels=None, rate=None, precision=None, length=None):
def __init__(self,
channels: Optional[int] = None,
rate: Optional[float] = None,
precision: Optional[int] = None,
length: Optional[int] = None) -> None:
self.channels = channels
self.rate = rate
self.precision = precision
self.length = length


class EncodingInfo:
def __init__(
self,
encoding=None,
bits_per_sample=None,
compression=None,
reverse_bytes=None,
reverse_nibbles=None,
reverse_bits=None,
opposite_endian=None
):
def __init__(self,
encoding: Any = None,
bits_per_sample: Optional[int] = None,
compression: Optional[float] = None,
reverse_bytes: Any = None,
reverse_nibbles: Any = None,
reverse_bits: Any = None,
opposite_endian: Optional[bool] = None) -> None:
self.encoding = encoding
self.bits_per_sample = bits_per_sample
self.compression = compression
Expand All @@ -39,24 +43,22 @@ def __init__(
self.opposite_endian = opposite_endian


def check_input(src):
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))


def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
channels_first: Optional[bool] = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: SignalInfo = None,
encodinginfo: EncodingInfo = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""See torchaudio.load"""

assert out is None
Expand Down Expand Up @@ -96,7 +98,7 @@ def load(
return out, sample_rate


def save(filepath, src, sample_rate, precision=16, channels_first=True):
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""

ch_idx, len_idx = (0, 1) if channels_first else (1, 0)
Expand Down Expand Up @@ -129,7 +131,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return soundfile.write(filepath, src, sample_rate, precision)


def info(filepath):
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""

import soundfile
Expand Down
32 changes: 16 additions & 16 deletions torchaudio/_sox_backend.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os.path
from typing import Any, Optional, Tuple, Union

import torch

from torch import Tensor
import torchaudio


def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo


def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
channels_first: Optional[bool] = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: SignalInfo = None,
encodinginfo: EncodingInfo = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""See torchaudio.load"""

# stringify if `pathlib.Path` (noop if already `str`)
Expand Down Expand Up @@ -53,7 +53,7 @@ def load(
return out, sample_rate


def save(filepath, src, sample_rate, precision=16, channels_first=True):
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""

si = torchaudio.sox_signalinfo_t()
Expand All @@ -65,7 +65,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return torchaudio.save_encinfo(filepath, src, channels_first, si)


def info(filepath):
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""

import _torch_sox
Expand Down

0 comments on commit c29598d

Please sign in to comment.