Skip to content

Commit

Permalink
wav norm might be upstream specific
Browse files Browse the repository at this point in the history
  • Loading branch information
leo19941227 committed May 24, 2021
1 parent a92684c commit 40a173c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 24 deletions.
1 change: 0 additions & 1 deletion downstream/runner.py
Expand Up @@ -80,7 +80,6 @@ def _get_upstream(self):
ckpt = self.args.upstream_ckpt,
model_config = self.args.upstream_model_config,
refresh = upstream_refresh,
wav_normalize = self.args.wav_normalize,
).to(self.args.device)

if is_initialized() and get_rank() == 0:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -28,3 +28,4 @@ pysndfx
nltk
normalise
editdistance
omegaconf
1 change: 0 additions & 1 deletion run_downstream.py
Expand Up @@ -73,7 +73,6 @@ def get_downstream_args():
parser.add_argument('--seed', default=1337, type=int)
parser.add_argument('--device', default='cuda', help='model.to(device)')
parser.add_argument('--cache_dir', help='The cache directory for pretrained model downloading')
parser.add_argument('--wav_normalize', action='store_true', help='Zero mean and unit variance normalization on waveform')
parser.add_argument('--verbose', action='store_true', help='Print model infomation')

args = parser.parse_args()
Expand Down
21 changes: 0 additions & 21 deletions upstream/interfaces.py
Expand Up @@ -36,7 +36,6 @@ def __call__(cls, *args, **kwargs):
class UpstreamBase(nn.Module, metaclass=initHook):
def __init__(
self,
wav_normalize: bool = False,
hooks: List[Tuple] = None,
hook_postprocess: Callable[
[List[Tuple[str, Tensor]]], List[Tuple[str, Tensor]]
Expand All @@ -48,8 +47,6 @@ def __init__(
hooks: each Tuple is an argument list for the Hook initializer
"""
super().__init__()
self.wav_normalize = wav_normalize

self.hooks: List[Hook] = [Hook(*hook) for hook in hooks] if hooks else []
self.hook_postprocess = hook_postprocess
self._hook_hiddens: List[Tuple(str, Tensor)] = []
Expand Down Expand Up @@ -109,25 +106,7 @@ def tolist(paired_wavs: List[Tensor], paired_feature: Tensor):
feature = [f[:l] for f, l in zip(paired_feature, feature_len)]
return feature

@staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray]) -> List[np.ndarray]:
"""
Every array in the list is normalized to have zero mean and unit variance
Taken from huggingface to ensure the same behavior across s3prl and huggingface
Reference: https://github.com/huggingface/transformers/blob/a26f4d620874b32d898a5b712006a4c856d07de1/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L81-L86
"""
return [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in input_values]

def __call__(self, wavs: List[Tensor], *args, **kwargs):
if self.wav_normalize:
device = wavs[0].device
wavs = [
torch.from_numpy(wav).to(device)
for wav in self.zero_mean_unit_var_norm(
[wav.cpu().numpy() for wav in wavs]
)
]

result = super().__call__(wavs, *args, **kwargs) or {}
assert isinstance(result, dict)

Expand Down
26 changes: 25 additions & 1 deletion upstream/wav2vec2/expert.py
Expand Up @@ -7,11 +7,16 @@
"""*********************************************************************************************"""


import argparse
from typing import List
from packaging import version

import torch
from torch.nn.utils.rnn import pad_sequence
import fairseq
import numpy as np
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from omegaconf.dictconfig import DictConfig

from upstream.interfaces import UpstreamBase

Expand All @@ -24,6 +29,12 @@ def __init__(self, ckpt, **kwargs):
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt])
self.model = model[0]

if isinstance(cfg, argparse.Namespace):
normalize = cfg.normalize
elif isinstance(cfg, DictConfig):
normalize = cfg.task.normalize
self.wav_normalize = normalize

if len(self.hooks) == 0:
module_name = "self.model.encoder.layers"
for module_id in range(len(eval(module_name))):
Expand All @@ -33,8 +44,21 @@ def __init__(self, ckpt, **kwargs):
)
self.add_hook("self.model.encoder", lambda input, output: output)

@staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray]) -> List[np.ndarray]:
"""
Every array in the list is normalized to have zero mean and unit variance
Taken from huggingface to ensure the same behavior across s3prl and huggingface
Reference: https://github.com/huggingface/transformers/blob/a26f4d620874b32d898a5b712006a4c856d07de1/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L81-L86
"""
return [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in input_values]

def forward(self, wavs):
device = wavs[0].device
if self.wav_normalize:
wavs = self.zero_mean_unit_var_norm([wav.cpu().numpy() for wav in wavs])
wavs = [torch.from_numpy(wav).to(device) for wav in wavs]

wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
Expand Down

0 comments on commit 40a173c

Please sign in to comment.