# Sweeping Code (for both attention and no attention)

In [None]:
# importing required libraries for the notebook
import lightning as lt
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torchaudio.functional import edit_distance as edit_dist
import random
import wandb
from language import *
from dataset_dataloader import *
from encoder_decoder import *

In [None]:
# know the accelerator available - NOT USED as we have switched to lightning
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Defining the source and target languages and loading data

In [None]:
# define the source and target languages
TARGET = 'tam'
SOURCE = 'eng'

In [None]:
# load all the available data and print sample counts for each set
x_train, y_train = load_data(TARGET, 'train')
x_valid, y_valid = load_data(TARGET, 'valid')
x_test, y_test = load_data(TARGET, 'test')

print(f'Number of train samples = {len(x_train)}')
print(f'Number of valid samples = {len(x_valid)}')
print(f'Number of test samples = {len(x_test)}')

In [None]:
# create language objects for storing vocabulary, index2sym and sym2index
SRC_LANG = Language(SOURCE)
TAR_LANG = Language(TARGET)

# creating vocabulary using train data only
SRC_LANG.create_vocabulary(*(x_train))
TAR_LANG.create_vocabulary(*(y_train))

# generate mappings from characters to numbers and vice versa
SRC_LANG.generate_mappings()
TAR_LANG.generate_mappings()

# print the source and target vocabularies
print(f'Source Vocabulary Size = {len(SRC_LANG.symbols)}')
print(f'Source Vocabulary = {SRC_LANG.symbols}')
print(f'Source Mapping {SRC_LANG.index2sym}')
print(f'Target Vocabulary Size = {len(TAR_LANG.symbols)}')
print(f'Target Vocabulary = {TAR_LANG.symbols}')
print(f'Target Mapping {TAR_LANG.index2sym}')

## Runner Class

