Skip to content
Open

A Law #930

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,18 @@ dimension (channels, time)")

Transforms expect and return the following dimensions.

* `Spectrogram`: (channels, time) -> (channels, freq, time)
* `AmplitudeToDB`: (channels, freq, time) -> (channels, freq, time)
* `MelScale`: (channels, freq, time) -> (channels, mel, time)
* `MelSpectrogram`: (channels, time) -> (channels, mel, time)
* `MFCC`: (channels, time) -> (channel, mfcc, time)
* `MuLawEncode`: (channels, time) -> (channels, time)
* `MuLawDecode`: (channels, time) -> (channels, time)
* `Resample`: (channels, time) -> (channels, time)
* `Fade`: (channels, time) -> (channels, time)
* `Vol`: (channels, time) -> (channels, time)
* `Spectrogram`: (channel, time) -> (channel, freq, time)
* `AmplitudeToDB`: (channel, freq, time) -> (channel, freq, time)
* `MelScale`: (channel, freq, time) -> (channel, mel, time)
* `MelSpectrogram`: (channel, time) -> (channel, mel, time)
* `MFCC`: (channel, time) -> (channel, mfcc, time)
* `MuLawEncode`: (channel, time) -> (channel, time)
* `MuLawDecode`: (channel, time) -> (channel, time)
* `ALawEncode`: (channel, time) -> (channel, time)
* `ALawDecode`: (channel, time) -> (channel, time)
* `Resample`: (channel, time) -> (channel, time)
* `Fade`: (channel, time) -> (channel, time)
* `Vol`: (channel, time) -> (channel, time)

Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions.

Expand Down
11 changes: 11 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ vad

.. autofunction:: mu_law_decoding

=======
:hidden:`a_law_encoding`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: a_law_encoding

:hidden:`a_law_decoding`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: a_law_decoding

:hidden:`complex_norm`
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
14 changes: 14 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`ALawEncoding`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ALawEncoding

.. automethod:: forward

:hidden:`ALawDecoding`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ALawDecoding

.. automethod:: forward

:hidden:`Resample`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
25 changes: 25 additions & 0 deletions test/torchaudio_unittest/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,31 @@ def test_batch_mulaw(self):
# shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected)

def test_batch_alaw(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100

# Single then transform then batch
waveform_encoded = torchaudio.transforms.ALawEncoding()(waveform)
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

# Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = torchaudio.transforms.ALawEncoding()(waveform_batched)

# shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected)

# Single then transform then batch
waveform_decoded = torchaudio.transforms.ALawDecoding()(waveform_encoded)
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

# Batch then transform
computed = torchaudio.transforms.ALawDecoding()(computed)

# shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected)

def test_batch_spectrogram(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
Expand Down
18 changes: 18 additions & 0 deletions test/torchaudio_unittest/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ def func(tensor):
tensor = torch.rand((1, 10))
self._assert_consistency(func, tensor)

def test_a_law_encoding(self):
def func(tensor):
qc = 256
compression = 83.7
return F.a_law_encoding(tensor, qc, compression)

waveform = common_utils.get_whitenoise()
self._assert_consistency(func, waveform,)

def test_a_law_decoding(self):
def func(tensor):
qc = 256
compression = 83.7
return F.a_law_decoding(tensor, qc, compression)

tensor = torch.rand((1, 10))
self._assert_consistency(func, tensor)

def test_complex_norm(self):
def func(tensor):
power = 2.
Expand Down
31 changes: 31 additions & 0 deletions test/torchaudio_unittest/transforms_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import unittest
import audioop

import torch
import torchaudio
Expand Down Expand Up @@ -43,6 +44,36 @@ def test_mu_law_companding(self):
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)

def test_a_law_companding(self):

quantization_channels = 256
compression_param = 83.7

waveform = self.waveform.clone()
if not waveform.is_floating_point():
waveform = waveform.to(torch.get_default_dtype())
waveform /= torch.abs(waveform).max()

self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)

waveform_a = transforms.ALawEncoding(quantization_channels, compression_param)(waveform)
self.assertTrue(waveform_a.min() >= 0. and waveform_a.max() <= quantization_channels)

waveform_exp = transforms.ALawDecoding(quantization_channels, compression_param)(waveform_a)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)

segment_length = 1
small_int_waveform = waveform.to(torch.uint8)
waveform_bytes = bytearray(small_int_waveform[0, :])

encoded = audioop.lin2alaw(waveform_bytes, segment_length)
torch_encoded = transforms.ALawEncoding(quantization_channels, compression_param)(small_int_waveform)
self.assertEqual(torch.tensor(list(encoded)), torch_encoded[0])

