Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port sox::vad #578

Merged
merged 24 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,8 @@ Functions to perform common audio operations.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: sliding_window_cmn

:hidden:`vad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: vad
11 changes: 9 additions & 2 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Transforms are common audio transforms. They can be chained together using :clas
:hidden:`GriffinLim`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GriffinLim
.. autoclass:: GriffinLim

.. automethod:: forward

Expand Down Expand Up @@ -128,10 +128,17 @@ Transforms are common audio transforms. They can be chained together using :clas
.. autoclass:: Vol

.. automethod:: forward

:hidden:`SlidingWindowCmn`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SlidingWindowCmn

.. automethod:: forward

:hidden:`Vad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Vad

.. automethod:: forward
Binary file added test/assets/vad-hello-mono-32000.wav
Binary file not shown.
Binary file added test/assets/vad-hello-stereo-44100.wav
Binary file not shown.
5 changes: 5 additions & 0 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_sliding_window_cmn(self):
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=False)

def test_vad(self):
filepath = common_utils.get_asset_path("vad-hello-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.vad, waveform, sample_rate=sample_rate)


class TestTransforms(unittest.TestCase):
"""Test suite for classes defined in `transforms` module"""
Expand Down
18 changes: 18 additions & 0 deletions test/test_sox_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,24 @@ def test_vol(self):
# check if effect worked
self.assertTrue(x.allclose(z, rtol=1e-4, atol=1e-4))

def test_vad(self):
sample_files = [
common_utils.get_asset_path("vad-hello-stereo-44100.wav"),
common_utils.get_asset_path("vad-hello-mono-32000.wav")
]

for sample_file in sample_files:
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(sample_file)
E.append_effect_to_chain("vad")
x, _ = E.sox_build_flow_effects()

x_orig, sample_rate = torchaudio.load(sample_file)
vad = torchaudio.transforms.Vad(sample_rate)

y = vad(x_orig)
self.assertTrue(x.allclose(y, rtol=1e-4, atol=1e-4))


if __name__ == '__main__':
with AudioBackendScope("sox"):
Expand Down
5 changes: 5 additions & 0 deletions test/test_torchscript_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,11 @@ def test_SlidingWindowCmn(self):
tensor = torch.rand((1000, 10))
self._assert_consistency(T.SlidingWindowCmn(), tensor)

def test_Vad(self):
filepath = common_utils.get_asset_path("vad-hello-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)


class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on CPU"""
Expand Down
297 changes: 297 additions & 0 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
'mask_along_axis',
'mask_along_axis_iid',
'sliding_window_cmn',
'vad',
]


Expand Down Expand Up @@ -1836,3 +1837,299 @@ def sliding_window_cmn(
if len(input_shape) == 2:
cmn_waveform = cmn_waveform.squeeze(0)
return cmn_waveform


def _measure(
measure_len_ws: int,
samples: Tensor,
spectrum: Tensor,
noise_spectrum: Tensor,
spectrum_window: Tensor,
spectrum_start: int,
spectrum_end: int,
cepstrum_window: Tensor,
cepstrum_start: int,
cepstrum_end: int,
noise_reduction_amount: float,
measure_smooth_time_mult: float,
noise_up_time_mult: float,
noise_down_time_mult: float,
index_ns: int,
boot_count: int
) -> float:

assert spectrum.size()[-1] == noise_spectrum.size()[-1]

samplesLen_ns = samples.size()[-1]
dft_len_ws = spectrum.size()[-1]

dftBuf = torch.zeros(dft_len_ws)

_index_ns = torch.tensor([index_ns] + [
(index_ns + i) % samplesLen_ns
for i in range(1, measure_len_ws)
])
dftBuf[:measure_len_ws] = \
samples[_index_ns] * spectrum_window[:measure_len_ws]

# memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
dftBuf[measure_len_ws:dft_len_ws].zero_()

# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
_dftBuf = torch.rfft(dftBuf, 1)

# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_()

mult: float = boot_count / (1. + boot_count) \
if boot_count >= 0 \
else measure_smooth_time_mult

_d = complex_norm(_dftBuf[spectrum_start:spectrum_end])
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
_d = spectrum[spectrum_start:spectrum_end] ** 2

_zeros = torch.zeros(spectrum_end - spectrum_start)
_mult = _zeros \
if boot_count >= 0 \
else torch.where(
_d > noise_spectrum[spectrum_start:spectrum_end],
torch.tensor(noise_up_time_mult), # if
torch.tensor(noise_down_time_mult) # else
)

noise_spectrum[spectrum_start:spectrum_end].mul_(_mult).add_(_d * (1 - _mult))
_d = torch.sqrt(
torch.max(
_zeros,
_d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end]))

_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1)
_cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_()

# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
_cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1)

result: float = float(torch.sum(
complex_norm(
_cepstrum_Buf[cepstrum_start:cepstrum_end],
power=2.0)))
result = \
math.log(result / (cepstrum_end - cepstrum_start)) \
if result > 0 \
else -math.inf
return max(0, 21 + result)


def vad(
waveform: Tensor,
sample_rate: int,
trigger_level: float = 7.0,
trigger_time: float = 0.25,
search_time: float = 1.0,
allowed_gap: float = 0.25,
pre_trigger_time: float = 0.0,
# Fine-tuning parameters
boot_time: float = .35,
noise_up_time: float = .1,
noise_down_time: float = .01,
noise_reduction_amount: float = 1.35,
measure_freq: float = 20.0,
measure_duration: Optional[float] = None,
measure_smooth_time: float = .4,
hp_filter_freq: float = 50.,
lp_filter_freq: float = 6000.,
hp_lifter_freq: float = 150.,
lp_lifter_freq: float = 2000.,
) -> Tensor:
r"""Voice Activity Detector. Similar to SoX implementation.
Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
The algorithm currently uses a simple cepstral power measurement to detect voice,
so may be fooled by other things, especially music.

The effect can trim only from the front of the audio,
so in order to trim from the back, the reverse effect must also be used.

Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
sample_rate (int): Sample rate of audio signal.
trigger_level (float, optional): The measurement level used to trigger activity detection.
This may need to be cahnged depending on the noise level, signal level,
and other characteristics of the input audio. (Default: 7.0)
trigger_time (float, optional): The time constant (in seconds)
used to help ignore short bursts of sound. (Default: 0.25)
search_time (float, optional): The amount of audio (in seconds)
to search for quieter/shorter bursts of audio to include prior
to the detected trigger point. (Default: 1.0)
allowed_gap (float, optional): The allowed gap (in seconds) between
quiteter/shorter bursts of audio to include prior
to the detected trigger point. (Default: 0.25)
pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
boot_time (float, optional) The algorithm (internally) uses adaptive noise
estimation/reduction in order to detect the start of the wanted audio.
This option sets the time for the initial noise estimate. (Default: 0.35)
noise_up_time (float, optional) Time constant used by the adaptive noise estimator
for when the noise level is increasing. (Default: 0.1)
noise_down_time (float, optional) Time constant used by the adaptive noise estimator
for when the noise level is decreasing. (Default: 0.01)
noise_reduction_amount (float, optional) Amount of noise reduction to use in
the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
measure_freq (float, optional) Frequency of the algorithm’s
processing/measurements. (Default: 20.0)
measure_duration: (float, optional) Measurement duration.
(Default: Twice the measurement period; i.e. with overlap.)
measure_smooth_time (float, optional) Time constant used to smooth
spectral measurements. (Default: 0.4)
hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
at the input to the detector algorithm. (Default: 50.0)
lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
at the input to the detector algorithm. (Default: 6000.0)
hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
in the detector algorithm. (Default: 150.0)
lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
in the detector algorithm. (Default: 2000.0)

Returns:
Tensor: Tensor of audio of dimension (..., time).

References:
http://sox.sourceforge.net/sox.html
"""

measure_duration: float = 2.0 / measure_freq \
if measure_duration is None \
else measure_duration

measure_len_ws = int(sample_rate * measure_duration + .5)
measure_len_ns = measure_len_ws
# for (dft_len_ws = 16; dft_len_ws < measure_len_ws; dft_len_ws <<= 1);
dft_len_ws = 16
while (dft_len_ws < measure_len_ws):
dft_len_ws *= 2

measure_period_ns = int(sample_rate / measure_freq + .5)
measures_len = math.ceil(search_time * measure_freq)
search_pre_trigger_len_ns = measures_len * measure_period_ns
gap_len = int(allowed_gap * measure_freq + .5)

fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + .5)
samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns

spectrum_window = torch.zeros(measure_len_ws)
for i in range(measure_len_ws):
# sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32)
spectrum_window[i] = 2. / math.sqrt(float(measure_len_ws))
# lsx_apply_hann(spectrum_window, (int)measure_len_ws);
spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float)

spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + .5)
spectrum_start: int = max(spectrum_start, 1)
spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + .5)
spectrum_end: int = min(spectrum_end, dft_len_ws // 2)

cepstrum_window = torch.zeros(spectrum_end - spectrum_start)
for i in range(spectrum_end - spectrum_start):
cepstrum_window[i] = 2. / math.sqrt(float(spectrum_end) - spectrum_start)
# lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float)

cepstrum_start = math.ceil(sample_rate * .5 / lp_lifter_freq)
cepstrum_end = math.floor(sample_rate * .5 / hp_lifter_freq)
cepstrum_end = min(cepstrum_end, dft_len_ws // 4)

assert cepstrum_end > cepstrum_start

noise_up_time_mult = math.exp(-1. / (noise_up_time * measure_freq))
noise_down_time_mult = math.exp(-1. / (noise_down_time * measure_freq))
measure_smooth_time_mult = math.exp(-1. / (measure_smooth_time * measure_freq))
trigger_meas_time_mult = math.exp(-1. / (trigger_time * measure_freq))

boot_count_max = int(boot_time * measure_freq - .5)
measure_timer_ns = measure_len_ns
boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0

# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])

n_channels, ilen = waveform.size()

mean_meas = torch.zeros(n_channels)
samples = torch.zeros(n_channels, samplesLen_ns)
spectrum = torch.zeros(n_channels, dft_len_ws)
noise_spectrum = torch.zeros(n_channels, dft_len_ws)
measures = torch.zeros(n_channels, measures_len)

has_triggered: bool = False
num_measures_to_flush: int = 0
pos: int = 0

while (pos < ilen and not has_triggered):
measure_timer_ns -= 1
for i in range(n_channels):
samples[i, samplesIndex_ns] = waveform[i, pos]
# if (!p->measure_timer_ns) {
if (measure_timer_ns == 0):
index_ns: int = \
(samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns
meas: float = _measure(
measure_len_ws=measure_len_ws,
samples=samples[i],
spectrum=spectrum[i],
noise_spectrum=noise_spectrum[i],
spectrum_window=spectrum_window,
spectrum_start=spectrum_start,
spectrum_end=spectrum_end,
cepstrum_window=cepstrum_window,
cepstrum_start=cepstrum_start,
cepstrum_end=cepstrum_end,
noise_reduction_amount=noise_reduction_amount,
measure_smooth_time_mult=measure_smooth_time_mult,
noise_up_time_mult=noise_up_time_mult,
noise_down_time_mult=noise_down_time_mult,
index_ns=index_ns,
boot_count=boot_count)
measures[i, measures_index] = meas
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1. - trigger_meas_time_mult)

has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
if has_triggered:
n: int = measures_len
k: int = measures_index
jTrigger: int = n
jZero: int = n
j: int = 0

for j in range(n):
if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
jZero = jTrigger = j
elif (measures[i, k] == 0) and (jTrigger >= jZero):
jZero = j
k = (k + n - 1) % n
j = min(j, jZero)
# num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
num_measures_to_flush = (min(max(num_measures_to_flush, j), n))
# end if has_triggered
# end if (measure_timer_ns == 0):
# end for
samplesIndex_ns += 1
pos += 1
# end while
if samplesIndex_ns == samplesLen_ns:
samplesIndex_ns = 0
if measure_timer_ns == 0:
measure_timer_ns = measure_period_ns
measures_index += 1
measures_index = measures_index % measures_len
if boot_count >= 0:
boot_count = -1 if boot_count == boot_count_max else boot_count + 1

if has_triggered:
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns

res = waveform[:, pos - samplesLen_ns + flushedLen_ns:]
# unpack batch
return res.view(shape[:-1] + res.shape[-1:])