In [None]:
'''
    Pytorch lightning based module that encapsulates our seq2seq model with useful
    helper functions
'''
class Runner(lt.LightningModule):
    def __init__(self, src_lang : Language, tar_lang : Language, common_embed_size, common_num_layers, 
                 common_hidden_size, common_cell_type, init_tf_ratio = 0.8, enc_bidirect=False, attention=False, dropout=0.0, 
                 opt_name='Adam', learning_rate=2e-3, batch_size=32):
    
        super(Runner,self).__init__()
        # save the language objects
        self.src_lang = src_lang
        self.tar_lang = tar_lang
        # create all the sub-networks and the main model
        self.encoder = EncoderNet(vocab_size=src_lang.get_size(), embed_size=common_embed_size,
                             num_layers=common_num_layers, hid_size=common_hidden_size,
                             cell_type=common_cell_type, bidirect=enc_bidirect, dropout=dropout)
        if attention:
            self.attention = True
            self.attn_layer = Attention(common_hidden_size, enc_bidirect)
        else:
            self.attention = False
            self.attn_layer = None
        
        self.decoder = DecoderNet(vocab_size=tar_lang.get_size(), embed_size=common_embed_size,
                             num_layers=common_num_layers, hid_size=common_hidden_size,
                             cell_type=common_cell_type, attention=attention, attn_layer=self.attn_layer,
                             enc_bidirect=enc_bidirect, dropout=dropout)
        
        self.model = EncoderDecoder(encoder=self.encoder, decoder=self.decoder, src_lang=src_lang, 
                                    tar_lang=tar_lang)

        # for determinism
        torch.manual_seed(42); torch.cuda.manual_seed(42); np.random.seed(42); random.seed(42)

        self.model.apply(self.init_weights) # initialize model weights
        self.batch_size = batch_size

        # optimizer for the model and loss function [that ignores locs where target = PAD token]
        self.loss_criterion = nn.CrossEntropyLoss(ignore_index=tar_lang.sym2index[PAD_SYM])
        self.opt_name = opt_name
        self.learning_rate = learning_rate

        # only adam is present in configure_optimizers as of now
        if (opt_name != 'Adam'):
            exit(-1)

        self.save_test_preds = False # true if we want to save test predictions and not clear them on test epoch end
        self.cur_tf_ratio = init_tf_ratio # the current epoch teacher forcing ratio
        self.min_tf_ratio = 0.01          # minimum allowed teacher forcing ratio

        # lists for tracking predictions/true words etc...
        self.pred_train_words = []; self.true_train_words = []
        self.pred_valid_words = []; self.true_valid_words = []
        self.test_X_words = []; self.pred_test_words = []; self.true_test_words = []
        self.attn_matrices = []  # used only when there is attention layer

        # lists for tracking losses
        self.train_losses = []
        self.valid_losses = []

        # dictionary for logging at end of val epoch
        self.wdb_logged_metrics = dict()
        self.best_val_acc_seen = -0.01 # to save model weights on wandb

    def configure_optimizers(self):
        optimizer = None
        if self.opt_name == 'Adam':
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer

    @staticmethod
    def init_weights(m):
        '''
        function to initialize the weights of the model parameters
        '''
        for name, param in m.named_parameters():
            if 'weight' in name:
                 nn.init.uniform_(param.data, -0.04, 0.04)
            else:
                nn.init.constant_(param.data, 0)
    
    @staticmethod
    def exact_accuracy(pred_words, tar_words):
        ''' 
        compute the accuracy using (predicted words, target words) and return it.
        exact word matching is used.
        '''
        assert(len(pred_words) == len(tar_words))
        count = 0
        for i in range(len(pred_words)):
            if pred_words[i] == tar_words[i]:
                count += 1
        return count / len(pred_words)
    
    ####################
    # DATA RELATED HOOKS
    ####################

    def setup(self, stage=None):
        # load all the available data on all GPUs
        self.x_train, self.y_train = load_data(TARGET, 'train')
        self.x_valid, self.y_valid = load_data(TARGET, 'valid')
        self.x_test, self.y_test = load_data(TARGET, 'test')

    def train_dataloader(self):
        dataset = TransliterateDataset(self.x_train, self.y_train, src_lang=SRC_LANG, tar_lang=TAR_LANG)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, collate_fn=CollationFunction(SRC_LANG, TAR_LANG))
        return dataloader

    def val_dataloader(self):
        dataset = TransliterateDataset(self.x_valid, self.y_valid, src_lang=SRC_LANG, tar_lang=TAR_LANG)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, collate_fn=CollationFunction(SRC_LANG, TAR_LANG))
        return dataloader

    def test_dataloader(self):
        dataset = TransliterateDataset(self.x_test, self.y_test, src_lang=SRC_LANG, tar_lang=TAR_LANG)
        dataloader = DataLoader(dataset=dataset, batch_size=1, collate_fn=CollationFunction(SRC_LANG, TAR_LANG))
        # we do inference word by word. So, batch_size = 1
        return dataloader

    ####################
    # INTERFACE RELATED FUNCTIONS 
    ####################

    def training_step(self, train_batch, batch_idx):
        batch_X, batch_y, X_lens = train_batch
        # get the logits, preds for the current batch
        logits, preds = self.model(batch_X, batch_y, X_lens, tf_ratio=self.cur_tf_ratio)
        # ignore loss for the first time step
        targets = batch_y[:, 1:]; logits = logits[:, 1:, :]
        logits = logits.swapaxes(1, 2) # make class logits the second dimension as needed
        loss = self.loss_criterion(logits, targets)
        # for epoch-level metrics[accuracy], log all the required data
        self.true_train_words += self.tar_lang.convert_to_words(batch_y)
        self.pred_train_words += self.tar_lang.convert_to_words(preds)
        self.train_losses.append(loss) # to get train loss for epoch
        return loss
    
    def on_train_epoch_end(self):
        # for wandb logging
        self.wdb_logged_metrics['train_loss'] = torch.stack(self.train_losses).mean()
        self.wdb_logged_metrics['train_acc'] = self.exact_accuracy(self.pred_train_words, self.true_train_words)
        self.wdb_logged_metrics['tf_ratio'] = self.cur_tf_ratio
        self.wdb_logged_metrics['epoch'] = self.current_epoch
        self.train_losses.clear()

        # note that on train_epoch_end is actually executed after valid epoch; so we log onto wandb here
        if wandb.run is not None:
            wandb.log(self.wdb_logged_metrics)

        # for display bar
        self.log('train_loss', self.wdb_logged_metrics['train_loss'], on_epoch=True, prog_bar=True)
        self.log('train_acc', self.wdb_logged_metrics['train_acc'], on_epoch=True, prog_bar=True)
        self.pred_train_words.clear(); self.true_train_words.clear() # clear to save memory and for next epoch

        # for first 12 epochs, we dont change the tf ratio. Then we decrease it by 0.1 every epoch till
        # min_tf_ratio is reached. This is also logged.
        if (self.current_epoch >= 11):
            self.cur_tf_ratio -= 0.1
            self.cur_tf_ratio = max(self.cur_tf_ratio, self.min_tf_ratio)

    def validation_step(self, valid_batch, batch_idx):
        batch_X, batch_y, X_lens = valid_batch
        # get the logits, preds for the current batch
        logits, preds = self.model(batch_X, batch_y, X_lens) # no teacher forcing
        # ignore loss for the first time step
        targets = batch_y[:, 1:]; logits = logits[:, 1:, :]
        logits = logits.swapaxes(1, 2) # make class logits the second dimension as needed
        loss = self.loss_criterion(logits, targets)
        # for epoch-level metrics[accuracy], log all the required data
        self.true_valid_words += self.tar_lang.convert_to_words(batch_y)
        self.pred_valid_words += self.tar_lang.convert_to_words(preds)
        self.valid_losses.append(loss) # to get val loss for epoch
    
    def on_validation_epoch_end(self):
        # for wandb logging
        self.wdb_logged_metrics['val_loss'] = torch.stack(self.valid_losses).mean()
        self.wdb_logged_metrics['val_acc'] = self.exact_accuracy(self.true_valid_words, self.pred_valid_words)
        self.valid_losses.clear()
        
        # for display bar
        self.log('val_loss', self.wdb_logged_metrics['val_loss'], on_epoch=True, prog_bar=True)
        self.log('val_acc', self.wdb_logged_metrics['val_acc'], on_epoch=True, prog_bar=True)

        self.true_valid_words.clear(); self.pred_valid_words.clear() # clear to free memory and for next epoch
    
    def test_step(self, test_batch, batch_idx):
        batch_X, batch_y, X_lens = test_batch
        logits, pred_word, attn_matrix = self.model.greedy_inference(batch_X, X_lens)
        # update all the global lists
        self.pred_test_words += pred_word
        self.true_test_words += self.tar_lang.convert_to_words(batch_y)
        self.test_X_words += self.src_lang.convert_to_words(batch_X)
        # if there is attention, update the attention list also
        if (self.attention):
            self.attn_matrices += [attn_matrix]
        # ignore loss for the first time step
        targets = batch_y[:, 1:]; logits = logits[1:, :]
        # we shrink the logits to the true decoded sequence length for loss computation alone
        true_dec_len = targets.size(1)
        logits = (logits[:true_dec_len, :]).swapaxes(0,1).unsqueeze(0)
        # squeeze and swapping of dimensions is to meet condition needed by nn.CrossEntopyLoss()
        loss = self.loss_criterion(logits, targets)
        self.log('test_loss', loss, prog_bar=True, on_epoch=True, on_step=False)
    
    # will prevent clearing of global test lists on test epoch end
    def track_test_predictions(self):
        self.save_test_preds = True

    def on_test_epoch_end(self):
        self.log('test_acc', self.exact_accuracy(self.pred_test_words, self.true_test_words), 
                 on_epoch=True, prog_bar=True)
        if not self.save_test_preds:
            self.pred_test_words.clear(); self.true_test_words.clear(); self.test_X_words.clear()
            self.attn_matrices.clear()
    
    # here, we will save all the predictions made and also, return a copy of the list of attention
    # matrices for generating heatmaps
    def save_test_predictions(self, fname='pred'):
        edit_distances = [edit_dist(pred,tar) for pred, tar in zip(self.pred_test_words,self.true_test_words)]
        pred_df = pd.DataFrame(zip(self.test_X_words, self.true_test_words, self.pred_test_words, edit_distances),
                               columns=['Input', 'Target', 'Predicted', 'Levenshtein Distance'])
        pred_df.to_csv('./'+fname+'.csv', index=False, encoding='utf-8')

        # if attention layer is present, we return attention matrices as well.
        ret_info = None
        if self.attention:
            ret_info = (self.test_X_words.copy(), self.true_test_words.copy(), self.pred_test_words.copy(), self.attn_matrices.copy())
        self.save_test_preds = False

        # clear after saving to save memory 
        self.pred_test_words.clear(); self.true_test_words.clear(); self.test_X_words.clear()
        self.attn_matrices.clear()
        return ret_info

