Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cmvn #540

Merged
merged 8 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,8 @@ Functions to perform common audio operations.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: detect_pitch_frequency

:hidden:`sliding_window_cmn_internal`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: sliding_window_cmn_internal
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,10 @@ Transforms are common audio transforms. They can be chained together using :clas
.. autoclass:: Vol

.. automethod:: forward

:hidden:`SlidingWindowCmn`
mthrok marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SlidingWindowCmn

.. automethod:: forward
4 changes: 4 additions & 0 deletions test/test_torchscript_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,10 @@ def test_Vol(self):
waveform, _ = torchaudio.load(test_filepath)
self._assert_consistency(T.Vol(1.1), waveform)

def test_SlidingWindowCmn(self):
tensor = torch.rand((1000, 10))
self._assert_consistency(T.SlidingWindowCmn(), tensor)


class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on CPU"""
Expand Down
86 changes: 85 additions & 1 deletion torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
"riaa_biquad",
"biquad",
'mask_along_axis',
'mask_along_axis_iid'
'mask_along_axis_iid',
'sliding_window_cmn_internal',
]


Expand Down Expand Up @@ -1636,3 +1637,86 @@ def detect_pitch_frequency(
freq = freq.view(shape[:-1] + list(freq.shape[-1:]))

return freq


def sliding_window_cmn_internal(
mthrok marked this conversation as resolved.
Show resolved Hide resolved
waveform: Tensor,
cmvn_window: int = 600,
min_cmn_window: int = 100,
center: bool = False,
norm_vars: bool = False,
) -> Tensor:
r"""
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time)
cmvn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
center (bool, optional): If true, use a window centered on the current frame
(to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

Returns:
Tensor: Tensor of freq of dimension (..., frame)
"""
dtype = waveform.dtype
device = waveform.device
last_window_start = last_window_end = -1
num_frames, num_feats = waveform.shape
cur_sum = torch.zeros(num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_feats, dtype=dtype, device=device)
cmvn_waveform = torch.zeros(
num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames):
window_start = 0
window_end = 0
if center:
window_start = t - cmvn_window // 2
window_end = window_start + cmvn_window
else:
window_start = t - cmvn_window
window_end = t + 1
if window_start < 0:
window_end -= window_start
window_start = 0
if not center:
if window_end > t:
window_end = max(t + 1, min_cmn_window)
if window_end > num_frames:
window_start -= (window_end - num_frames)
window_end = num_frames
if window_start < 0:
window_start = 0
if last_window_start == -1:
input_part = waveform[window_start: window_end - window_start]
cur_sum += torch.sum(input_part, 0)
if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 0)[-1]
else:
if window_start > last_window_start:
frame_to_remove = waveform[last_window_start]
cur_sum -= frame_to_remove
if norm_vars:
cur_sumsq -= (frame_to_remove ** 2)
if window_end > last_window_end:
frame_to_add = waveform[last_window_end]
cur_sum += frame_to_add
if norm_vars:
cur_sumsq += (frame_to_add ** 2)
window_frames = window_end - window_start
last_window_start = window_start
last_window_end = window_end
cmvn_waveform[t] = waveform[t] - cur_sum / window_frames
if norm_vars:
if window_frames == 1:
cmvn_waveform[t] = torch.zeros(
num_feats, dtype=dtype, device=device)
else:
variance = cur_sumsq
variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2))
variance = torch.pow(variance, -0.5)
cmvn_waveform[t] *= variance
return cmvn_waveform
38 changes: 38 additions & 0 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'Fade',
'FrequencyMasking',
'TimeMasking',
'SlidingWindowCmn',
]


Expand Down Expand Up @@ -869,3 +870,40 @@ def forward(self, waveform: Tensor) -> Tensor:
waveform = F.gain(waveform, 10 * math.log10(self.gain))

return torch.clamp(waveform, -1, 1)


class SlidingWindowCmn(torch.nn.Module):
r"""
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

Args:
cmvn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
center (bool, optional): If true, use a window centered on the current frame
(to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
"""

def __init__(self,
cmvn_window: int = 600,
min_cmn_window: int = 100,
center: bool = False,
norm_vars: bool = False) -> None:
super().__init__()
self.cmvn_window = cmvn_window
self.min_cmn_window = min_cmn_window
self.center = center
self.norm_vars = norm_vars

def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).

Returns:
Tensor: Tensor of audio of dimension (..., time).
"""
cmvn_waveform = F.sliding_window_cmn_internal(
waveform, self.cmvn_window, self.min_cmn_window, self.center, self.norm_vars)
return cmvn_waveform