# QuartzNet model

In [195]:
import string
import json
from typing import Dict, Union
from collections import OrderedDict
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio

from datasets import load_metric

In [196]:
BATCH_SIZE = 32

## Data preprocessor

In [271]:
class ASRProcessor(object):
    def __init__(self, sampling_rate=16000, n_fft=1024, hop_length=512, n_mels=128):
        self.audio_preprocessor = torchaudio.transforms.MelSpectrogram(
            sample_rate=sampling_rate, 
            n_fft=n_fft, 
            hop_length=hop_length, 
            n_mels=n_mels
        )
        
        self.let2idx = [s for s in string.ascii_lowercase]
        self.let2idx.extend(["'", " ", "<PAD>"])
        self.vocab = {w: idx for idx, w in enumerate(self.let2idx)}

    def text_preprocessor(self, text):
        sym_tokenize = [s for s in text]
        
        return torch.LongTensor([self.vocab[symbol] for symbol in sym_tokenize])
    
    def __call__(self, input_values: torch.tensor, labels: str) -> Dict[str, Union[torch.tensor, torch.LongTensor]]:
        return {
            'input_features': self.audio_preprocessor(input_values), 
            'labels': self.text_preprocessor(labels)
        }
    
    def labels_decode(self, batched_labels):
        pad_idx = len(processor.vocab) - 1
        space_idx = len(processor.vocab) - 2
        decoding_labels = batched_labels.clone().detach()
        
        decoding_labels[decoding_labels == pad_idx] = space_idx
        return [''.join([self.let2idx[l] for l in bl]).strip() for bl in decoding_labels]

## Dataset

In [368]:
class LibriDataset(Dataset):
    def __init__(self, processor, root='', split='val', max_length=150):
        assert split in ['dev-clean', 'dev-other', 'test-clean', 'test-other', 'train-clean-100'], \
                'Split error!'
        
        self.data_iterator = torchaudio.datasets.LIBRISPEECH(root=root, url=split)
        
        self.processor = processor
            
    def __getitem__(self, idx):
        sample = self.data_iterator[idx]
        sample = self.processor(sample[0][0], sample[2].lower())
        
        return sample
    
    def __len__(self):
        return len(self.data_iterator)
    
    @staticmethod
    def batch_collate(batch):
        # Collate audio samples.  
        sample_tokens_lengths = torch.tensor([x['input_features'].size(1) for x in batch])
        max_len_per_samples = torch.max(sample_tokens_lengths)
        
        # Extend to even max_len.  
        max_len_per_samples += (max_len_per_samples % 2)
        lengths_to_pad = max_len_per_samples - sample_tokens_lengths - 1
        
        input_features = torch.stack([
            F.pad(x['input_features'], pad=(0, val_to_pad)) 
            for x, val_to_pad in zip(batch, lengths_to_pad)
        ])
        
        # Collate label samples.  
        label_tokens_lengths = torch.tensor([len(x['labels']) for x in batch])
        max_len_per_labels = torch.max(label_tokens_lengths)
        
        # Extend to even max_len.  
        max_len_per_labels += (max_len_per_labels % 2)
        lengths_to_pad = max_len_per_labels - label_tokens_lengths - 1

        labels = torch.vstack([
            F.pad(x['labels'], pad=(0, val_to_pad), value=len(processor.vocab) - 1) 
            for x, val_to_pad in zip(batch, lengths_to_pad)
        ]).type(torch.int64)

        return {
            'input_features': input_features, 
            'input_lengths': sample_tokens_lengths, 
            'labels': labels, 
            'labels_lengths': label_tokens_lengths
        }

In [369]:
processor = ASRProcessor()

In [370]:
dataset = {
    'train': LibriDataset(processor, split='train-clean-100'), 
    'val': LibriDataset(processor, split='dev-clean')
}

In [371]:
dataloaders = {
    k: DataLoader(dataset[k], batch_size=BATCH_SIZE, shuffle=True, 
                    collate_fn=LibriDataset.batch_collate, num_workers=1)
    for k in dataset.keys()
}

## Model (QuartzNet 5x5)

In [358]:
class SingleBBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation=True):
        super().__init__()
        
        padding = (kernel_size - 1) // 2
        self.depthwise = nn.Conv1d(
            in_channels, in_channels, kernel_size, 
            padding=padding, groups=in_channels
        )
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.batch_norm = nn.BatchNorm1d(num_features=out_channels)
        self.activation = activation
        
    def forward(self, x):
        TCS_out = self.pointwise(self.depthwise(x))
        
        bn_out = self.batch_norm(TCS_out)
        
        return F.relu(bn_out) if self.activation else bn_out

In [359]:
class RepeatedBBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, R):
        super().__init__()
        # The first block to match in_channels and out_channels
        self.B = [SingleBBlock(in_channels, out_channels, kernel_size, activation=True)]
        
        # BBlocks between the first and the last blocks.  
        self.B.extend([
            SingleBBlock(out_channels, out_channels, kernel_size, activation=True)
            for _ in range(R - 2)
        ])
        
        # The last block to prevent nonlinearity.  
        self.B.append(SingleBBlock(out_channels, out_channels, kernel_size, activation=False))
        self.B = nn.Sequential(*self.B)
        
        # Skip connection.  
        self.skip_connection = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1), 
            nn.BatchNorm1d(num_features=out_channels)
        )
        
    def forward(self, x):
        RBlocks_out = self.B(x)
        skip_out = self.skip_connection(x)
        
        return F.relu(RBlocks_out + skip_out)

