# QuartzNet model

In [92]:
import string
import json
from typing import Dict, Union
from collections import OrderedDict
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio

In [2]:
BATCH_SIZE = 32

## Data preprocessor

In [97]:
class ASRProcessor(object):
    def __init__(self, sampling_rate, n_fft=1024, hop_length=512, n_mels=128):
        self.audio_preprocessor = torchaudio.transforms.MelSpectrogram(
            sample_rate=sampling_rate, 
            n_fft=n_fft, 
            hop_length=hop_length, 
            n_mels=n_mels
        )
        
        self.let2idx = [s for s in string.ascii_lowercase]
        self.let2idx.extend([" ", "'", "<EOS>", "<PAD>"])
        self.vocab = {w: idx for idx, w in enumerate(self.let2idx)}

    def text_preprocessor(self, text):
        sym_tokenize = [s for s in text]
        
        return torch.LongTensor([self.vocab[symbol] for symbol in sym_tokenize])
    
    def __call__(self, input_values: torch.tensor, labels: str) -> Dict[str, Union[torch.tensor, torch.LongTensor]]:
        return {
            'input_features': self.audio_preprocessor(input_values), 
            'labels': self.text_preprocessor(labels)
        }
    
    def labels_decode(self, labels):
        return ''.join([self.let2idx[l] for l in labels])

## Dataset

In [4]:
class LibriDataset(Dataset):
    def __init__(self, root='', split='val', max_length=150):
        assert split in ['dev-clean', 'dev-other', 'test-clean', 'test-other', 'train-clean-100'], \
                'Split error!'
        
        self.data_iterator = torchaudio.datasets.LIBRISPEECH(root=root, url=split)
        
        self.processor = ASRProcessor(sampling_rate=16000)
            
    def __getitem__(self, idx):
        sample = self.data_iterator[idx]
        sample = self.processor(sample[0][0], sample[2].lower())
        
        return sample
    
    def __len__(self):
        return len(self.data_iterator)
    
    @staticmethod
    def batch_collate(batch):
        # Collate audio samples.  
        batch_tokens_lengths = torch.tensor([x['input_features'].size(1) for x in batch])
        max_len_per_batch = torch.max(batch_tokens_lengths)
        
        # Extend to even max_len.  
        max_len_per_batch += (max_len_per_batch % 2)
        lengths_to_pad = max_len_per_batch - batch_tokens_lengths - 1
        
        input_features = torch.stack([
            F.pad(x['input_features'], pad=(0, val_to_pad)) 
            for x, val_to_pad in zip(batch, lengths_to_pad)
        ])
        
        # Collate label samples.  
        batch_tokens_lengths = torch.tensor([len(x['labels']) for x in batch])
        max_len_per_batch = torch.max(batch_tokens_lengths)
        
        # Extend to even max_len.  
        max_len_per_batch += (max_len_per_batch % 2)
        lengths_to_pad = max_len_per_batch - batch_tokens_lengths - 1

        labels = torch.vstack([
            F.pad(x['labels'], pad=(0, val_to_pad), value=29) 
            for x, val_to_pad in zip(batch, lengths_to_pad)
        ]).type(torch.int64)

        return {
            'input_features': input_features, 
            'labels': labels
        }
    
    def labels_decode(self, labels):
        return self.processor.labels_decode(labels)

In [5]:
dataset = {
    'train': LibriDataset(split='train-clean-100'), 
    'val': LibriDataset(split='dev-clean')
}

In [6]:
dataloaders = {
    k: DataLoader(dataset[k], batch_size=BATCH_SIZE, shuffle=True, 
                    collate_fn=LibriDataset.batch_collate, num_workers=1)
    for k in dataset.keys()
}

## Model (QuartzNet 5x5)

In [82]:
class SingleBBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation=True):
        super().__init__()
        
        padding = (kernel_size - 1) // 2
        self.depthwise = nn.Conv1d(
            in_channels, in_channels, kernel_size, 
            padding=padding, groups=in_channels
        )
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.batch_norm = nn.BatchNorm1d(num_features=out_channels)
        self.activation = activation
        
    def forward(self, x):
        TCS_out = self.pointwise(self.depthwise(x))
        
        bn_out = self.batch_norm(TCS_out)
        
        return F.relu(bn_out) if self.activation else bn_out

In [86]:
class RepeatedBBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, R):
        super().__init__()
        # The first block to match in_channels and out_channels
        self.B = [SingleBBlock(in_channels, out_channels, kernel_size, activation=True)]
        
        # BBlocks between the first and the last blocks.  
        self.B.extend([
            SingleBBlock(out_channels, out_channels, kernel_size, activation=True)
            for _ in range(R - 2)
        ])
        
        # The last block to prevent nonlinearity.  
        self.B.append(SingleBBlock(out_channels, out_channels, kernel_size, activation=False))
        self.B = nn.Sequential(*self.B)
        
        # Skip connection.  
        self.skip_connection = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1), 
            nn.BatchNorm1d(num_features=out_channels)
        )
        
    def forward(self, x):
        RBlocks_out = self.B(x)
        skip_out = self.skip_connection(x)
        
        return F.relu(RBlocks_out + skip_out)

In [126]:
class QuartzNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.C1 = nn.Sequential(
            nn.Conv1d(in_channels=128, out_channels=256, kernel_size=33, stride=2), 
            nn.BatchNorm1d(num_features=256), 
            nn.ReLU()
        )
        self.B = nn.Sequential(
            OrderedDict([
                ('B1', RepeatedBBlocks(in_channels=256, out_channels=256, kernel_size=33, R=5)), 
                ('B2', RepeatedBBlocks(in_channels=256, out_channels=256, kernel_size=39, R=5)), 
                ('B3', RepeatedBBlocks(in_channels=256, out_channels=512, kernel_size=51, R=5)), 
                ('B4', RepeatedBBlocks(in_channels=512, out_channels=512, kernel_size=63, R=5)), 
                ('B5', RepeatedBBlocks(in_channels=512, out_channels=512, kernel_size=75, R=5))
            ])
        )
        self.C2 = nn.Sequential(
            nn.Conv1d(in_channels=512, out_channels=512, kernel_size=87, padding=43), 
            nn.BatchNorm1d(num_features=512), 
            nn.ReLU()
        )
        self.C3 = nn.Sequential(
            nn.Conv1d(in_channels=512, out_channels=1024, kernel_size=1), 
            nn.BatchNorm1d(num_features=1024), 
            nn.ReLU()
        )
        self.C4 = nn.Conv1d(in_channels=1024, out_channels=29, kernel_size=1, dilation=2)
        
    def forward(self, x):
        first_conv_out = self.C1(x)
        b_out = self.B(first_conv_out)
        c2_out = self.C2(b_out)
        c3_out = self.C3(c2_out)
        c4_out = self.C4(c3_out)
        
        return c4_out

In [129]:
data_it = iter(dataloaders['train'])

In [134]:
sample = next(data_it)['input_features']

In [135]:
model = QuartzNet()

In [136]:
output = model(sample)

In [137]:
output.shape

torch.Size([32, 29, 241])