### Sweep Section

In [None]:
# testing runner
# BEST CONFIG - NOTE (attn = True, 12 epochs, tf_ratio=0.8)
runner = Runner(SRC_LANG, TAR_LANG, 128, 3, 256, 'LSTM', 0.8, True, True, 0.0, 'Adam', learning_rate=2e-3, batch_size=128)
trainer = lt.Trainer(max_epochs=12)

trainer.fit(runner)
runner.freeze()
runner.track_test_predictions()
trainer.test(runner)
a, b, c, d = runner.save_test_predictions()

In [None]:
import wandb
wandb.login()

sweep_configuration = {
    'method': 'bayes',
    'name': 'no-attention-bayes',
    'metric': {'goal': 'maximize', 'name': 'validation_accuracy'},
    'parameters': {
        'embedding_size' : {'values' : [16, 32, 64, 128, 192]},
        'number_of_layers' : {'values' : [1, 2, 3]},
        'hidden_size' : {'values' : [32, 64, 128, 192, 256]},
        'cell' : {'values' : ['RNN', 'GRU', 'LSTM']},
        'bidirectional' : {'values' : ['True', 'False']},
        'dropout' : {'values' : [0.0, 0.05, 0.1, 0.2, 0.3]},
        'initial_tf_ratio' : {'values' : [0.6, 0.7, 0.8, 0.9]},
        'batch_size' : {'values' : [32, 64, 128]},
        'attention' : {'value' : 'False'},
        'dataset' : {'value' : 'aksharantar'},
        'optimizer' : {'value' : 'Adam'},
        'learning_rate' : {'value' : 2e-3},
        'max_epochs' : {'value' : 35},
        'patience' : {'value' : 5},
    }
}

