# QuartzNet + LM model

## TODO: 
* Look at the distribution of mel-spec and probably clip values
* Simple LSTM seq2seq model
* Beam Search
* cuda.amp
* (optional) BPE

In [1]:
import string
import re
import json
from typing import Dict, Union
from collections import OrderedDict, defaultdict
from IPython.display import clear_output
import copy
import gc
import random
from tqdm import tqdm
import sys
from functools import partial
from typing import List
from dataclasses import dataclass

import numpy as np
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torch.optim import AdamW

from datasets import load_metric

from bpemb import BPEmb

import matplotlib.pyplot as plt
%matplotlib inline
import librosa
import librosa.display

BPEMB_EN = BPEmb(lang="en", dim=300, vs=1000)
PAD_IDX = BPEMB_EN.EOS
DEVICE = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

## Set seed to all processes

In [2]:
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [3]:
set_seed(42)

## Data preprocessor

In [4]:
class Unsqueeze(nn.Module):
    def __init__(self, dim=0):
        super().__init__()
        self.dim = dim
    
    def forward(self, x):
        return x.unsqueeze(self.dim)

In [5]:
class ASRProcessor(object):
    def __init__(self, tokenizer, sampling_rate=16000, n_fft=1024, 
                 hop_length=256, n_mels=64, split='test'):
        self.mel_spec_processor = torchaudio.transforms.MelSpectrogram(
            sample_rate=sampling_rate, 
            n_fft=n_fft, 
            hop_length=hop_length, 
            n_mels=n_mels
        )
        
        if split == 'train':
            self.augmentation = nn.Sequential(
                Unsqueeze(), 
                torchaudio.transforms.TimeMasking(time_mask_param=5), 
                torchaudio.transforms.FrequencyMasking(freq_mask_param=5), 
            )
        else:
            self.augmentation = None
        
        self.tokenizer = tokenizer

    def text_processor(self, text):
        encoded_text = [self.tokenizer.BOS] + \
                       self.tokenizer.encode_ids(text) + \
                       [self.tokenizer.EOS]
        
        return torch.tensor(encoded_text)
    
    def __call__(self, input_values: torch.tensor, labels: str) -> Dict[str, Union[torch.tensor, torch.LongTensor]]:
        # Got such boundaries for 99.8% of non-augmented train data: [-10.7776, 6.4294].  
        log_mel_spec_image = torch.log(self.mel_spec_processor(input_values) + 1e-6).clamp_(-10, 6)
        input_preprocessed = self.augmentation(log_mel_spec_image).squeeze(0) \
                                if self.augmentation else log_mel_spec_image
        
        text_preprocessed = self.text_processor(labels)
        
        return {
            'input_features': input_preprocessed, 
            'labels': text_preprocessed
        }

## Dataset

