In [2]:
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from dataclasses import dataclass
import torch.nn.functional as F
import numpy as np



In [3]:


@dataclass
class ASRconfig:
    sample_rate: int = 16000
    n_fft: int = 400
    win_length: int = 400
    hop_length: int = 160
    n_mels: int = 80
    center: bool = True
    time_mask_param: int = 30
    freq_mask_param: int = 15
    model_dim: int = 256
    feedforward_dim: int = 1024
    dropout: float = 0.1
    num_heads: int = 4
    num_layers: int = 6
    encoder_normalize_first: bool = True
    max_len: int = 5000
    vocab_size: int = 1000



In [7]:

class AudioFeatureExtractor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.MelSpec = T.MelSpectrogram(
            sample_rate=config.sample_rate,
            n_fft=config.n_fft,
            win_length=config.win_length,
            hop_length=config.hop_length,
            n_mels=config.n_mels,
            normalized=False,
            center=config.center
        )
        self.log = lambda x: torch.log1p(x)
    
    def forward(self, x, lengths):
        
        print("zeros in speech signal : ")
        print([(u==0).sum() for u in x])
        x = self.MelSpec(x)
        x = x.transpose(1, 2)

        if self.config.center:
            frame_lengths = 1 + lengths // self.config.hop_length
        else:
            frame_lengths = 1 + (lengths - self.config.win_length) // self.config.hop_length

        print("frame lengths : ", frame_lengths)
        max_len = x.size(1)
        range_ = torch.arange(max_len, device=x.device)
        mask = range_[None, :] <= frame_lengths[:, None]
        print(mask)
        mask = mask.unsqueeze(-1)
        
        x = torch.where(mask, x, torch.zeros_like(x))
        for i, utt in enumerate(x):
            print("utterance : ", i)
            print(utt)
            print(utt.shape)
            print(lengths[i])
            print(frame_lengths[i])
            print("number of zero frames : ", (utt==0).all(dim=1).sum().item())
            print("number of zero features : ", (utt==0).all(dim=0).sum().item())
            arr = utt.numpy()
            np.savetxt(f"{i}_utterance_mel.txt", arr)

        print("zeros in MelSpec : ")
        print([(u==0).sum() for u in x])
        x = self.log(x)
        print("zeros in MelSpec after log : ")
        print([(u==0).sum() for u in x])
        return x, frame_lengths


In [5]:

class UtteranceMVN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
    
    def forward(self, x, lengths):
        normalized_x = []        
        for utt, l in zip(x, lengths):
            print(utt.shape)
            print(utt)
            mean = utt[:l, :].mean(dim=0, keepdims=True)
            std = utt[:l, :].std(dim=0, keepdims=True)
            print("std shape : ", std.shape)
            print("std : ", std)
            print("zeros in std : ", (std == 0).sum())
            utt[:l, :] = (utt[:l, :] - mean) / std
            normalized_x.append(utt)
        
        normalized_x = torch.stack(normalized_x)
        return normalized_x

In [8]:

config = ASRconfig()
extractor = AudioFeatureExtractor(config)
normalizer = UtteranceMVN(config)
input1 = torch.cat([torch.randn(8000), torch.zeros(4000), torch.randn(4000)])
# print(input1.shape)
input2 = torch.randn(8000)
input = [input1, input2]
input = torch.nn.utils.rnn.pad_sequence(input, batch_first=True, padding_value=0.0)
out, out_frame_lengths = extractor(input, torch.tensor([16000, 8000]))

out = normalizer(out, out_frame_lengths)
# print(out.shape)

zeros in speech signal : 
[tensor(4000), tensor(8000)]
frame lengths :  tensor([101,  51])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True, 