# sweep_id = wandb.sweep(sweep=sweep_configuration, entity='cs19b021', project='cs6910-assignment3')

In [None]:
# testing runner
# send tf_ratio (hparam) for 10 epochs (min_epochs); then turn on early stopping to track val_loss/val_acc
# BEST CONFIG - NOTE (attn = True, 12 epochs, tf_ratio=0.8)

wconfig = {
        'embedding_size' : 128,
        'number_of_layers' : 3,
        'hidden_size' : 128,
        'cell' : 'LSTM',
        'bidirectional' : True,
        'dropout' : 0.0,
        'initial_tf_ratio' : 0.8,
        'batch_size' : 128,
        'attention' : True,
        'dataset' : 'aksharantar',
        'optimizer' : 'Adam',
        'learning_rate' :  2e-3,
        'max_epochs' : 1,
        'patience' : 5,
        'min_epochs': 12,
        'min_delta_imp' : 0.001
}

def agent_code():
    #wconfig = wandb.config
    wdbrun = wandb.init(job_type='testing-code', project='cs6910-assignment3', entity='cs19b021', config=wconfig)
    wdbrun.name = f'emb={wconfig["embedding_size"]}_layers={wconfig["number_of_layers"]}_hid={wconfig["hidden_size"]}'
    wdbrun.name += f'_cell={wconfig["cell"]}_bidirectional={wconfig["bidirectional"]}_dr={wconfig["dropout"]}'
    wdbrun.name += f'_itfr={wconfig["initial_tf_ratio"]}_bsize={wconfig["batch_size"]}_att={wconfig["attention"]}'
    wdbrun.name += f'_opt={wconfig["optimizer"]}_lr={wconfig["learning_rate"]}'
    rdict = dict(
                src_lang=SRC_LANG,
                tar_lang=TAR_LANG,
                common_embed_size=wconfig['embedding_size'],
                common_num_layers=wconfig['number_of_layers'],
                common_hidden_size=wconfig['hidden_size'],
                common_cell_type=wconfig['cell'],
                init_tf_ratio= wconfig['initial_tf_ratio'],
                enc_bidirect=wconfig['bidirectional'],
                attention=wconfig['attention'],
                dropout=wconfig['dropout'],
                opt_name=wconfig['optimizer'],
                learning_rate=wconfig['learning_rate'],
                batch_size=wconfig['batch_size'] 
        )
    
    runner = Runner(**rdict)
    # early stop if val_acc does not improve by 0.001 = 0.1% for 5 epochs
    early_stop_callback = EarlyStopping(monitor="val_acc", min_delta=wconfig['min_delta_imp'], patience=wconfig['patience'], verbose=True, mode="max")
    chkCallback = ModelCheckpoint(dirpath='./', filename=f'{wdbrun.name}', monitor='val_acc', mode='max')
    trainer = lt.Trainer(min_epochs=wconfig['min_epochs'], max_epochs=wconfig['max_epochs'], callbacks=[chkCallback, early_stop_callback])
    trainer.fit(runner)
    # log the checkpoint so that we can test by loading it directly
    artifact = wandb.Artifact(f'{wandb.run.name}_best_ckpt'.replace("=","-"), type='model')
    artifact.add_file(chkCallback.best_model_path)
    wandb.run.log_artifact(artifact)
    wdbrun.finish()

