Skip to content

Commit

Permalink
Move utility Tensor functions to misc_ops module (#694)
Browse files Browse the repository at this point in the history
* also deletes duplicated func
  • Loading branch information
mthrok committed Jun 5, 2020
1 parent 9f3075c commit e5eb485
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 49 deletions.
29 changes: 5 additions & 24 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
transforms
)
from torchaudio._backend import (
check_input,
_get_audio_backend_module,
get_audio_backend,
set_audio_backend,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio.sox_effects import initialize_sox, shutdown_sox

try:
Expand Down Expand Up @@ -161,7 +163,7 @@ def save_encinfo(filepath: str,
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
_misc_ops.check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
Expand Down Expand Up @@ -328,24 +330,3 @@ def get_sox_bool(i: int = 0) -> Any:
return _torchaudio.sox_bool
else:
return _torchaudio.sox_bool(i)


def _audio_normalization(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""

if not normalization:
return

if isinstance(normalization, bool):
normalization = 1 << 31

if isinstance(normalization, (float, int)):
# normalize with custom value
a = normalization
signal /= a
elif callable(normalization):
a = normalization(signal)
signal /= a
13 changes: 2 additions & 11 deletions torchaudio/_backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from functools import wraps
from typing import Any, List, Union
from typing import Any

import platform
import torch
from torch import Tensor


from . import _soundfile_backend, _sox_backend

Expand Down Expand Up @@ -43,10 +41,3 @@ def _get_audio_backend_module() -> Any:
"""
backend = get_audio_backend()
return _audio_backends[backend]


def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
30 changes: 30 additions & 0 deletions torchaudio/_internal/misc_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Union, Callable

import torch
from torch import Tensor


def normalize_audio(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""

if not normalization:
return

if isinstance(normalization, bool):
normalization = 1 << 31

if isinstance(normalization, (float, int)):
# normalize with custom value
signal /= normalization
elif callable(normalization):
signal /= normalization(signal)


def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
12 changes: 4 additions & 8 deletions torchaudio/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
from torch import Tensor

from torchaudio._internal import misc_ops as _misc_ops


_subtype_to_precision = {
'PCM_S8': 8,
'PCM_16': 16,
Expand Down Expand Up @@ -43,13 +46,6 @@ def __init__(self,
self.opposite_endian = opposite_endian


def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))


def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
Expand Down Expand Up @@ -108,7 +104,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
_misc_ops.check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
Expand Down
9 changes: 6 additions & 3 deletions torchaudio/_sox_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from torch import Tensor

import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo

if _mod_utils.is_module_available('torchaudio._torchaudio'):
Expand All @@ -32,7 +35,7 @@ def load(filepath: str,

# initialize output tensor
if out is not None:
torchaudio.check_input(out)
_misc_ops.check_input(out)
else:
out = torch.FloatTensor()

Expand All @@ -53,7 +56,7 @@ def load(filepath: str,
)

# normalize if needed
torchaudio._audio_normalization(out, normalization)
_misc_ops.normalize_audio(out, normalization)

return out, sample_rate

Expand Down
9 changes: 6 additions & 3 deletions torchaudio/sox_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torchaudio
from torch import Tensor

from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)

if _mod_utils.is_module_available('torchaudio._torchaudio'):
from . import _torchaudio
Expand Down Expand Up @@ -200,7 +203,7 @@ def sox_build_flow_effects(self,
"""
# initialize output tensor
if out is not None:
torchaudio.check_input(out)
_misc_ops.check_input(out)
else:
out = torch.FloatTensor()
if not len(self.chain):
Expand All @@ -220,7 +223,7 @@ def sox_build_flow_effects(self,
self.chain,
self.MAX_EFFECT_OPTS)

torchaudio._audio_normalization(out, self.normalization)
_misc_ops.normalize_audio(out, self.normalization)

return out, sr

Expand Down

0 comments on commit e5eb485

Please sign in to comment.