# Best Model with Attention

In [1]:
# importing required libraries for the notebook
import lightning as lt
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
import os
import torch
import wandb
import torch.nn as nn
from IPython.display import display
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import plotly.graph_objects as go
import plotly.express as px
from torchaudio.functional import edit_distance as edit_dist
import random
from language import *
from dataset_dataloader import *
from encoder_decoder import *

In [2]:
# know the accelerator available - NOT USED as we have switched to lightning
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# Defining the source and target languages and loading data

In [3]:
# define the source and target languages
TARGET = 'tam'
SOURCE = 'eng'

In [4]:
# load all the available data and print sample counts for each set
x_train, y_train = load_data(TARGET, 'train')
x_valid, y_valid = load_data(TARGET, 'valid')
x_test, y_test = load_data(TARGET, 'test')

print(f'Number of train samples = {len(x_train)}')
print(f'Number of valid samples = {len(x_valid)}')
print(f'Number of test samples = {len(x_test)}')

Number of train samples = 51200
Number of valid samples = 4096
Number of test samples = 4096


In [5]:
# create language objects for storing vocabulary, index2sym and sym2index
SRC_LANG = Language(SOURCE)
TAR_LANG = Language(TARGET)

# creating vocabulary using train data only
SRC_LANG.create_vocabulary(*(x_train))
TAR_LANG.create_vocabulary(*(y_train))

# generate mappings from characters to numbers and vice versa
SRC_LANG.generate_mappings()
TAR_LANG.generate_mappings()

# print the source and target vocabularies
print(f'Source Vocabulary Size = {len(SRC_LANG.symbols)}')
print(f'Source Vocabulary = {SRC_LANG.symbols}')
print(f'Source Mapping {SRC_LANG.index2sym}')
print(f'Target Vocabulary Size = {len(TAR_LANG.symbols)}')
print(f'Target Vocabulary = {TAR_LANG.symbols}')
print(f'Target Mapping {TAR_LANG.index2sym}')