decoded = audioop.alaw2lin(encoded, segment_length)
torch_decoded = transforms.ALawDecoding(quantization_channels, compression_param)(torch_encoded)
self.assertEqual(torch.tensor(list(decoded)), torch_decoded[0])

def test_AmplitudeToDB(self):
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform = common_utils.load_wav(filepath)[0]
Expand Down
65 changes: 65 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"DB_to_amplitude",
"mu_law_encoding",
"mu_law_decoding",
"a_law_encoding",
"a_law_decoding",
"complex_norm",
"angle",
"magphase",
Expand Down Expand Up @@ -417,6 +419,69 @@ def mu_law_decoding(
return x


def a_law_encoding(
x: Tensor,
quantization_channels: int,
compression: float
) -> Tensor:
r"""Encode signal based on A-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/A-law_algorithm>`_

This algorithm assumes the signal has been scaled to between -1 and 1 and
returns a signal encoded with values from 0 to quantization_channels - 1.

Args:
x (Tensor): Input tensor
quantization_channels (int): Number of channels

Returns:
Tensor: Input after A-law encoding
"""
quant = quantization_channels - 1.0
if not x.is_floating_point():
x = x.to(torch.float)
A = torch.tensor(compression, dtype=x.dtype)
x_abs = torch.abs(x)
x_narrow = A * x_abs
x_wide = (1 + torch.log(x_narrow))
x_numerator = torch.where(x_abs < (1 / A), x_narrow, x_wide)
a = torch.sign(x) * x_numerator / (1 + torch.log(A))
x_a = ((a + 1) / 2 * quant + 0.5).to(torch.int64)
return x_a


def a_law_decoding(
x_q: Tensor,
quantization_channels: int,
compression: float
) -> Tensor:
r"""Decode A-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/A-law_algorithm>`_

This expects an input with values between 0 and quantization_channels - 1
and returns a signal scaled between -1 and 1.

Args:
x_q (Tensor): Input tensor
quantization_channels (int): Number of channels

Returns:
Tensor: Input after A-law decoding
"""
quant = quantization_channels - 1.0
if not x_q.is_floating_point():
x_q = x_q.to(torch.float)
A = torch.tensor(compression, dtype=x_q.dtype)
x_a = (x_q / quant) * 2 - 1.0
ln_a = 1 + torch.log(A)
x_abs = torch.abs(x_a)
x_narrow = x_abs * ln_a
x_wide = torch.exp(x_narrow - 1)
x_numerator = torch.where(x_abs < (1 / ln_a), x_narrow, x_wide)
x = torch.sign(x_a) * x_numerator / A
return x


def complex_norm(
complex_tensor: Tensor,
power: float = 1.0
Expand Down
57 changes: 57 additions & 0 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
'MFCC',
'MuLawEncoding',
'MuLawDecoding',
'ALawEncoding',
'ALawDecoding',
'Resample',
'ComplexNorm',
'TimeStretch',
Expand Down Expand Up @@ -568,6 +570,61 @@ def forward(self, x_mu: Tensor) -> Tensor:
"""
return F.mu_law_decoding(x_mu, self.quantization_channels)

class ALawEncoding(torch.nn.Module):
r"""Encode signal based on A-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/A-law_algorithm>`_

This algorithm assumes the signal has been scaled to between -1 and 1 and
returns a signal encoded with values from 0 to quantization_channels - 1

Args:
quantization_channels (int, optional): Number of channels. (Default: ``256``)
"""
__constants__ = ['quantization_channels']

def __init__(self, quantization_channels: int = 256, compress_param: float = 83.7) -> None:
super(ALawEncoding, self).__init__()
self.quantization_channels = quantization_channels
self.compress_param = compress_param

def forward(self, x: Tensor) -> Tensor:
r"""
Args:
x (Tensor): A signal to be encoded.

Returns:
x_a (Tensor): An encoded signal.
"""
return F.a_law_encoding(x, self.quantization_channels, self.compress_param)


class ALawDecoding(torch.nn.Module):
r"""Decode A-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/A-law_algorithm>`_

This expects an input with values between 0 and quantization_channels - 1
and returns a signal scaled between -1 and 1.

Args:
quantization_channels (int, optional): Number of channels. (Default: ``256``)
"""
__constants__ = ['quantization_channels']

def __init__(self, quantization_channels: int = 256, compression_param: float = 83.7) -> None:
super(ALawDecoding, self).__init__()
self.quantization_channels = quantization_channels
self.compression_param = compression_param

def forward(self, x_a: Tensor) -> Tensor:
r"""
Args:
x_a (Tensor): An A-law encoded signal which needs to be decoded.

Returns:
Tensor: The signal decoded.
"""
return F.a_law_decoding(x_a, self.quantization_channels, self.compression_param)


class Resample(torch.nn.Module):
r"""Resample a signal from one frequency to another. A resampling method can be given.
Expand Down