In [6]:
class LibriDataset(Dataset):
    def __init__(self, processor, root='', split='dev-clean'):
        assert split in ['dev-clean', 'dev-other', 'test-clean', 'test-other', 'train-clean-100'], \
                'Split error!'
        
        self.data_iterator = torchaudio.datasets.LIBRISPEECH(root=root, url=split)
        
        self.processor = processor
    
    def __getitem__(self, idx):
        sample = self.data_iterator[idx]
        sample = self.processor(sample[0][0], sample[2].lower())
        
        return sample
    
    def __len__(self):
        return len(self.data_iterator)
    
    @staticmethod
    def normalize_batch(sample, input_lengths):
        input_features = sample['input_features']
        all_means = torch.zeros(input_features.size(0))
        all_stds = torch.zeros(input_features.size(0))

        for s_idx, s_len in enumerate(input_lengths):
            valid_features = input_features[s_idx, :, :s_len]
            sample_mean = torch.mean(valid_features)
            sample_std = torch.sqrt(torch.mean(valid_features**2) - sample_mean**2)

            input_features[s_idx, :, :s_len] = \
                (valid_features - sample_mean) / sample_std
        
        return input_features
    
    @staticmethod
    def batch_collate(batch, value_to_pad_tokens):
        # Collate audio samples.  
        sample_tokens_lengths = torch.tensor([x['input_features'].size(1) for x in batch])
        max_len_per_samples = torch.max(sample_tokens_lengths)

        # Extend to even max_len.  
        additive = (max_len_per_samples % 2)
        max_len_per_samples += additive
        samples_lengths_to_pad = max_len_per_samples - sample_tokens_lengths - additive

        input_features = torch.stack([
            F.pad(x['input_features'], pad=(0, val_to_pad)) 
            for x, val_to_pad in zip(batch, samples_lengths_to_pad)
        ])

        # Collate label samples.  
        label_tokens_lengths = torch.tensor([x['labels'].size(0) for x in batch])
        max_len_per_labels = torch.max(label_tokens_lengths)

        # Extend to even max_len.  
        additive = (max_len_per_labels % 2)
        max_len_per_labels += additive
        labels_lengths_to_pad = max_len_per_labels - label_tokens_lengths - additive

        labels = torch.vstack([
            F.pad(x['labels'], pad=(0, val_to_pad), value=value_to_pad_tokens) 
            for x, val_to_pad in zip(batch, labels_lengths_to_pad)
        ]).type(torch.int64)

        collated = {
            'input_features': input_features,
            'targets': labels,
            'attention_mask': (labels != PAD_IDX).type(torch.int32)
        }

        collated['input_features'] = LibriDataset.normalize_batch(
            collated, sample_tokens_lengths
        )

        return collated

In [7]:
@dataclass
class Data:
    datasets: Dict[str, LibriDataset]
    dataloaders: Dict[str, torch.utils.data.dataloader.DataLoader]

In [8]:
def set_split(batch_size=32, train_shuffle=False, collator=None):
    train_processor = ASRProcessor(BPEMB_EN, split='test')
    test_processor = ASRProcessor(BPEMB_EN, split='test')
    
    datasets = {
        'train': LibriDataset(train_processor, split='train-clean-100'), 
        'val': LibriDataset(test_processor, split='dev-clean'), 
        'test': LibriDataset(test_processor, split='test-clean')
    }
    
    dataloaders = {
        'train': DataLoader(datasets['train'], batch_size=batch_size, shuffle=train_shuffle, 
                                collate_fn=collator, num_workers=4), 
        'val': DataLoader(datasets['val'], batch_size=32, shuffle=False, 
                            collate_fn=collator, num_workers=4), 
        'test': DataLoader(datasets['test'], batch_size=32, shuffle=False, 
                            collate_fn=collator, num_workers=4)
    }
    
    return Data(datasets, dataloaders)

## Model

### QuartzNet 5x5

In [9]:
class SingleBBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, activation=True):
        super().__init__()
        
        # Padding 'same'.  
        padding = (kernel_size // 2) * dilation
        
        self.depthwise = nn.Conv1d(
            in_channels, in_channels, kernel_size, stride, 
            padding=padding, dilation=dilation, groups=in_channels
        )
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.batch_norm = nn.BatchNorm1d(num_features=out_channels)
        self.activation = activation
    
    def forward(self, x):
        TCS_out = self.pointwise(self.depthwise(x))
        
        bn_out = self.batch_norm(TCS_out)
        
        return F.relu(bn_out) if self.activation else bn_out

In [10]:
class RepeatedBBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, R=5):
        super().__init__()
        
        # The first block to match in_channels and out_channels.  
        self.B = [
            SingleBBlock(
                in_channels, out_channels, kernel_size, 
                stride, dilation, activation=True
            )
        ]
        
        # BBlocks between the first and the last blocks.  
        self.B.extend([
            SingleBBlock(
                out_channels, out_channels, kernel_size, 
                stride, dilation, activation=True
            )
            for _ in range(R - 2)
        ])
        
        # The last block to prevent nonlinearity.  
        self.B.append(
            SingleBBlock(
                out_channels, out_channels, kernel_size, 
                stride, dilation, activation=False
            )
        )
        self.B = nn.Sequential(*self.B)
        
        # Skip connection.  
        self.skip_connection = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1), 
            nn.BatchNorm1d(num_features=out_channels)
        )
    
    def forward(self, x):
        RBlocks_out = self.B(x)
        skip_out = self.skip_connection(x)
        
        return F.relu(RBlocks_out + skip_out)