In [360]:
class QuartzNet(nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()
        
        self.C1 = nn.Sequential(
            nn.Conv1d(in_channels=n_features, out_channels=256, kernel_size=33, stride=2), 
            nn.BatchNorm1d(num_features=256), 
            nn.ReLU()
        )
        self.B = nn.Sequential(
            OrderedDict([
                ('B1', RepeatedBBlocks(in_channels=256, out_channels=256, kernel_size=33, R=5)), 
                ('B2', RepeatedBBlocks(in_channels=256, out_channels=256, kernel_size=39, R=5)), 
                ('B3', RepeatedBBlocks(in_channels=256, out_channels=512, kernel_size=51, R=5)), 
                ('B4', RepeatedBBlocks(in_channels=512, out_channels=512, kernel_size=63, R=5)), 
                ('B5', RepeatedBBlocks(in_channels=512, out_channels=512, kernel_size=75, R=5))
            ])
        )
        self.C2 = nn.Sequential(
            nn.Conv1d(in_channels=512, out_channels=512, kernel_size=87, padding=43), 
            nn.BatchNorm1d(num_features=512), 
            nn.ReLU()
        )
        self.C3 = nn.Sequential(
            nn.Conv1d(in_channels=512, out_channels=1024, kernel_size=1), 
            nn.BatchNorm1d(num_features=1024), 
            nn.ReLU()
        )
        self.C4 = nn.Conv1d(in_channels=1024, out_channels=n_classes, kernel_size=1, dilation=2)
        
    def forward(self, x):
        first_conv_out = self.C1(x)
        b_out = self.B(first_conv_out)
        c2_out = self.C2(b_out)
        c3_out = self.C3(c2_out)
        c4_out = self.C4(c3_out)
        
        return c4_out

In [372]:
data_it = iter(dataloaders['train'])

In [380]:
sample = next(data_it)

In [381]:
sample['input_features'].shape

torch.Size([32, 128, 509])

In [382]:
sample['input_lengths']

tensor([260, 426, 260, 355, 429, 472, 391, 471, 475, 451, 451, 244, 351, 496,
        456,  71, 441, 463, 503, 462, 163, 316, 456, 460, 509, 388, 377, 469,
        501, 491, 178, 497])

In [321]:
model = QuartzNet(n_features=128, n_classes=len(ASRProcessor().vocab) - 1)

In [322]:
output = model(sample['input_features'])

In [323]:
logits = F.softmax(output, dim=1)

In [324]:
torch.argmax(logits, dim=1)

tensor([[11, 25, 26,  ..., 25, 11, 26],
        [16, 10,  6,  ..., 10, 24, 10],
        [24, 13, 26,  ..., 10, 10, 24],
        ...,
        [19, 26, 26,  ..., 10, 10, 10],
        [ 7, 10, 10,  ..., 24, 10, 10],
        [17, 19,  6,  ..., 24, 11, 10]])

## Metrics and validation

In [325]:
wer_metric = load_metric('wer')

In [326]:
def compute_metrics(y_pred_logits, y_true):
    pred_ids = torch.argmax(y_pred_logits, dim=1)
    
    pred_str = processor.labels_decode(pred_ids)
    label_str = processor.labels_decode(y_true)
    
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    
    return wer

In [327]:
logits_for_loss = torch.log(logits).permute(2, 0, 1)

In [328]:
loss = nn.CTCLoss(blank=len(processor.vocab) - 2)

In [329]:
compute_metrics(logits, sample['labels'])

1.0

In [331]:
logits_for_loss.shape

torch.Size([246, 32, 28])

In [332]:
sample['input_features'].shape

torch.Size([32, 128, 523])

In [337]:
loss(logits_for_loss, sample['labels'], sample['input_lengths'], sample['labels_lengths'])

RuntimeError: Expected tensor to have size at least 284 at dimension 1, but got size 283 for argument #2 'targets' (while checking arguments for ctc_loss_cpu)

In [336]:
sample['labels'].shape

torch.Size([32, 283])

In [186]:
next(data_it)['labels'].shape

torch.Size([32, 293])

In [None]:
def validate_model(model, val_dataloader, criterion, metrics, 
                    device=torch.device('cpu'), return_train=False):
    model.eval()
    running_loss = 0.0

    predictions = torch.FloatTensor([]).to(device)
    targets = torch.FloatTensor([]).to(device)
    with torch.inference_mode():
        print("\n")
        for batch_idx, sample in enumerate(val_dataloader):
            if batch_idx % 100 == 0 or batch_idx == len(val_dataloader) - 1:
                print(f"==> Batch: {batch_idx}/{len(val_dataloader)}")
            
            X = sample['input_features'].to(device)
            y_true = sample['labels'].to(device)
            
            y_pred = model(X)
            logits = F.log_softmax(y_pred)
            
            loss = criterion(logits, y_true)

            running_loss += loss.item()
            predictions = torch.cat((predictions, y_pred))
            targets = torch.cat((targets, y_true))

        all_metrics_score = metrics(predictions, targets)
        running_loss /= len(val_dataloader)
        
    if return_train:
        model.train()

    return running_loss, all_metrics_score