Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move I/O related Tensor ops to misc_ops module #694

Merged
merged 1 commit into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 5 additions & 24 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
transforms
)
from torchaudio._backend import (
check_input,
_get_audio_backend_module,
get_audio_backend,
set_audio_backend,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio.sox_effects import initialize_sox, shutdown_sox

try:
Expand Down Expand Up @@ -161,7 +163,7 @@ def save_encinfo(filepath: str,
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
_misc_ops.check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
Expand Down Expand Up @@ -328,24 +330,3 @@ def get_sox_bool(i: int = 0) -> Any:
return _torchaudio.sox_bool
else:
return _torchaudio.sox_bool(i)


def _audio_normalization(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""

if not normalization:
return

if isinstance(normalization, bool):
normalization = 1 << 31

if isinstance(normalization, (float, int)):
# normalize with custom value
a = normalization
signal /= a
elif callable(normalization):
a = normalization(signal)
signal /= a
13 changes: 2 additions & 11 deletions torchaudio/_backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from functools import wraps
from typing import Any, List, Union
from typing import Any

import platform
import torch
from torch import Tensor


from . import _soundfile_backend, _sox_backend

Expand Down Expand Up @@ -43,10 +41,3 @@ def _get_audio_backend_module() -> Any:
"""
backend = get_audio_backend()
return _audio_backends[backend]


def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
30 changes: 30 additions & 0 deletions torchaudio/_internal/misc_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Union, Callable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: That's a very generic name for the file :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do not know what to call, but this is just evicting these functions from __init__.py so if this module glows into something then we can have a better naming.


import torch
from torch import Tensor


def normalize_audio(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""

if not normalization:
return

if isinstance(normalization, bool):
normalization = 1 << 31

if isinstance(normalization, (float, int)):
# normalize with custom value
signal /= normalization
elif callable(normalization):
signal /= normalization(signal)


def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
12 changes: 4 additions & 8 deletions torchaudio/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch
from torch import Tensor

from torchaudio._internal import misc_ops as _misc_ops


_subtype_to_precision = {
'PCM_S8': 8,
'PCM_16': 16,
Expand Down Expand Up @@ -43,13 +46,6 @@ def __init__(self,
self.opposite_endian = opposite_endian


def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))


def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
Expand Down Expand Up @@ -108,7 +104,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
_misc_ops.check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
Expand Down
9 changes: 6 additions & 3 deletions torchaudio/_sox_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from torch import Tensor

import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo

if _mod_utils.is_module_available('torchaudio._torchaudio'):
Expand All @@ -32,7 +35,7 @@ def load(filepath: str,

# initialize output tensor
if out is not None:
torchaudio.check_input(out)
_misc_ops.check_input(out)
else:
out = torch.FloatTensor()

Expand All @@ -53,7 +56,7 @@ def load(filepath: str,
)

# normalize if needed
torchaudio._audio_normalization(out, normalization)
_misc_ops.normalize_audio(out, normalization)

return out, sample_rate

Expand Down
9 changes: 6 additions & 3 deletions torchaudio/sox_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torchaudio
from torch import Tensor

from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)

if _mod_utils.is_module_available('torchaudio._torchaudio'):
from . import _torchaudio
Expand Down Expand Up @@ -200,7 +203,7 @@ def sox_build_flow_effects(self,
"""
# initialize output tensor
if out is not None:
torchaudio.check_input(out)
_misc_ops.check_input(out)
else:
out = torch.FloatTensor()
if not len(self.chain):
Expand All @@ -220,7 +223,7 @@ def sox_build_flow_effects(self,
self.chain,
self.MAX_EFFECT_OPTS)

torchaudio._audio_normalization(out, self.normalization)
_misc_ops.normalize_audio(out, self.normalization)

return out, sr

Expand Down