Skip to content

Commit

Permalink
lfilter: add an optional arg clamp (#600)
Browse files Browse the repository at this point in the history
* lfilter: add an optional arg `clamp`

to give users control on the clamping behavior within the range [-1, 1],
which was hardcoded before.

Fixes #596

* doc string formatting

* lint

* doc string again

Co-authored-by: Vincent QB <vincentqb@users.noreply.github.com>
  • Loading branch information
r9y9 and vincentqb committed Apr 30, 2020
1 parent fe30f30 commit 0fcb0c0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
9 changes: 9 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def test_simple(self):

torch.testing.assert_allclose(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)

def test_clamp(self):
input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True)
self.assertTrue(output_signal.max() <= 1)
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
self.assertTrue(output_signal.max() > 1)


class TestLfilterFloat32CPU(_LfilterMixin, unittest.TestCase):
device = torch.device('cpu')
Expand Down
21 changes: 13 additions & 8 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,21 +673,23 @@ def phase_vocoder(
def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor
b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.
Args:
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[b0, b1, b2, ...]`.
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of `(..., time)`. Output will be clipped to -1 to 1.
Tensor: Waveform with dimension of ``(..., time)``.
"""
# pack batch
shape = waveform.size()
Expand Down Expand Up @@ -731,7 +733,10 @@ def lfilter(
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
padded_output_waveform[:, i_sample + n_order - 1] = o0

output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.)
output = padded_output_waveform[:, (n_order - 1):]

if clamp:
output = torch.clamp(output, min=-1., max=1.)

# unpack batch
output = output.reshape(shape[:-1] + output.shape[-1:])
Expand Down

0 comments on commit 0fcb0c0

Please sign in to comment.