agent_code()

In [None]:
torch.save(runner.state_dict, 'best_model.pth')

In [None]:
rdict = dict(
    src_lang=SRC_LANG,
    tar_lang=TAR_LANG,
    common_embed_size=wconfig['embedding_size'],
    common_num_layers=wconfig['number_of_layers'],
    common_hidden_size=wconfig['hidden_size'],
    common_cell_type=wconfig['cell'],
    init_tf_ratio= wconfig['initial_tf_ratio'],
    enc_bidirect=wconfig['bidirectional'],
    attention=wconfig['attention'],
    dropout=wconfig['dropout'],
    opt_name=wconfig['optimizer'],
    learning_rate=wconfig['learning_rate'],
    batch_size=wconfig['batch_size']
)

runner = Runner.load_from_checkpoint('./mythical-fighter-31_chkpt.ckt.ckpt', **rdict)
runner.freeze()
trainer.test(runner)

### IGNORE SECTION -> PERFORMED BUG SEARCH

In [None]:
# NOTE -> performing BUG SEARCH
# testing all combinations to catch bugs
# RESULT -> all clear; no bugs caught
num_lay = [1,3]
ctype = ['LSTM', 'GRU', 'RNN']
bidirect = [True, False]
attn = [True, False]

for n in num_lay:
    for c in ctype:
        for b in bidirect:
            for a in attn:
                runner = Runner(SRC_LANG, TAR_LANG, 128, n, 256, c, 0.8, b, a, 0.05, 'Adam', learning_rate=2e-3, batch_size=128)
                trainer = lt.Trainer(max_steps=2)
                trainer.fit(runner)
                runner.freeze()
                trainer.test(runner)