Skip to content

Commit

Permalink
Add dcshift to functional (#558)
Browse files Browse the repository at this point in the history
* Add dcshift to functional

* Doc string change and remove inplace clamp

* Minor Fix to dcshit and separate sox test refactoring

* Minor change to limiter_gain type

* adding dcshift to __all__ in functional
  • Loading branch information
bhargavkathivarapu committed Apr 20, 2020
1 parent fc2537e commit 91e5923
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ Functions to perform common audio operations.

.. autofunction:: contrast

:hidden:`dcshift`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: dcshift

:hidden:`mask_along_axis`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.contrast, waveform, enhancement_amount=80.)

def test_dcshift(self):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)


class TestTransforms(unittest.TestCase):
"""Test suite for classes defined in `transforms` module"""
Expand Down
37 changes: 37 additions & 0 deletions test/test_sox_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,43 @@ def test_contrast(self):

torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_dcshift_with_limiter(self):
"""
Test dcshift effect, compare to SoX implementation
"""
shift = 0.5
limiter_gain = 0.05
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("dcshift", [shift, limiter_gain])
sox_output_waveform, sr = E.sox_build_flow_effects()

waveform, _ = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.dcshift(waveform, shift, limiter_gain)

torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_dcshift_without_limiter(self):
"""
Test dcshift effect, compare to SoX implementation
"""
shift = 0.6
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("dcshift", [shift])
sox_output_waveform, sr = E.sox_build_flow_effects()

waveform, _ = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.dcshift(waveform, shift)

torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_equalizer(self):
Expand Down
12 changes: 12 additions & 0 deletions test/test_torchscript_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,18 @@ def func(tensor):

self._assert_consistency(func, waveform)

def test_dcshift(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

def func(tensor):
shift = 0.5
limiter_gain = 0.05
return F.dcshift(tensor, shift, limiter_gain)

self._assert_consistency(func, waveform)


class _TransformsTestMixin:
"""Implements test for Transforms that are performed for different devices"""
device = None
Expand Down
45 changes: 45 additions & 0 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"riaa_biquad",
"biquad",
"contrast",
"dcshift",
'mask_along_axis',
'mask_along_axis_iid',
'sliding_window_cmn',
Expand Down Expand Up @@ -1194,6 +1195,50 @@ def contrast(
return output_waveform


def dcshift(
waveform: Tensor,
shift: float,
limiter_gain: Optional[float] = None
) -> Tensor:
r"""Apply a DC shift to the audio. Similar to SoX implementation.
This can be useful to remove a DC offset
(caused perhaps by a hardware problem in the recording chain) from the audio
Args:
waveform (Tensor): audio waveform of dimension of `(..., time)`
shift (float): indicates the amount to shift the audio
Allowed range of values for shift : -2.0 to +2.0
limiter_gain (float): It is used only on peaks to prevent clipping
It should have a value much less than 1 (e.g. 0.05 or 0.02)
Returns:
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
"""
output_waveform = waveform
limiter_threshold = 0.

if limiter_gain is not None:
limiter_threshold = 1.0 - (abs(shift) - limiter_gain)

if limiter_gain is not None and shift > 0:
mask = waveform > limiter_threshold
temp = (waveform[mask] - limiter_threshold) * limiter_gain / (1 - limiter_threshold)
output_waveform[mask] = (temp + limiter_threshold + shift).clamp(max=limiter_threshold)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
elif limiter_gain is not None and shift < 0:
mask = waveform < -limiter_threshold
temp = (waveform[mask] + limiter_threshold) * limiter_gain / (1 - limiter_threshold)
output_waveform[mask] = (temp - limiter_threshold + shift).clamp(min=-limiter_threshold)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
else:
output_waveform = (waveform + shift).clamp(min=-1, max=1)

return output_waveform


def mask_along_axis_iid(
specgrams: Tensor,
mask_param: int,
Expand Down

0 comments on commit 91e5923

Please sign in to comment.