# QuartzNet model

In [6]:
import string
import json
from typing import Dict, Union
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio

In [7]:
BATCH_SIZE = 32

## Data preprocessor

In [147]:
class ASRProcessor(object):
    def __init__(self, sampling_rate, n_fft=1024, hop_length=512, n_mels=128):
        self.audio_preprocessor = torchaudio.transforms.MelSpectrogram(
            sample_rate=sampling_rate, 
            n_fft=n_fft, 
            hop_length=hop_length, 
            n_mels=n_mels
        )
        
        self.let2idx = [s for s in string.ascii_lowercase]
        self.let2idx.extend([" ", "'", "<UNK>", "<PAD>"])
        self.vocab = {w: idx for idx, w in enumerate(self.let2idx)}

    def text_preprocessor(self, text):
        sym_tokenize = [s for s in text]
        
        return torch.LongTensor([self.vocab[symbol] for symbol in sym_tokenize])
    
    def __call__(self, input_values: torch.tensor, labels: str) -> Dict[str, Union[torch.tensor, torch.LongTensor]]:
        return {
            'input_features': self.audio_preprocessor(input_values), 
            'labels': self.text_preprocessor(labels)
        }
    
    def labels_decode(self, labels):
        return ''.join([self.let2idx[l] for l in labels])

## Dataset

In [176]:
class LibriDataset(Dataset):
    def __init__(self, root='', split='val', max_length=150):
        assert split in ['dev-clean', 'dev-other', 'test-clean', 'test-other', 'train-clean-100'], \
                'Split error!'
        
        self.data_iterator = torchaudio.datasets.LIBRISPEECH(root=root, url=split)
        
        self.processor = ASRProcessor(sampling_rate=16000)
            
    def __getitem__(self, idx):
        sample = self.data_iterator[idx]
        sample = self.processor(sample[0][0], sample[2].lower())
        
        return sample
    
    def __len__(self):
        return len(self.data_iterator)
    
    @staticmethod
    def batch_collate(batch):
        # Collate audio samples.  
        batch_tokens_lengths = torch.tensor([x['input_features'].size(1) for x in batch])
        max_len_per_batch = torch.max(batch_tokens_lengths)
        
        # Extend to even max_len.  
        max_len_per_batch += (max_len_per_batch % 2)
        lengths_to_pad = max_len_per_batch - batch_tokens_lengths - 1
        
        input_features = torch.stack([
            F.pad(x['input_features'], pad=(0, val_to_pad)) 
            for x, val_to_pad in zip(batch, lengths_to_pad)
        ])
        
        # Collate label samples.  
        batch_tokens_lengths = torch.tensor([len(x['labels']) for x in batch])
        max_len_per_batch = torch.max(batch_tokens_lengths)
        
        # Extend to even max_len.  
        max_len_per_batch += (max_len_per_batch % 2)
        lengths_to_pad = max_len_per_batch - batch_tokens_lengths - 1

        labels = torch.vstack([
            F.pad(x['labels'], pad=(0, val_to_pad), value=29) 
            for x, val_to_pad in zip(batch, lengths_to_pad)
        ]).type(torch.int64)

        return {
            'input_features': input_features, 
            'labels': labels
        }
    
    def labels_decode(self, labels):
        return self.processor.labels_decode(labels)

In [177]:
dataset = {
    'train': LibriDataset(split='train-clean-100'), 
    'val': LibriDataset(split='dev-clean')
}

In [178]:
dataloaders = {
    k: DataLoader(dataset[k], batch_size=BATCH_SIZE, shuffle=False, 
                    collate_fn=LibriDataset.batch_collate, num_workers=1)
    for k in dataset.keys()
}

## Model

In [None]:
class QuartzNet(nn.Module):
    def __init__(self):
        super().__init__()
        