In [11]:
class QuartzNet(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        
        self.C1 = SingleBBlock(
            in_channels=in_features, out_channels=256, kernel_size=33, 
            stride=2, dilation=1, activation=True
        )
        
        self.B = nn.Sequential(
            OrderedDict([
                ('B1', RepeatedBBlocks(
                    in_channels=256, out_channels=256, kernel_size=33, 
                    stride=1, dilation=1, R=5
                )), 
                ('B2', RepeatedBBlocks(
                    in_channels=256, out_channels=256, kernel_size=39, 
                    stride=1, dilation=1, R=5
                )), 
                ('B3', RepeatedBBlocks(in_channels=256, out_channels=512, kernel_size=51, 
                                       stride=1, dilation=1, R=5
                )), 
                ('B4', RepeatedBBlocks(in_channels=512, out_channels=512, kernel_size=63, 
                                       stride=1, dilation=1, R=5
                )), 
                ('B5', RepeatedBBlocks(in_channels=512, out_channels=512, kernel_size=75, 
                                       stride=1, dilation=1, R=5
                ))
            ])
        )
        self.C2 = SingleBBlock(
            in_channels=512, out_channels=512, kernel_size=87, 
            stride=1, dilation=2, activation=True
        )
        
        self.C3 = nn.Sequential(
            nn.Conv1d(in_channels=512, out_channels=1024, kernel_size=1, stride=1, dilation=1), 
            nn.BatchNorm1d(num_features=1024), 
            nn.ReLU()
        )
        self.C4 = nn.Sequential(
            nn.Conv1d(in_channels=1024, out_channels=out_features, kernel_size=1, stride=1, dilation=1), 
            nn.BatchNorm1d(num_features=out_features), 
            nn.ReLU()
        )
    
    def forward(self, x):
        first_conv_out = self.C1(x)

        b_out = self.B(first_conv_out)
        c2_out = self.C2(b_out)
        c3_out = self.C3(c2_out)
        c4_out = self.C4(c3_out)
        
        return c4_out

### Seq2seq

In [12]:
class Encoder(nn.Module):
    def __init__(self, 
        in_features, 
        hid_size, 
        num_layers=2
    ):
        super().__init__()
        
        self.acoustic_model = QuartzNet(in_features, hid_size)
        self.lstm_enc = nn.LSTM(
            input_size=hid_size, hidden_size=hid_size, 
            num_layers=num_layers, batch_first=True, dropout=0.2
        )
    
    def forward(self, x):
        acoustic_out = self.acoustic_model(x).permute(0, 2, 1)
        
        # x: [B, SEQ, H]
        _, (h, c) = self.lstm_enc(acoustic_out)
        
        return h[-1], c[-1]

In [13]:
class Decoder(nn.Module):
    def __init__(self,
        hid_size
    ):
        super().__init__()
        
        self.emb_out = nn.Embedding(num_embeddings=BPEMB_EN.vs, embedding_dim=BPEMB_EN.dim, 
                                    padding_idx=PAD_IDX)
        self.emb_out.weight.data.copy_(torch.from_numpy(BPEMB_EN.vectors))
        
        self.lstm_dec = nn.LSTMCell(input_size=BPEMB_EN.dim, hidden_size=hid_size)
        self.logits = nn.Linear(in_features=hid_size, out_features=BPEMB_EN.vs)
    
    def decode_step(self, h_prev, c_prev, cur_token):
        emb_target = self.emb_out(cur_token)
        
        hx, cx = self.lstm_dec(emb_target, (h_prev, c_prev))
        logits = F.log_softmax(self.logits(hx), dim=-1)
        
        return logits, (hx, cx)
        
    def forward(self, hx, cx, target):
        seq_first_target = target.T  # (S, B)
        
        predictions = []
        for curr_token in seq_first_target:
            logits, (hx, cx) = self.decode_step(hx, cx, curr_token)
            predictions.append(logits)
        
        return torch.stack(predictions)

### QuartzLM

In [14]:
class QuartzLM(nn.Module):
    def __init__(self, 
        in_features, hid_size, num_layers
    ):
        super().__init__()
        
        self.encoder = Encoder(in_features, hid_size, num_layers)
        self.decoder = Decoder(hid_size)
    
    def forward(self, inp, target):
        hx, cx = self.encoder(inp)
        pred_sequence = self.decoder(hx, cx, target)
        
        return pred_sequence.permute(1, 0, 2)

## Beam Search Experimental

In [39]:
temp_result = defaultdict(list)
            
for i in range(4):
    curr_token = torch.tensor([result['seq'][i, -1]])

    pred, states = decoder.decode_step(*result['states'][i], curr_token)

    top = torch.topk(-pred, k=4, dim=-1)

    temp_result['log_probs'].append(result['log_probs'] + top.values.squeeze())
    temp_result['seq'].append(top.indices[0])
    temp_result['states'].append(states)  # states = (h, c)

temp_result['log_probs'] = torch.stack(temp_result['log_probs'])
temp_result['seq'] = concat_seqs(result['seq'], torch.stack(temp_result['seq']))
top_ids = get_topk_in_matrix(
    temp_result['log_probs'], k=4
)
result['log_probs'] = temp_result['log_probs'][top_ids]
result['seq'] = temp_result['seq'][top_ids]
result['states'] = [temp_result['states'][idx]
                    for idx in top_ids[0]]


## Inference Language Model

In [36]:
def concat_raw_arrays(old_seq: torch.Tensor, new_seq: torch.Tensor):
    '''
    Parallel arrays concatenation:
    torch.tensor([1, 2])
    torch.tensor([5, 4, 3, 10])
    
    Result:
    tensor([[ 1,  2,  5],
        [ 1,  2,  4],
        [ 1,  2,  3],
        [ 1,  2, 10]])
    '''
    
    return torch.stack([
        torch.cat((old_seq, j)) for j in new_seq.unsqueeze(-1)
    ])

In [37]:
def concat_seqs(old_seq: torch.Tensor, new_seq: torch.Tensor):
    '''
    Parallel sequences concatenation.
    '''
    
    return torch.stack([
        concat_raw_arrays(o, n) for o, n in zip(old_seq, new_seq)
    ])

In [38]:
def get_topk_in_matrix(matrix: torch.Tensor, k: int):
    '''
    Search top k values in matrix
    
    Returns: indices of top k values in matrix.
    '''
    
    _, topk_ids = torch.topk(matrix.flatten(), k=4, dim=-1)
    
    return torch.div(topk_ids, 4, rounding_mode='floor'), topk_ids % 4

In [39]:
def start_beam_search(sample):
    '''
    Encode sample (melspec), predict the first token of sentence 
    after BOS token and return the next hidden states
    
    Returns: 
    dict of:
    * the top k sequnces,
    * the top k of their probabilities,
    * their hidden and cell states of LSTM
    '''
    
    result = {
        'seq': torch.full([4], BPEMB_EN.BOS, dtype=torch.int64),
        'log_probs': torch.zeros(4)
    }
    init_seq = torch.full([1], BPEMB_EN.BOS, dtype=torch.int64)
    encoder = Encoder(in_features=N_MELS, hid_size=300, num_layers=2)
    decoder = Decoder(hid_size=300)
    
    # Get initial state
    h0, c0 = encoder(sample)
    h0 = h0[0].unsqueeze(0)
    c0 = c0[0].unsqueeze(0)
    
    pred1, (h1, c1) = decoder.decode_step(h0, c0, init_seq)
    top1 = torch.topk(-pred1, k=4, dim=-1)
    
    result['seq'] = torch.vstack((result['seq'], top1.indices[0])).T
    result['log_probs'] += top1.values[0]
    result['states'] = [(h1, c1)] * 4
    
    return result

In [40]:
@torch.inference_mode()
class InferenceDecoder:
    def __init__(self, 
        language_model: QuartzLM,
        max_len: int = 100
    ):
        super().__init__()
        
        self.encoder = language_model.encoder
        self.decoder = language_model.decoder
        
        self.max_len = max_len
    
    def greedy_decoding(self, sample):
        batch_size = sample.size(0)
        outputs = [torch.full([batch_size], BPEMB_EN.BOS, 
                              dtype=torch.int64, device=DEVICE)]
        
        # Get initial state
        hx, cx = self.encoder(sample)
        for _ in range(self.max_len):
            logits, (hx, cx) = self.decoder.decode_step(hx, cx, outputs[-1])
            outputs.append(logits.argmax(dim=-1))
            
        return torch.stack(outputs, dim=1)
    
    def start_beam_search(self, h0, c0, beam_size):
        '''
        Encode sample (melspec), predict the first token of sentence 
        after BOS token and return the next hidden states

        Returns: 
        dict of:
        * the top k sequnces,
        * the top k of their probabilities,
        * their hidden and cell states of LSTM
        '''

        result = {
            'seq': torch.full([beam_size], BPEMB_EN.BOS, dtype=torch.int64),
            'log_probs': torch.zeros(beam_size)
        }
        init_seq = torch.tensor([BPEMB_EN.BOS], dtype=torch.int64)

        pred1, (h1, c1) = self.decoder.decode_step(h0, c0, init_seq)
        top1 = torch.topk(pred1, k=beam_size, dim=-1)

        result['seq'] = torch.vstack((result['seq'], top1.indices[0])).T
        result['log_probs'] -= top1.values[0]
        result['states'] = [(h1, c1)] * beam_size

        return result
    
    def beam_search_loop(self, result, beam_size):
        # Max len excluding BOS and start_iteration
        for _ in range(self.max_len - 2):
            temp_result = defaultdict(list)
            
            for i in range(beam_size):
                curr_token = torch.tensor([result['seq'][i, -1]])

                pred, states = self.decoder.decode_step(*result['states'][i], curr_token)

                top = torch.topk(pred, k=beam_size, dim=-1)

                temp_result['log_probs'].append(result['log_probs'] - top.values.squeeze())
                temp_result['seq'].append(top.indices[0])
                temp_result['states'].append(states)  # states = (h, c)

            temp_result['log_probs'] = torch.stack(temp_result['log_probs'])
            temp_result['seq'] = self.concat_seqs(result['seq'], torch.stack(temp_result['seq']))
            top_ids = get_topk_in_matrix(
                temp_result['log_probs'], k=beam_size
            )
            result['log_probs'] = temp_result['log_probs'][top_ids]
            result['seq'] = temp_result['seq'][top_ids]
            result['states'] = [temp_result['states'][idx]
                                for idx in top_ids[0]]
            
        high_p_idx = torch.argmax(result['log_probs'])
            
        return result['seq'][high_p_idx]
    
    def beam_search_decoding(self, sample, beam_size=4):
        batch_size = sample.size(0)
        
        # Get initial state
        h0, c0 = self.encoder(sample)
        predictions = torch.zeros(batch_size, self.max_len, dtype=torch.int64)
        for batch_idx in range(batch_size):
            hx = h0[batch_idx].unsqueeze(0)
            cx = c0[batch_idx].unsqueeze(0)
            
            result = self.start_beam_search(hx, cx, beam_size)
            predictions[batch_idx] = self.beam_search_loop(result, beam_size)
        
        return predictions
    
    def translate_lines(self, input_lines):
        result_str = []
        for i in input_lines:
            bpe_format_str = \
                ''.join([BPEMB_EN.emb.index_to_key[j] for j in input_lines[i][1:]])
            result_str.append(' '.join(bpe_format_str.split('▁')).strip())
        
        return result_str

## Load from checkpoint

In [16]:
def load_from_checkpoint(checkpoint_path, device=torch.device('cpu')):
    checkpoint = torch.load(checkpoint_path, map_location=device)

    epoch = checkpoint['epoch']
    model = checkpoint['model_architecture'].to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer = checkpoint['optimizer']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    history = checkpoint['whole_history']

    if checkpoint['scheduler']:
        scheduler = checkpoint['scheduler']
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    return {
        'model': model, 
        'optimizer': optimizer, 
        'scheduler': scheduler if checkpoint['scheduler'] else None,
        'history': history,
        'epoch': epoch
    }

## Metrics and validation

In [17]:
@dataclass
class MetricsOut:
    pred_str: str
    target_str: str
    wer: float

In [18]:
class Metrics:
    def __init__(self):
        super().__init__()
        self.wer_metric = load_metric('wer')
        self.processor = ASRProcessor(BPEMB_EN)
    
    def __call__(self, y_pred, y_true):
        """
        Forward of labels prediction:
        :param y_pred: log_softmax(output) of the model, shape: (T, B, C), 
        :param y_true: true labels, shape (B, T)
        """
        pred_ids = torch.argmax(y_pred, dim=2).T
    
        pred_str = self.processor.labels_decode(pred_ids, apply_ctc=True)
        label_str = self.processor.labels_decode(y_true, apply_ctc=False)
        
        wer = self.wer_metric.compute(predictions=pred_str, references=label_str)

        return MetricsOut(pred_str, label_str, wer)

In [19]:
@dataclass
class ValidateOut:
    loss: float
    metrics: float

In [20]:
def validate_model(model, val_dataloader, criterion, metrics, 
                    device=torch.device('cpu'), return_train=False):
    model.eval()
    running_loss = 0.0
    running_score = 0.0
    
    with torch.inference_mode():
        print("\n")
        for batch_idx, sample in enumerate(val_dataloader):
            if batch_idx % 10 == 0 or batch_idx == len(val_dataloader) - 1:
                print(f"==> Batch: {batch_idx}/{len(val_dataloader)}")
            
            sample = {k: v.to(device) for k, v in sample.items()}
            
            y_pred = model(sample['input_features'])
            sample['log_probs'] = F.log_softmax(y_pred, dim=1).permute(2, 0, 1)
            del sample['input_features']
            
            loss = criterion(**sample)

            running_loss += loss.item()
            running_score += metrics(sample['log_probs'], sample['targets'])

        running_loss /= len(val_dataloader)
        running_score /= len(val_dataloader)
        
    if return_train:
        model.train()

    return ValidateOut(running_loss, running_score)

## Train function

In [21]:
def train(model, dataloaders, criterion, optimizer, metrics, scheduler=None, 
          num_epochs=5, start_epoch=-1, prev_metrics=dict(), device=torch.device('cpu'),
          folder_for_checkpoints='/'):
    for key, vals in prev_metrics.items():
        for val in vals:
            wandb.log({key :val[1]}, step=val[0])

    if len(prev_metrics) > 0:
        history = copy.deepcopy(prev_metrics)
        curr_step = prev_metrics['train_loss'][-1][0] + 1
    else:
        history = defaultdict(list)
        curr_step = 1

    model.train()
    for epoch in range(start_epoch + 1, start_epoch + 1 + num_epochs):
        running_loss = 0.0
        running_score = 0.0

        clear_output(True)

        print("-" * 20)
        print(f"Epoch: {epoch}/{start_epoch + num_epochs}")
        print("-" * 20)
        print("Train: ")

        for batch_idx, sample in enumerate(tqdm(dataloaders['train'])):            
            sample = {k: v.to(device) for k, v in sample.items()}
            
            y_pred = model(sample['input_features'])
            sample['log_probs'] = F.log_softmax(y_pred, dim=1).permute(2, 0, 1)
            del sample['input_features']
            
            loss = criterion(**sample)
            optimizer.zero_grad()

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            if scheduler:
                scheduler.step()
            
            running_loss += loss.item()
            
            model_out = metrics(sample['log_probs'], sample['targets'])
            running_score += model_out.wer

        val_result = validate_model(model, dataloaders['val'], criterion, 
                                    metrics, device, return_train=True)
        
        val_loss = val_result.loss
        val_metrics = val_result.metrics
        
        wandb.log({'train_loss': running_loss / (batch_idx + 1)}, step=curr_step)
        wandb.log('val_loss', val_loss, step=curr_step)
        history['train_loss'].append((curr_step, running_loss / (batch_idx + 1)))
        history['val_loss'].append((curr_step, val_loss))

        wandb.log({'val_wer': val_metrics}, step=curr_step)
        history['val_wer'].append((curr_step, val_metrics))

        wandb.log({'train_wer': running_score / (batch_idx + 1)}, step=curr_step)
        history['train_wer'].append((curr_step, running_score / (batch_idx + 1)))

        curr_step += 1

        state = {
            'epoch': epoch,
            'batch_size_training': dataloaders['train'].batch_size,
            'model_architecture': model,
            'model_state_dict': model.state_dict(),
            'optimizer': optimizer,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler': scheduler if scheduler else None,
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'whole_history': history
        }

        torch.save(state, folder_for_checkpoints + f'checkpoint_epoch_{epoch%5 + 1}.pt')

## Model configuration

In [22]:
PROCESSOR = ASRProcessor(BPEMB_EN)
BATCH_SIZE = 32
LR = 1e-3
NUM_EPOCHS = 50
CHECKPOINT_PATH = "ASR_checkpoints/"
N_MELS = 64

## Init new model

In [23]:
try:
    del optimizer
    del model
    torch.cuda.empty_cache()
    gc.collect()
except:
    pass

model = QuartzLM(in_features=N_MELS, hid_size=300, num_layers=2).to(DEVICE)

optimizer = AdamW(model.parameters(), lr=LR)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
START_EPOCH = -1
history = dict()

## Load model from checkpoint

In [None]:
check = load_from_checkpoint('ASR_checkpoints/checkpoint_epoch_3.pt', DEVICE)
model = check['model']
model.register_backward_hook(backward_hook)

optimizer = check['optimizer']
scheduler = check['scheduler']
history = check['history']

START_EPOCH = check['epoch']

In [24]:
metrics = Metrics()
criterion = nn.CrossEntropyLoss()
batch_collator = partial(LibriDataset.batch_collate, value_to_pad_tokens=PAD_IDX)

## Split data

In [25]:
data = set_split(
    batch_size=BATCH_SIZE, train_shuffle=True, collator=batch_collator
)

## Model inference

In [26]:
inference_dec = InferenceDecoder(model)

In [30]:
s = next(iter(data.dataloaders['val']))
sample = s['input_features']

In [43]:
pred0 = inference_dec.beam_search_decoding(sample)

In [44]:
pred1 = inference_dec.beam_search_decoding(sample)

In [47]:
pred0

tensor([[  1, 143, 690,  ..., 376, 801, 773],
        [  1, 143, 690,  ..., 376, 801, 773],
        [  1, 143, 690,  ..., 114, 550, 773],
        ...,
        [  1, 143, 690,  ..., 643, 643, 898],
        [  1, 143, 690,  ..., 643, 643, 898],
        [  1, 143, 690,  ..., 643, 643, 898]])

In [48]:
pred1

tensor([[  1, 143, 690,  ..., 957, 107, 532],
        [  1, 865, 593,  ...,   5, 122, 928],
        [  1, 143, 690,  ..., 643, 643, 898],
        ...,
        [  1, 143, 690,  ..., 643, 643, 898],
        [  1, 143, 690,  ..., 376, 801, 773],
        [  1, 143, 690,  ..., 550, 168, 865]])

In [None]:
inference_dec.greedy_decoding(sample)