Skip to content

Commit

Permalink
Add flanger to functional.py (#651)
Browse files Browse the repository at this point in the history
* Add flanger to functional

Signed-off-by: Bhargav Kathivarapu <bhargavkathivarapu31@gmail.com>

* Add random seed

Signed-off-by: Bhargav Kathivarapu <bhargavkathivarapu31@gmail.com>

* fix flanger

Signed-off-by: Bhargav Kathivarapu <bhargavkathivarapu31@gmail.com>

* shape

* Change bool arguments to strings

Signed-off-by: Bhargav Kathivarapu <bhargavkathivarapu31@gmail.com>

* Refactor tests

Signed-off-by: Bhargav Kathivarapu <bhargavkathivarapu31@gmail.com>

Co-authored-by: Vincent QB <vincentqb@users.noreply.github.com>
  • Loading branch information
bhargavkathivarapu and vincentqb committed Jun 2, 2020
1 parent 2ed97b6 commit 9e27cf3
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ Functions to perform common audio operations.

.. autofunction:: phaser

:hidden:`flanger`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: flanger

:hidden:`mask_along_axis`
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def test_phaser(self):
waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies(F.phaser, waveform, sample_rate)

def test_flanger(self):
torch.random.manual_seed(40)
waveform = torch.rand(2, 100) - 0.5
sample_rate = 44100
self.assert_batch_consistencies(F.flanger, waveform, sample_rate)

def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
Expand Down
96 changes: 96 additions & 0 deletions test/test_sox_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,102 @@ def test_phaser_triangle(self):

self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_triangle_linear(self):
"""
Test flanger effect with triangle modulation and linear interpolation, compare to SoX implementation
"""
delay = 0.6
depth = 0.87
regen = 3.0
width = 0.9
speed = 0.5
phase = 30
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("flanger", [delay, depth, regen, width, speed, "triangle", phase, "linear"])
sox_output_waveform, sr = E.sox_build_flow_effects()

waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.flanger(waveform, sample_rate, delay, depth, regen, width, speed, phase,
modulation='triangular', interpolation='linear')

torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_triangle_quad(self):
"""
Test flanger effect with triangle modulation and quadratic interpolation, compare to SoX implementation
"""
delay = 0.8
depth = 0.88
regen = 3.0
width = 0.4
speed = 0.5
phase = 40
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("flanger", [delay, depth, regen, width, speed, "triangle", phase, "quadratic"])
sox_output_waveform, sr = E.sox_build_flow_effects()

waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.flanger(waveform, sample_rate, delay, depth, regen, width, speed, phase,
modulation='triangular', interpolation='quadratic')

torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_sine_linear(self):
"""
Test flanger effect with sine modulation and linear interpolation, compare to SoX implementation
"""
delay = 0.8
depth = 0.88
regen = 3.0
width = 0.23
speed = 1.3
phase = 60
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("flanger", [delay, depth, regen, width, speed, "sine", phase, "linear"])
sox_output_waveform, sr = E.sox_build_flow_effects()

waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.flanger(waveform, sample_rate, delay, depth, regen, width, speed, phase,
modulation='sinusoidal', interpolation='linear')

torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_sine_quad(self):
"""
Test flanger effect with sine modulation and quadratic interpolation, compare to SoX implementation
"""
delay = 0.9
depth = 0.9
regen = 4.0
width = 0.23
speed = 1.3
phase = 25
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("flanger", [delay, depth, regen, width, speed, "sine", phase, "quadratic"])
sox_output_waveform, sr = E.sox_build_flow_effects()

waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.flanger(waveform, sample_rate, delay, depth, regen, width, speed, phase,
modulation='sinusoidal', interpolation='quadratic')

torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_equalizer(self):
Expand Down
17 changes: 17 additions & 0 deletions test/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,23 @@ def func(tensor):

self._assert_consistency(func, waveform)

def test_flanger(self):
torch.random.manual_seed(40)
waveform = torch.rand(2, 100) - 0.5

def func(tensor):
delay = 0.8
depth = 0.88
regen = 3.0
width = 0.23
speed = 1.3
phase = 60.
sample_rate = 44100
return F.flanger(tensor, sample_rate, delay, depth, regen, width, speed,
phase, modulation='sinusoidal', interpolation='linear')

self._assert_consistency(func, waveform)


class Transforms(common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
Expand Down
151 changes: 151 additions & 0 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"dcshift",
"overdrive",
"phaser",
"flanger",
'mask_along_axis',
'mask_along_axis_iid',
'sliding_window_cmn',
Expand Down Expand Up @@ -1357,6 +1358,156 @@ def _generate_wave_table(
return d


def flanger(
waveform: Tensor,
sample_rate: int,
delay: float = 0.,
depth: float = 2.,
regen: float = 0.,
width: float = 71.,
speed: float = 0.5,
phase: float = 25.,
modulation: str = 'sinusoidal',
interpolation: str = 'linear'
) -> Tensor:
r"""Apply a flanger effect to the audio. Similar to SoX implementation.
Args:
waveform (Tensor): audio waveform of dimension of `(..., channel, time)` .
Max 4 channels allowed
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
delay (float): desired delay in milliseconds(ms)
Allowed range of values are 0 to 30
depth (float): desired delay depth in milliseconds(ms)
Allowed range of values are 0 to 10
regen (float): desired regen(feeback gain) in dB
Allowed range of values are -95 to 95
width (float): desired width(delay gain) in dB
Allowed range of values are 0 to 100
speed (float): modulation speed in Hz
Allowed range of values are 0.1 to 10
phase (float): percentage phase-shift for multi-channel
Allowed range of values are 0 to 100
modulation (str): Use either "sinusoidal" or "triangular" modulation. (Default: ``sinusoidal``)
interpolation (str): Use either "linear" or "quadratic" for delay-line interpolation. (Default: ``linear``)
Returns:
Tensor: Waveform of dimension of `(..., channel, time)`
References:
http://sox.sourceforge.net/sox.html
Scott Lehman, Effects Explained,
https://web.archive.org/web/20051125072557/http://www.harmony-central.com/Effects/effects-explained.html
"""

if modulation not in ('sinusoidal', 'triangular'):
raise ValueError("Only 'sinusoidal' or 'triangular' modulation allowed")

if interpolation not in ('linear', 'quadratic'):
raise ValueError("Only 'linear' or 'quadratic' interpolation allowed")

actual_shape = waveform.shape
device, dtype = waveform.device, waveform.dtype

if actual_shape[-2] > 4:
raise ValueError("Max 4 channels allowed")

# convert to 3D (batch, channels, time)
waveform = waveform.view(-1, actual_shape[-2], actual_shape[-1])

# Scaling
feedback_gain = regen / 100
delay_gain = width / 100
channel_phase = phase / 100
delay_min = delay / 1000
delay_depth = depth / 1000

n_channels = waveform.shape[-2]

if modulation == 'sinusoidal':
wave_type = 'SINE'
else:
wave_type = 'TRIANGLE'

# Balance output:
in_gain = 1. / (1 + delay_gain)
delay_gain = delay_gain / (1 + delay_gain)

# Balance feedback loop:
delay_gain = delay_gain * (1 - abs(feedback_gain))

delay_buf_length = int((delay_min + delay_depth) * sample_rate + 0.5)
delay_buf_length = delay_buf_length + 2

delay_bufs = torch.zeros(waveform.shape[0], n_channels, delay_buf_length, dtype=dtype, device=device)
delay_last = torch.zeros(waveform.shape[0], n_channels, dtype=dtype, device=device)

lfo_length = int(sample_rate / speed)

lfo = torch.zeros(lfo_length, dtype=dtype, device=device)

table_min = math.floor(delay_min * sample_rate + 0.5)
table_max = delay_buf_length - 2.

lfo = _generate_wave_table(wave_type=wave_type,
data_type='FLOAT',
table_size=lfo_length,
min=float(table_min),
max=float(table_max),
phase=3 * math.pi / 2)

output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)

delay_buf_pos = 0
lfo_pos = 0
channel_idxs = torch.arange(0, n_channels)

for i in range(waveform.shape[-1]):

delay_buf_pos = (delay_buf_pos + delay_buf_length - 1) % delay_buf_length

cur_channel_phase = (channel_idxs * lfo_length * channel_phase + .5).to(torch.int64)
delay_tensor = lfo[(lfo_pos + cur_channel_phase) % lfo_length]
frac_delay = torch.frac(delay_tensor)
delay_tensor = torch.floor(delay_tensor)

int_delay = delay_tensor.to(torch.int64)

temp = waveform[:, :, i]

delay_bufs[:, :, delay_buf_pos] = temp + delay_last * feedback_gain

delayed_0 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]

int_delay = int_delay + 1

delayed_1 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]

int_delay = int_delay + 1

if interpolation == 'linear':
delayed = delayed_0 + (delayed_1 - delayed_0) * frac_delay
else:
delayed_2 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]

int_delay = int_delay + 1

delayed_2 = delayed_2 - delayed_0
delayed_1 = delayed_1 - delayed_0
a = delayed_2 * .5 - delayed_1
b = delayed_1 * 2 - delayed_2 * .5

delayed = delayed_0 + (a * frac_delay + b) * frac_delay

delay_last = delayed
output_waveform[:, :, i] = waveform[:, :, i] * in_gain + delayed * delay_gain

lfo_pos = (lfo_pos + 1) % lfo_length

return output_waveform.clamp(min=-1, max=1).view(actual_shape)


def mask_along_axis_iid(
specgrams: Tensor,
mask_param: int,
Expand Down

0 comments on commit 9e27cf3

Please sign in to comment.