Source Vocabulary Size = 26
Source Vocabulary = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Source Mapping {0: '@', 1: '$', 2: '!', 3: '%', 4: 'a', 5: 'b', 6: 'c', 7: 'd', 8: 'e', 9: 'f', 10: 'g', 11: 'h', 12: 'i', 13: 'j', 14: 'k', 15: 'l', 16: 'm', 17: 'n', 18: 'o', 19: 'p', 20: 'q', 21: 'r', 22: 's', 23: 't', 24: 'u', 25: 'v', 26: 'w', 27: 'x', 28: 'y', 29: 'z'}
Target Vocabulary Size = 46
Target Vocabulary = ['ஃ', 'அ', 'ஆ', 'இ', 'ஈ', 'உ', 'ஊ', 'எ', 'ஏ', 'ஐ', 'ஒ', 'ஓ', 'க', 'ங', 'ச', 'ஜ', 'ஞ', 'ட', 'ண', 'த', 'ந', 'ன', 'ப', 'ம', 'ய', 'ர', 'ற', 'ல', 'ள', 'ழ', 'வ', 'ஷ', 'ஸ', 'ஹ', 'ா', 'ி', 'ீ', 'ு', 'ூ', 'ெ', 'ே', 'ை', 'ொ', 'ோ', 'ௌ', '்']
Target Mapping {0: '@', 1: '$', 2: '!', 3: '%', 4: 'ஃ', 5: 'அ', 6: 'ஆ', 7: 'இ', 8: 'ஈ', 9: 'உ', 10: 'ஊ', 11: 'எ', 12: 'ஏ', 13: 'ஐ', 14: 'ஒ', 15: 'ஓ', 16: 'க', 17: 'ங', 18: 'ச', 19: 'ஜ', 20: 'ஞ', 21: 'ட', 22: 'ண', 23: 'த', 24: 'ந', 25: 'ன', 26: 'ப', 27: 'ம', 28: 'ய',

## Runner Class

In [6]:
class Runner(lt.LightningModule):
    def __init__(self, src_lang : Language, tar_lang : Language, common_embed_size, common_num_layers, 
                 common_hidden_size, common_cell_type, init_tf_ratio = 0.8, enc_bidirect=False, attention=False, dropout=0.0, 
                 opt_name='Adam', learning_rate=2e-3, batch_size=32):
    
        super(Runner,self).__init__()
        # save the language objects
        self.src_lang = src_lang
        self.tar_lang = tar_lang

        # create all the sub-networks and the main model
        self.encoder = EncoderNet(vocab_size=src_lang.get_size(), embed_size=common_embed_size,
                             num_layers=common_num_layers, hid_size=common_hidden_size,
                             cell_type=common_cell_type, bidirect=enc_bidirect, dropout=dropout)
        if attention:
            self.attention = True
            self.attn_layer = Attention(common_hidden_size, enc_bidirect)
        else:
            self.attention = False
            self.attn_layer = None
        
        self.decoder = DecoderNet(vocab_size=tar_lang.get_size(), embed_size=common_embed_size,
                             num_layers=common_num_layers, hid_size=common_hidden_size,
                             cell_type=common_cell_type, attention=attention, attn_layer=self.attn_layer,
                             enc_bidirect=enc_bidirect, dropout=dropout)
        
        self.model = EncoderDecoder(encoder=self.encoder, decoder=self.decoder, src_lang=src_lang, 
                                    tar_lang=tar_lang)

        # for determinism
        torch.manual_seed(42); torch.cuda.manual_seed(42); np.random.seed(42); random.seed(42)

        self.model.apply(self.init_weights) # initialize model weights
        self.batch_size = batch_size

        # optimizer for the model and loss function [that ignores locs where target = PAD token]
        self.loss_criterion = nn.CrossEntropyLoss(ignore_index=tar_lang.sym2index[PAD_SYM])
        self.opt_name = opt_name
        self.learning_rate = learning_rate

        # only adam is present in configure_optimizers as of now
        if (opt_name != 'Adam'):
            exit(-1)
        
        self.pred_train_words = []; self.true_train_words = []
        self.pred_valid_words = []; self.true_valid_words = []
        self.test_X_words = []; self.pred_test_words = []; self.true_test_words = []
        self.save_test_preds = False
        self.cur_tf_ratio = init_tf_ratio
        self.min_tf_ratio = 0.01
        self.attn_matrices = []  # used only when there is attention layer

    def configure_optimizers(self):
        optimizer = None
        if self.opt_name == 'Adam':
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer

    @staticmethod
    def init_weights(m):
        '''
        function to initialize the weights of the model parameters
        '''
        for name, param in m.named_parameters():
            if 'weight' in name:
                 nn.init.uniform_(param.data, -0.04, 0.04)
            else:
                nn.init.constant_(param.data, 0)
    
    @staticmethod
    def exact_accuracy(pred_words, tar_words):
        ''' 
        compute the accuracy using (predicted words, target words) and return it.
        exact word matching is used.
        '''
        assert(len(pred_words) == len(tar_words))
        count = 0
        for i in range(len(pred_words)):
            if pred_words[i] == tar_words[i]:
                count += 1
        return count / len(pred_words)
    
    ####################
    # DATA RELATED HOOKS
    ####################

    def setup(self, stage=None):
        # load all the available data on all GPUs
        self.x_train, self.y_train = load_data(TARGET, 'train')
        self.x_valid, self.y_valid = load_data(TARGET, 'valid')
        self.x_test, self.y_test = load_data(TARGET, 'test')

    def train_dataloader(self):
        dataset = TransliterateDataset(self.x_train, self.y_train, src_lang=SRC_LANG, tar_lang=TAR_LANG)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, collate_fn=CollationFunction(SRC_LANG, TAR_LANG))
        return dataloader

    def val_dataloader(self):
        dataset = TransliterateDataset(self.x_valid, self.y_valid, src_lang=SRC_LANG, tar_lang=TAR_LANG)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, collate_fn=CollationFunction(SRC_LANG, TAR_LANG))
        return dataloader

    def test_dataloader(self):
        dataset = TransliterateDataset(self.x_test, self.y_test, src_lang=SRC_LANG, tar_lang=TAR_LANG)
        dataloader = DataLoader(dataset=dataset, batch_size=1, collate_fn=CollationFunction(SRC_LANG, TAR_LANG))
        # we do inference word by word. So, batch_size = 1
        return dataloader

    ####################
    # INTERFACE RELATED FUNCTIONS - NOTE -> heatmap; beam decoding (and save top 3 preds);
    #                                       
    #                                       wandb sweeping stuff and model checkpointing;
    #                                       put all together
    ####################

    def training_step(self, train_batch, batch_idx):
        batch_X, batch_y, X_lens = train_batch
        # get the logits, preds for the current batch
        logits, preds = self.model(batch_X, batch_y, X_lens, tf_ratio=self.cur_tf_ratio)
        # ignore loss for the first time step
        targets = batch_y[:, 1:]; logits = logits[:, 1:, :]
        logits = logits.swapaxes(1, 2) # make class logits the second dimension as needed
        loss = self.loss_criterion(logits, targets)
        # for epoch-level metrics[accuracy], log all the required data
        self.true_train_words += self.tar_lang.convert_to_words(batch_y)
        self.pred_train_words += self.tar_lang.convert_to_words(preds)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def on_train_epoch_end(self):
        self.log('train_accuracy', self.exact_accuracy(self.pred_train_words, self.true_train_words), 
                 on_epoch=True, prog_bar=True)
        self.pred_train_words.clear(); self.true_train_words.clear()

        self.log('tf_ratio', self.cur_tf_ratio, on_epoch=True, prog_bar=True)
        # for first 12 epochs, we dont change the tf ratio. Then we decrease it by 0.1 every epoch till
        # min_tf_ratio is reached. This is also logged.
        if (self.current_epoch >= 11):
            self.cur_tf_ratio -= 0.1
            self.cur_tf_ratio = max(self.cur_tf_ratio, self.min_tf_ratio)

    def validation_step(self, valid_batch, batch_idx):
        batch_X, batch_y, X_lens = valid_batch
        # get the logits, preds for the current batch
        logits, preds = self.model(batch_X, batch_y, X_lens) # no teacher forcing
        # ignore loss for the first time step
        targets = batch_y[:, 1:]; logits = logits[:, 1:, :]
        logits = logits.swapaxes(1, 2) # make class logits the second dimension as needed
        loss = self.loss_criterion(logits, targets)
        # for epoch-level metrics[accuracy], log all the required data
        self.true_valid_words += self.tar_lang.convert_to_words(batch_y)
        self.pred_valid_words += self.tar_lang.convert_to_words(preds)
        self.log('validation_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    
    def on_validation_epoch_end(self):
        self.log('validation_accuracy', self.exact_accuracy(self.true_valid_words, self.pred_valid_words), 
                 on_epoch=True, prog_bar=True)
        self.true_valid_words.clear(); self.pred_valid_words.clear()
    
    def test_step(self, test_batch, batch_idx):
        batch_X, batch_y, X_lens = test_batch
        logits, pred_word, attn_matrix = self.model.greedy_inference(batch_X, X_lens)
        # update all the global lists
        self.pred_test_words += pred_word
        self.true_test_words += self.tar_lang.convert_to_words(batch_y)
        self.test_X_words += self.src_lang.convert_to_words(batch_X)
        # if there is attention, update the attention list also
        if (self.attention):
            self.attn_matrices += [attn_matrix]
        # ignore loss for the first time step
        targets = batch_y[:, 1:]; logits = logits[1:, :]
        # we shrink the logits to the true decoded sequence length for loss computation alone
        true_dec_len = targets.size(1)
        logits = (logits[:true_dec_len, :]).swapaxes(0,1).unsqueeze(0)
        # squeeze and swapping of dimensions is to meet condition needed by nn.CrossEntopyLoss()
        loss = self.loss_criterion(logits, targets)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    
    # will prevent clearing of global test lists on test epoch end
    def track_test_predictions(self):
        self.save_test_preds = True

    def on_test_epoch_end(self):
        self.log('test_accuracy', self.exact_accuracy(self.pred_test_words, self.true_test_words), 
                 on_epoch=True, prog_bar=True)
        if not self.save_test_preds:
            self.pred_test_words.clear(); self.true_test_words.clear(); self.test_X_words.clear()
            self.attn_matrices.clear()
    
    # here, we will save all the predictions made and also, return a copy of the list of attention
    # matrices for generating heatmaps
    def save_test_predictions(self, fname='pred'):
        edit_distances = [edit_dist(pred,tar) for pred, tar in zip(self.pred_test_words,self.true_test_words)]
        pred_df = pd.DataFrame(zip(self.test_X_words, self.true_test_words, self.pred_test_words, edit_distances),
                               columns=['Input', 'Target', 'Predicted', 'Levenshtein Distance'])
        pred_df.to_csv('./'+fname+'.csv', index=False, encoding='utf-8')

        # if attention layer is present, we return attention matrices as well.
        ret_info = None
        if self.attention:
            ret_info = (self.test_X_words.copy(), self.pred_test_words.copy(), self.attn_matrices.copy())
        self.save_test_preds = False
        # clear after saving to save memory 
        self.pred_test_words.clear(); self.true_test_words.clear(); self.test_X_words.clear()
        self.attn_matrices.clear()
        return ret_info

# Best Model (With Attention)

In [36]:
def generate_attention_heatmap(number, X_word, pred_word, attn_matrix :torch.Tensor):
    x_labels = list(pred_word) + ['<E>']
    y_labels = ['<S>'] + list(X_word) + ['<E>']
    # we ignore the 0th timestep as it was not generated by attn_layer
    attn_matrix = np.transpose(attn_matrix[1:,:].numpy())
    assert(len(x_labels) == attn_matrix.shape[1])
    assert(len(y_labels) == attn_matrix.shape[0])
    fig = px.imshow(attn_matrix, labels=dict(x="Predicted character", y="Source Character", color="Value"),
                    x=x_labels, y=y_labels, title=f'Src : {X_word}\n Pred : {pred_word}',
                    color_continuous_scale='greys')
    fig.update_xaxes(side="top")
    fig.show()
    # log both the confusion matrix from wandb.plot and plotly plot along with loss, acc
    # wandb.log({f"Attention Plot {number}" :  fig})

In [34]:

pred = 'அரோபிந்தோ'
source =  'aurobindo'
print(len(source), len(pred))
attn_matrix = torch.rand(11, 11)
print(attn_matrix)

9 9
tensor([[0.2008, 0.5093, 0.5218, 0.7471, 0.0690, 0.5574, 0.4257, 0.5440, 0.3209,
         0.3812, 0.8899],
        [0.6010, 0.7760, 0.6995, 0.2897, 0.6082, 0.0929, 0.4723, 0.7046, 0.0971,
         0.2403, 0.9353],
        [0.8981, 0.8716, 0.1455, 0.6104, 0.9677, 0.3442, 0.5847, 0.7872, 0.8383,
         0.6261, 0.5380],
        [0.4637, 0.7016, 0.0632, 0.2449, 0.3537, 0.9507, 0.5546, 0.1838, 0.4665,
         0.5429, 0.5275],
        [0.3641, 0.6584, 0.5259, 0.3637, 0.6762, 0.8266, 0.2758, 0.9560, 0.4595,
         0.5541, 0.3992],
        [0.7002, 0.8730, 0.0284, 0.6064, 0.9715, 0.8978, 0.0923, 0.5008, 0.2959,
         0.9820, 0.3799],
        [0.7114, 0.3798, 0.2799, 0.0328, 0.6672, 0.5033, 0.5355, 0.5494, 0.6346,
         0.5207, 0.2742],
        [0.8518, 0.1962, 0.3825, 0.7030, 0.1951, 0.7911, 0.2271, 0.0896, 0.9790,
         0.3448, 0.9414],
        [0.4732, 0.7089, 0.0048, 0.8510, 0.9119, 0.1475, 0.7496, 0.2593, 0.1893,
         0.0363, 0.3750],
        [0.2973, 0.1556, 0.0940, 

In [37]:
generate_attention_heatmap(1, source, pred, attn_matrix)