# CS6910 Assignment 3 (RNN Frameworks for transliteration) - With Bahdanau attention

In [70]:
# importing required libraries for the notebook

import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
import wandb
import numpy as np
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import random

In [71]:
# set the device to 'cuda' if available
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# define the source and target languages
TARGET = 'hin'
SOURCE = 'eng'
# define the special tokens that stand for start of seq, end of seq, 
# an unknown symbol.
SOS_SYM = '@'
EOS_SYM = '$'
UNK_SYM = '!'
# define a special token for padding - this helps with batch processing 
PAD_SYM = '%'

cuda


## Preprocessing Functions and Helpers

In [72]:
# function to load the 'cat' (= train/val/test) data of language 'lang'
def load_data(lang, cat):
    fcontents = open(f'aksharantar_sampled/{lang}/{lang}_{cat}.csv','r', encoding='utf-8').readlines()
    pairs = [tuple(l.strip().split(',')) for l in fcontents]
    x_data, y_data = list(map(list,zip(*pairs)))
    return x_data, y_data

# class for a language with useful functions.
class Language:
    def __init__(self, name):
        self.lname = name
    
    # function to create the vocabulary(set of tokens) using the words in 'data'
    # here, a token is either a special token or a lang character
    def create_vocabulary(self, *data):
        symbols = set()
        for wd in data:
            for c in wd:
                symbols.add(c)
        self.symbols = symbols
    
    # function to generate the index2sym (a number to a token) and 
    # sym2index (a token to a number) mappings using the vocabulary
    def generate_mappings(self):
        self.index2sym = {0: SOS_SYM, 1 : EOS_SYM, 2 : UNK_SYM, 3 : PAD_SYM}
        self.sym2index = {SOS_SYM : 0, EOS_SYM : 1, UNK_SYM : 2, PAD_SYM : 3}
        self.symbols = list(self.symbols)
        self.symbols.sort()

        for i, sym in enumerate(self.symbols):
            self.sym2index[sym] = i + 4
            self.index2sym[i+4] = sym
        
        self.num_tokens = len(self.index2sym.keys())
    
    # function to tokenize a word and convert all the tokens to
    # their corr. numbers using sym2index
    def convert_to_numbers(self, word):
        enc = [self.sym2index[SOS_SYM]]
        for ch in word:
            if ch in self.sym2index.keys():
                enc.append(self.sym2index[ch])
            else:
                enc.append(self.sym2index[UNK_SYM])
        enc.append(self.sym2index[EOS_SYM])
        return enc
    
    # convert a list of predictions (each prediction is a list of numbers)
    # to the corresponding list of words using index2sym
    # pred should be numpy array of shape (number_of_words, max_word_length)
    # tokens after EOS_SYM are discarded
    def convert_to_words(self, preds):
        num = preds.shape[0]
        words = [] 
        for i in range(num):
            wd = ''
            for idx in preds[i][1:]: # 1: -> ignore SOS token
                ch = self.index2sym[idx]
                if ch != EOS_SYM:
                    wd += ch
            words.append(wd)
        return words

    # get the number assigned to a token
    def get_index(self, sym):
        return self.sym2index[sym]
    
    # get the number of tokens in the vocabulary
    def get_size(self):
        return self.num_tokens

In [73]:
# load all the available data and print sample counts for each set
x_train, y_train = load_data(TARGET, 'train')
x_valid, y_valid = load_data(TARGET, 'valid')
x_test, y_test = load_data(TARGET, 'test')

print(f'Number of train samples = {len(x_train)}')
print(f'Number of valid samples = {len(x_valid)}')
print(f'Number of test samples = {len(x_test)}')

Number of train samples = 51200
Number of valid samples = 4096
Number of test samples = 4096


In [74]:
# create language objects for storing vocabulary, index2sym and sym2index
SRC_LANG = Language(SOURCE)
TAR_LANG = Language(TARGET)

# creating vocabulary using train data only
SRC_LANG.create_vocabulary(*(x_train))
TAR_LANG.create_vocabulary(*(y_train))

# otherwise, use unicode characters (assigned codepoints) in the script's range
# src_lang.create_vocabulary_range()
# tar_lang.create_vocabulary_range()

# generate mappings from characters to numbers and vice versa
SRC_LANG.generate_mappings()
TAR_LANG.generate_mappings()

# print the source and target vocabularies
print(f'Source Vocabulary Size = {len(SRC_LANG.symbols)}')
print(f'Source Vocabulary = {SRC_LANG.symbols}')
print(f'Source Mapping {SRC_LANG.index2sym}')
print(f'Target Vocabulary Size = {len(TAR_LANG.symbols)}')
print(f'Target Vocabulary = {TAR_LANG.symbols}')
print(f'Target Mapping {TAR_LANG.index2sym}')

Source Vocabulary Size = 26
Source Vocabulary = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Source Mapping {0: '@', 1: '$', 2: '!', 3: '%', 4: 'a', 5: 'b', 6: 'c', 7: 'd', 8: 'e', 9: 'f', 10: 'g', 11: 'h', 12: 'i', 13: 'j', 14: 'k', 15: 'l', 16: 'm', 17: 'n', 18: 'o', 19: 'p', 20: 'q', 21: 'r', 22: 's', 23: 't', 24: 'u', 25: 'v', 26: 'w', 27: 'x', 28: 'y', 29: 'z'}
Target Vocabulary Size = 64
Target Vocabulary = ['ँ', 'ं', 'ः', 'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ए', 'ऐ', 'ऑ', 'ओ', 'औ', 'क', 'ख', 'ग', 'घ', 'ङ', 'च', 'छ', 'ज', 'झ', 'ञ', 'ट', 'ठ', 'ड', 'ढ', 'ण', 'त', 'थ', 'द', 'ध', 'न', 'प', 'फ', 'ब', 'भ', 'म', 'य', 'र', 'ल', 'ळ', 'व', 'श', 'ष', 'स', 'ह', '़', 'ऽ', 'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॅ', 'े', 'ै', 'ॉ', 'ो', 'ौ', '्']
Target Mapping {0: '@', 1: '$', 2: '!', 3: '%', 4: 'ँ', 5: 'ं', 6: 'ः', 7: 'अ', 8: 'आ', 9: 'इ', 10: 'ई', 11: 'उ', 12: 'ऊ', 13: 'ऋ', 14: 'ए', 15: 'ऐ', 16: 'ऑ', 17: 'ओ', 18: 'औ',

In [75]:
# A Dataset class to help with creating the dataset. Converts the data into numbers
# using the source and target languages' sym2index and index2sym dictionaries
class TransliterateDataset(Dataset):
    def __init__(self, x_data, y_data, src_lang : Language, tar_lang : Language):
        # save all the data points and the language objects
        self.x_data = x_data
        self.y_data = y_data
        self.src_lang = src_lang
        self.tar_lang = tar_lang
        
    def __len__(self):
        # needs to be implemented for a pytorch `Dataset`
        return len(self.y_data)

    def __getitem__(self, idx):
        # gives the data point (X, y) at index = idx
        # we convert them to a tensor of numbers using the Language objects
        # also returns the word y for ease of computing accuracy later
        x_enc, y_enc = self.x_data[idx], self.y_data[idx]
        x_enc = self.src_lang.convert_to_numbers(x_enc)
        y_enc = self.tar_lang.convert_to_numbers(y_enc) 
        return torch.tensor(x_enc, dtype=int), torch.tensor(y_enc,dtype=int), self.y_data[idx]

# This is a collation function for post-processing a batch in a DataLoader. We sort the instances (X,y) in a batch
# based on seq length of X in desc order and create a padded batch to help with batch-processing in recurrent
# networks
class CollationFunction:
    def __init__(self, src_lang : Language, tar_lang : Language):
        self.src_lang = src_lang
        self.tar_lang = tar_lang
    
    def __call__(self, batch):
        # reasoning : https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
        batch.sort(key = lambda x : len(x[0]), reverse=True)
        src, tar, tar_words = zip(*batch)
        src_lens = torch.tensor([len(x) for x in src], dtype=int)
        # pad both the X part(src) and y part(tar) with PAD_SYM
        src = nn.utils.rnn.pad_sequence(list(src), batch_first=True, padding_value=self.src_lang.get_index(PAD_SYM))
        tar = nn.utils.rnn.pad_sequence(list(tar), batch_first=True, padding_value=self.tar_lang.get_index(PAD_SYM))
        # return padded batch_X (src), padded batch_Y (tar), X_lens (needed for unpacking) and y words(tar_words)
        # each entry in tar_words is a string
        return src, tar, src_lens, tar_words

## Encoder Network

In [124]:
'''
    Class for the Encoder Network
'''
class EncoderNet(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, hid_size, cell_type, 
                 bidirect=False, dropout=0):
        '''
            input_vocab_size (V) = number of tokens in the input language dictionary
            embed_size = dim of embedding for each input token
            num_layers = number of layers in the encoder network
            hidden_size = dim of hidden state of each cell
            cell_type = RNN/GRU/LSTM
            bidirect = True for bidirectional network and False otherwise
            dropout = dropout probability
        '''
        # save all the necessary arch information
        super(EncoderNet, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hid_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.cell_type = cell_type
        self.bidirect = bidirect

        # create the embedding layer and a dropout layer for it
        self.embedding = nn.Embedding(self.vocab_size, self.embed_size)
        self.dropout = nn.Dropout(p=dropout)

        # we create the kw args using the received parameters and use it to create network layer stack
        kwargs = {'input_size':embed_size, 'hidden_size':hid_size, 'num_layers':num_layers, 
                 'bidirectional':bidirect, 'batch_first':True}
        if num_layers > 1:
            kwargs['dropout'] = dropout
        if cell_type == 'RNN':
            self.network = nn.RNN(**kwargs)
        elif cell_type == 'LSTM':
            self.network = nn.LSTM(**kwargs)
        else:
            self.network = nn.GRU(**kwargs)
        
        # for combining the final layer's forward and reverse directions' final hidden state
        # we create linear layers to do this for each encoder layer
        if (self.bidirect):
            self.combine_forward_backward = [nn.Linear(2 * hid_size, hid_size) for _ in range(num_layers)]
            self.combine_forward_backward = nn.ModuleList(self.combine_forward_backward)

    def forward(self, batch_X, X_lens):
        '''
            batch_X - padded input batch of examples. shape = (batch_size, max_batch_seq_length)
                      [padding is already taken care of by collate function of DataLoader]
            X_lens - length of each input sequence. A python list of `batch_size` many integers
        '''
        # pass the batch through the embedding with dropout and pack the batch.
        # packing is for efficiency 
        batch_X = self.embedding(batch_X)
        batch_X = self.dropout(batch_X)
        packed_batch_x = nn.utils.rnn.pack_padded_sequence(batch_X, lengths=X_lens, batch_first=True, 
                                                           enforce_sorted=True)
        # send the batch through the network correctly based on cell type
        # packed_outputs = packed sequence of outputs from the final layer.
        # hidden_outputs = hidden outputs from every layer. shape = (D * num_layers, batch_size, hidden_size)
        # D = 2 if bidirectional; else 1
        if self.cell_type == 'LSTM':
            packed_outputs, (hidden_outputs, _) = self.network(packed_batch_x)
        else:
            packed_outputs, hidden_outputs = self.network(packed_batch_x)
        
        # unpack the packed sequence. outputs has shape (batch_size, max_seq_len, D * hidden_size)
        # without attention, outputs is NOT USED.
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)

        # shape of `hidden_state_all_layers`= (num_layers, batch_size, hidden_size)
        hidden_state_all_layers = hidden_outputs[:, :, :]
        if self.bidirect:
            # here, we need to process bidirectional final hidden states of each layer through a linear layer
            req_shape = (hidden_outputs.shape[0]//2, hidden_outputs.shape[1], hidden_outputs.shape[2])
            hidden_state_all_layers = torch.zeros(req_shape).to(device)
            for i in range(self.num_layers):
                xidx = 2 * i
                # concatenate the forward and reverse directions outputs along the hidden_size's dimension
                # shape = (batch_size, 2 * hidden_size) now
                concat_hidden_state_cur_layer = torch.cat([hidden_outputs[xidx, :, :], hidden_outputs[xidx + 1, :, :]], dim=1)
                hidden_state_cur_layer = self.combine_forward_backward[i](concat_hidden_state_cur_layer)
                hidden_state_cur_layer = torch.tanh(hidden_state_cur_layer)
                hidden_state_all_layers[i, :, :] = hidden_state_cur_layer

        # outputs - shape (batch_size, max_seq_len, D * hidden_size) -> NOT USED without attention
        # hidden_state_all_layers - shape (num_layers, batch_size, hidden_size)
        return outputs, hidden_state_all_layers

## Attention Class

In [None]:
'''
    Class for Attention
'''
class Attention(nn.Module):
    def __init__(self, hidden_dim, bidirect = False):
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.bidirect = bidirect
        self.attn_matrix_indim = 3 * hidden_dim if bidirect == True else 2 * hidden_dim
        self.U = nn.Linear(self.attn_matrix_indim, hidden_dim) # to be sent to tanh layer
        self.V = nn.Linear(self.hidden_dim, 1, bias=False) # dotted with tanh layer's output to get weights
        self.softmaxlayer = nn.Softmax(dim=1)

    def forward(self, prev_dec_hidden, padded_enc_outputs, mask):
        ''' 
            prev_dec_hidden -> shape (batch_size, hidden_size) - first decoder layer's hidden state
            padded_enc_outputs -> shape (batch_size, max_seq_len, hidden_size)
            mask -> shape (batch_size, max_seq_len) with 1 in locations where pad token is present
            max_seq_len = max_seq_len in batch_X [that was processed by encoder]
        '''
        batch_size, max_seq_len = padded_enc_outputs.shape
        hidden_extended = prev_dec_hidden.unsqueeze(1).repeat(1, max_seq_len, 1)
        # hidden_extended shape = (batch_size, max_seq_len, hidden_size)
        U_input = torch.cat([hidden_extended, padded_enc_outputs], dim=2).to(device)
        # U_input shape = (batch_size, max_seq_len, [2 or 3] * hidden_size)
        tanh_output = torch.tanh(self.U(U_input))
        # tanh_output shape = (batch_size, max_seq_len, hidden_size)
        attn_weights = self.V(tanh_output).squeeze(2)
        # attn_weights shape = (batch_size, max_seq_len)
        attn_weights = torch.masked_fill(attn_weights, mask==1, -1e12)
        # fill pad locations with very small values to be zeroed by softmax
        attn_weights = self.softmaxlayer(attn_weights)
        # convert weights to probabilities over max_seq_length dimension
        return attn_weights


## Decoder Network


In [92]:
'''
    Class for the Decoder Network
'''
class DecoderNet(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, hid_size, cell_type, attention,
                 dropout=0, enc_bidirect=False):
        super(DecoderNet, self).__init__()
        # store all the network arch information
        self.hidden_size = hid_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.cell_type = cell_type
        # create the embedding layer with a dropout layer for it
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(p=dropout)
        # create the linear layer for producing output logits (need to be sent through softmax to get 
        # char probabilities). But, we can use CrossEntropyLoss() on this logits directly
        self.out_layer = nn.Linear(hid_size, vocab_size)

       # we create the required architecture using the received parameters
       # now, input_size will be [embed_size + 2 * hidden_size] if enc_bidirect
       #                    and [embed_size + hidden_size] otherwise
        kwargs = {'hidden_size':hid_size, 'num_layers':num_layers, 
                 'batch_first':True}
        kwargs['input_size'] = embed_size + 2 * hid_size if enc_bidirect else embed_size + hid_size
        # create the network using kwargs
        if num_layers > 1:
            kwargs['dropout'] = dropout
        if cell_type == 'RNN':
            self.network = nn.RNN(**kwargs)
        elif cell_type == 'LSTM':
            self.network = nn.LSTM(**kwargs)
        else:
            self.network = nn.GRU(**kwargs)
        
        # save the attention object
        self.attention = attention

    # will always go 1 step forward in time (seqlen = L = 1)
    def forward(self, batch_y, prev_decoder_state, padded_enc_outputs, mask):
        ''' 
            batch_y -> shape (batch_size) = decoder input with vocabulary indices from target language
            prev_decoder_state(RNN, GRU) -> shape (num_layers, batch_size, hidden_size) 
            prev_decoder_state(LSTM) -> tuple of prev_hidden_state, prev_cell_state
                                        both have shape (num_layers, batch_size, hidden_size) ()
            padded_enc_outputs -> shape (batch_size, max_enc_seq_len, (2 or 1) * hid_size)
                                        2 if encoder is bidirectional
            mask -> shape (batch_size, max_enc_seq_len) -> 1 where pad token is present
        '''
        # we add a dummy dimension for seqlen = 1. new shape of batch_y is (batch_size, 1)
        batch_y = batch_y.unsqueeze(1)

        # pass through embedding and dropout layers. new shape of batch_y is (batch_size, 1, hidden_size)
        embedded_batch_y = self.embedding(batch_y)
        embedded_batch_y = self.dropout(embedded_batch_y)

        # unpack the decoder state and compute attention
        if self.cell_type == 'LSTM':
            decoder_hidden_state, _ = prev_decoder_state
        else:
            decoder_hidden_state = prev_decoder_state

        attn_weights = self.attention(decoder_hidden_state[0, :, :], padded_enc_outputs, mask)
        # recall attn_weights shape = (batch_size, max_enc_seq_len)
        attn_weights = attn_weights.unsqueeze(1)
        attn_weighted_enc_outputs = torch.bmm(attn_weights, padded_enc_outputs)
        # attn_weighted_enc_outputs shape = (batch_size, 1, (2 or 1) * hid_size)

        dec_input = torch.cat([attn_weighted_enc_outputs, embedded_batch_y], dim=2)
        # dec_input shape (batch_size, 1, (2 or 1) * hid_size + embed_size)

        # pass dec_input through the network
        outputs, new_decoder_state = self.network(dec_input, prev_decoder_state)
        outputs = outputs.squeeze(1) # remove seqlen dimension. shape = (batch_size, hidden_size)
        # pass outputs through the linear layer to get the logits (shape = (batch_size, out_vocab_size))
        logits = self.out_layer(outputs)
        # attention weights is required for visualiztion
        return logits, new_decoder_state, attn_weights

## Seq2Seq(Encoder-Decoder) Model

In [115]:
''' 
Class for encapsulating the encoder and decoder networks
'''
class EncoderDecoder(nn.Module):
    def __init__(self, encoder :EncoderNet, decoder : DecoderNet, src_lang : Language, tar_lang : Language) -> None:
        super(EncoderDecoder, self).__init__()
        # store the encoder and decoder along with language objects in the class
        self.enc_model = encoder
        self.dec_model = decoder
        self.src_lang = src_lang
        self.tar_lang = tar_lang
        self.cell_type = self.dec_model.cell_type
        # we require num of enc layers == num of dec layers as we connect encoder and decoder
        # layer by layer
        assert(self.enc_model.num_layers == self.dec_model.num_layers)
        self.num_layers = self.enc_model.num_layers
    
    # function to make mask for batch_X where mask == 1 iff pad token is present in that location
    def make_mask(self, batch_X):
        return torch.where(batch_X == self.tar_lang.sym2index[PAD_SYM], 1, 0)
    
    def forward(self, batch_X, batch_y, X_lens, tf_ratio=None):
        ''' 
            batch_X -> shape (batch_size, max_batch_X_seq_len) - padded input to encoder
            batch_y -> shape (batch_size, max_batch_y_seq_len) - padded input to decoder
            X_lens -> list of true (unpadded) lengths of the sequences in batch_X
        '''
        # compute batch_size and send batch_X through the encoder
        batch_size = batch_X.size(0)
        enc_outputs, final_enc_hidden_state = self.enc_model(batch_X, X_lens)
        # recall final_enc_hidden_state -> shape (num_layers, batch_size, hidden_size)
        
        # make padding mask for batch_X
        pad_mask = self.make_mask(batch_X)
        tarlength = batch_y.size(1) # max seq length of batch_y

        # outlogits -> tensor for storing the output logits (softmax to get post prob.)
        outlogits = torch.zeros(batch_size, tarlength, self.dec_model.vocab_size).to(device)
        # preds -> tensor for storing argmax(logits) over the target vocab for each example and each time step
        preds = torch.zeros(batch_size, tarlength).to(device)

        dec_input = batch_y[:,0] # get initial input for decoder -> SOS tokens with shape (batch_size)
        decoder_state = final_enc_hidden_state # initially decoder hidden state = final_enc_hidden_state

        if (self.cell_type == 'LSTM'):
            # for LSTM, cell state is initialized to zero and is added to decoder state
            init_dec_cell_state = torch.zeros_like(decoder_state, device = device)
            decoder_state = (decoder_state, init_dec_cell_state)
        
        # for each timestep
        for tstep in range(1, tarlength):
            # send the dec_input through the decoder. we ignore the attn_weights here.
            curlogits, decoder_state, _ = self.dec_model(dec_input, decoder_state, enc_outputs, pad_mask)
            # recall curlogits -> shape (batch_size, out_vocab_size); decoder_state -> shape invariant.
            tf_input = batch_y[:, tstep] # dec input for next time step if teacher forcing is chosen
            # pred -> argmax along vocab_size (dim = 1) to get class labels. shape = (batch_size)
            pred = torch.argmax(curlogits, dim=1).to(device)
            # greedy dec input is whatever set of words was predicted previously. shape = (batch_size)
            dec_input = pred
            # change dec input to tf input with prob = tf_ratio
            if tf_ratio != None and torch.rand(1)[0] < tf_ratio:
                dec_input = tf_input
            # store curlogits (for loss backprop) and pred (for predicted word construction)
            # for the current timestep
            outlogits[:, tstep, :] = curlogits 
            preds[:, tstep] = pred
        # NOTE - outlogits[:, 0, :] -> is a dummy tensor. should be discarded in loss computation
        # Similarly preds[:, 0] -> is also to be ignored. It has only 0s (=SOS_SYM).
        return outlogits, preds

    ### incomplete here onwards...
    def inference_forward(self, batch_X, X_lens, max_dec_length=25):
        ''' 
            batch_X -> shape (batch_size, max_batch_X_seq_len) - padded input to encoder
            X_lens -> list of true (unpadded) lengths of the sequences in batch_X
            max_dec_length -> length beyond which decoding is stopped
        '''
        # compute batch_size and send batch_X through the encoder
        batch_size = batch_X.size(0)
        _, final_enc_hidden_state = self.enc_model(batch_X, X_lens)
        # recall final_enc_hidden_state -> shape (num_layers, batch_size, hidden_size)
        
        # outlogits -> tensor for storing the output logits (softmax to get post prob.)
        outlogits = torch.zeros(batch_size, max_dec_length, self.dec_model.vocab_size).to(device)
        # preds -> tensor for storing argmax(logits) over the target vocab for each example at each time step
        preds = torch.zeros(batch_size, max_dec_length).to(device)
        dec_input = torch.tensor([self.tar_lang.sym2index[SOS_SYM] for _ in range(batch_size)], device=device)
        # get initial input for decoder -> SOS tokens with shape (batch_size)
        decoder_state = final_enc_hidden_state # initially decoder hidden state = final_enc_hidden_state

        if (self.cell_type == 'LSTM'):
            # for LSTM, cell state is initialized to zero and is added to decoder state
            init_dec_cell_state = torch.zeros_like(decoder_state, device = device)
            decoder_state = (decoder_state, init_dec_cell_state)

        # for each timestep
        for tstep in range(1, max_dec_length):
            # send the dec_input through the decoder.
            curlogits, decoder_state = self.dec_model(dec_input, decoder_state)
            # recall curlogits -> shape (batch_size, out_vocab_size); decoder_state -> shape invariant.
            # pred -> argmax along vocab_size (dim = 1) to get class labels. shape = (batch_size)
            pred = torch.argmax(curlogits, dim=1).to(device)
            # greedy dec input is whatever set of words was predicted previously. shape = (batch_size)
            dec_input = pred
            # store curlogits and pred (for predicted word construction) for the current timestep
            outlogits[:, tstep, :] = curlogits 
            preds[:, tstep] = pred

        # generate predicted words using preds tensor and return it.
        pred_words = self.tar_lang.convert_to_words(preds.cpu().numpy())
        return pred_words

## Train/Evaluation/Inference Class

In [135]:
class Runner():
    def __init__(self, src_lang : Language, tar_lang : Language, common_embed_size, common_num_layers, 
                 common_hidden_size, common_cell_type, enc_bidirect, dropout, opt_name='Adam',
                 learning_rate=1e-3):
        # save the language objects
        self.src_lang = src_lang
        self.tar_lang = tar_lang

        # create all the sub-networks and the main model
        self.encoder = EncoderNet(vocab_size=src_lang.get_size(), embed_size=common_embed_size,
                             num_layers=common_num_layers, hid_size=common_hidden_size,
                             cell_type= common_cell_type, bidirect=enc_bidirect, dropout=dropout)
        self.decoder = DecoderNet(vocab_size=tar_lang.get_size(), embed_size=common_embed_size,
                             num_layers=common_num_layers, hid_size=common_hidden_size,
                             cell_type=common_cell_type, dropout=dropout)
        self.model = EncoderDecoder(encoder=self.encoder, decoder=self.decoder, src_lang=src_lang, 
                                    tar_lang=tar_lang)
        
        # move model to the torch device
        self.model.to(device)
        # for reproducibility - seed everything with 42
        torch.manual_seed(42); torch.cuda.manual_seed(42); np.random.seed(42); random.seed(42)

        self.model.apply(self.init_weights) # initialize model weights

        # initialize the data loaders
        self.trainLoader, self.validLoader, self.testLoader = None, None, None

        # optimizer for the model and loss function [that ignores locs where target = PAD token]
        self.optimizer = None
        self.loss_criterion = nn.CrossEntropyLoss(ignore_index=tar_lang.sym2index[PAD_SYM])
        if opt_name == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

    @staticmethod
    def init_weights(m):
        '''
        function to initialize the weights of the model parameters
        '''
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param.data, mean=0, std=0.05)
            else:
                nn.init.constant_(param.data, 0)
    
    def generate_data_loaders(self, data_X, data_y, batch_size):
        '''
        Create the pytorch Dataset and use it to make the dataloader. The dataloader
        post-processes every batch to add padding
        '''
        dataset = TransliterateDataset(data_X, data_y, src_lang=SRC_LANG, tar_lang=TAR_LANG)
        dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                                collate_fn=CollationFunction(SRC_LANG, TAR_LANG))
        return dataloader

    def make_all_loaders(self, train_data, valid_data, test_data, batch_size):
        ''' 
        get train, valid, test data and create dataloaders of batches with size=batch_size
        '''
        train_X, train_y = train_data
        valid_X, valid_y = valid_data
        test_X, test_y = test_data

        self.trainLoader = self.generate_data_loaders(train_X, train_y, batch_size)
        self.validLoader = self.generate_data_loaders(valid_X, valid_y, batch_size)
        self.testLoader = self.generate_data_loaders(test_X, test_y, batch_size)

    def get_accuracy(self, pred_words, tar_words):
        ''' 
        compute the accuracy using (predicted words, target words) and return it.
        exact word matching is used.
        '''
        assert(len(pred_words) == len(tar_words))
        count = 0
        for i in range(len(pred_words)):
            if pred_words[i] == tar_words[i]:
                count += 1
        return count / len(pred_words)

    def train(self, epoch_number, tf_ratio=0.6):
        '''
        train the model for 1 epoch with teacher forcing ratio = tf_ratio.
        epoch_number is displayed in the output
        '''
        # sanity checks 
        assert(self.trainLoader != None); assert(self.optimizer != None)

        # set model in training mode for autograd to be activated
        self.model.train(); self.optimizer.zero_grad()
        # maintain training loss and lists of predicted and true words
        train_loss = 0.0
        pred_words, true_words = [], []
        with tqdm(self.trainLoader, unit=' batch') as tqdmLoader:
            for batch_X, batch_y, X_lens, y_words in tqdmLoader:
                tqdmLoader.set_description(f'Epoch {epoch_number}')
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                # get the logits, preds for the current batch
                logits, preds = self.model(batch_X, batch_y, X_lens, tf_ratio=tf_ratio)
                # ignore loss for the first time step
                targets = batch_y[:, 1:]; logits = logits[:, 1:, :]
                logits = logits.swapaxes(1, 2) # make class logits the second dimension as needed
                loss = self.loss_criterion(logits, targets); loss.backward(); train_loss += loss.item()
                self.optimizer.step(); self.optimizer.zero_grad()
                batch_pred_words = self.tar_lang.convert_to_words(preds.cpu().numpy())
                tqdmLoader.set_postfix(loss=loss.item())
                true_words += y_words; pred_words += batch_pred_words
        train_loss /= len(self.trainLoader); train_acc = self.get_accuracy(pred_words, true_words)
        print(f'Train Loss = {train_loss}; Train Accuracy = {train_acc * 100}')
        return train_loss, train_acc, pred_words, true_words

    def evaluate(self):
        assert(self.validLoader != None)
        self.model.eval()
        valid_loss = 0.0
        true_words, pred_words = [], []
        with torch.no_grad():
            with tqdm(self.validLoader, unit=' batch') as tqdmLoader:
                for batch_X, batch_y, X_lens, y_words in tqdmLoader:
                    tqdmLoader.set_description(f'Validation')
                    batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                    # get the logits, preds for the current batch
                    logits, preds = self.model(batch_X, batch_y, X_lens)
                    # ignore loss for the first time step
                    targets = batch_y[:, 1:]; logits = logits[:, 1:, :]
                    logits = logits.swapaxes(1, 2) # make class logits the second dimension as needed
                    loss = self.loss_criterion(logits, targets)
                    valid_loss += loss.item()
                    batch_pred_words = self.tar_lang.convert_to_words(preds.cpu().numpy())
                    tqdmLoader.set_postfix(loss=loss.item())
                    true_words += y_words
                    pred_words += batch_pred_words
        valid_loss /= len(self.validLoader)
        valid_acc = self.get_accuracy(pred_words, true_words)
        print(f'Validation Loss = {valid_loss}; Validation Accuracy = {valid_acc * 100}')
        return valid_loss, valid_acc, pred_words, true_words

    def inference(self):
        assert(self.testLoader != None)
        self.model.eval()
        pred_words, true_words = [], []
        with torch.no_grad():
            with tqdm(self.testLoader, unit=' batch') as tqdmLoader:
                for batch_X, _, X_lens, y_words in tqdmLoader:
                    tqdmLoader.set_description(f'Testing')
                    batch_X = batch_X.to(device)
                    # get the predicted words for the current batch
                    batch_pred_words = self.model.inference_forward(batch_X, X_lens, 25)
                    true_words += y_words
                    pred_words += batch_pred_words
        test_acc = self.get_accuracy(pred_words, true_words)
        print(f'Test Accuracy = {test_acc * 100}')
        return pred_words, true_words
    
    def beam_search_inference(self, ):
        pass

In [164]:
# testing runner
# keep embedding small (around 32) -> important to get dense embedding
# also, adjust learning rate reasonably
runner = Runner(SRC_LANG, TAR_LANG, 32, 1, 1024, 'LSTM', True, 0.1, learning_rate=1e-3)

In [165]:
runner.make_all_loaders((x_train, y_train), (x_valid, y_valid), (x_test, y_test), 128)

In [166]:
for i in range(10):
    _, _, p, w = runner.train(i, 0.6)
    runner.evaluate(); runner.inference()
    for x, y in list(zip(p,w))[:10]:
        print(x,y)

Epoch 0: 100%|██████████| 400/400 [00:50<00:00,  7.87batch/s, loss=2.12]


Train Loss = 2.5933389937877656; Train Accuracy = 0.03515625


Validation: 100%|██████████| 32/32 [00:01<00:00, 25.46batch/s, loss=2.12]


Validation Loss = 2.1170311346650124; Validation Accuracy = 0.634765625


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.82batch/s]


Test Accuracy = 1.0009765625
आईक्थंउउभउउउेआणषझ देशभरामध्ये
ससअनमररररउरउतशउउझ पेशनधारियों
टॅब!ररीइउउशशशँऔऔग अनुक्रमानुपात
ङसजऽररघरशममऽथँनऩ वाहिन्यांसाठी
ढसढौ्ौौञएयबछझझझझझ लोकरचनाओं
ईयसहऽउउवेँपपतो ट्रांसक्रिप्टेज़
ङ@ङङङँेठवजठँजजगौझ गंतिविधियों
ओईकगौठयउउथभउउउझझझ शासनाबद्दल
टईगउउगठररतररररररझ सर्वसंग्रह
@ङेसॅपथँँननचओ%ःरऑ तुमच्यापैकी


Epoch 1: 100%|██████████| 400/400 [00:49<00:00,  8.04batch/s, loss=1.66]


Train Loss = 1.7483454313874245; Train Accuracy = 1.365234375


Validation: 100%|██████████| 32/32 [00:01<00:00, 25.32batch/s, loss=1.54]


Validation Loss = 1.6516394838690758; Validation Accuracy = 8.0322265625


Testing: 100%|██████████| 32/32 [00:01<00:00, 19.83batch/s]


Test Accuracy = 6.7138671875
दिष्रामाचचा देशभरामध्ये
प्र्व्व्यांर पेशनधारियों
अनुत्रुुप् अनुक्रमानुपात
वावाव्यातत वाहिन्यांसाठी
लोखानाां लोकरचनाओं
ट्र्रस््रोस्स ट्रांसक्रिप्टेज़
गंत्वाव्यां गंतिविधियों
शिद्रदुुल शासनाबद्दल
सर्ता्गारा सर्वसंग्रह
तुमाममामां तुमच्यापैकी


Epoch 2: 100%|██████████| 400/400 [00:49<00:00,  8.02batch/s, loss=1.11] 


Train Loss = 1.3108171017467976; Train Accuracy = 5.705078125


Validation: 100%|██████████| 32/32 [00:01<00:00, 23.51batch/s, loss=1.36]


Validation Loss = 1.4206744208931923; Validation Accuracy = 13.8427734375


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.54batch/s]


Test Accuracy = 12.3046875
दिध््भभा्याी देशभरामध्ये
पक््््र्यों पेशनधारियों
अनुक्रानंपप्र अनुक्रमानुपात
वह्वय्यांतठी वाहिन्यांसाठी
लोक्ाााओं लोकरचनाओं
ट्रासस्पस्क्र ट्रांसक्रिप्टेज़
गंविविधियों गंतिविधियों
शश्ंदददलल शासनाबद्दल
सर्वसंगररार सर्वसंग्रह
तुक्ायामीकी तुमच्यापैकी


Epoch 3: 100%|██████████| 400/400 [00:50<00:00,  7.99batch/s, loss=0.839]


Train Loss = 1.0676872895658016; Train Accuracy = 10.28125


Validation: 100%|██████████| 32/32 [00:01<00:00, 24.34batch/s, loss=1.28]


Validation Loss = 1.3050520941615105; Validation Accuracy = 15.966796875


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.13batch/s]


Test Accuracy = 11.4501953125
देवारभा्यया देशभरामध्ये
पेशार्यियों पेशनधारियों
अनुक्रमां्द अनुक्रमानुपात
वह्यानससाहहठी वाहिन्यांसाठी
लोकराकाां लोकरचनाओं
ट्रास्प्पिप्टान ट्रांसक्रिप्टेज़
गंतिविधियों गंतिविधियों
शससंदद्ल शासनाबद्दल
सर्वसंगढ़ सर्वसंग्रह
तुक्चायमीकीी तुमच्यापैकी


Epoch 4: 100%|██████████| 400/400 [00:50<00:00,  7.99batch/s, loss=0.658]


Train Loss = 0.8852376410365105; Train Accuracy = 13.802734375


Validation: 100%|██████████| 32/32 [00:01<00:00, 24.62batch/s, loss=1.24]


Validation Loss = 1.2865384854376316; Validation Accuracy = 19.7998046875


Testing: 100%|██████████| 32/32 [00:01<00:00, 19.02batch/s]


Test Accuracy = 13.37890625
देशभरांे्ये देशभरामध्ये
पशशााधधियों पेशनधारियों
अनुक्रमानपपात अनुक्रमानुपात
वाह्न्यासााठी वाहिन्यांसाठी
लोकराचाां लोकरचनाओं
ट्रास्प्रसस्टस ट्रांसक्रिप्टेज़
गणतिविधियों गंतिविधियों
शससााबद्दललल शासनाबद्दल
सर्वसंगररह सर्वसंग्रह
पुमच्यापिका तुमच्यापैकी


Epoch 5: 100%|██████████| 400/400 [00:50<00:00,  7.94batch/s, loss=0.594]


Train Loss = 0.7818373070657253; Train Accuracy = 18.478515625


Validation: 100%|██████████| 32/32 [00:01<00:00, 24.82batch/s, loss=1.24]


Validation Loss = 1.2697770446538925; Validation Accuracy = 21.97265625


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.39batch/s]


Test Accuracy = 18.212890625
देशाराबाधध्येे देशभरामध्ये
पेशाधारियोंं पेशनधारियों
अनुक्रमापपुुतू अनुक्रमानुपात
वाहिन्यांसाठीी वाहिन्यांसाठी
लोकररचचाओ लोकरचनाओं
ट्रांसप्रेप्टस ट्रांसक्रिप्टेज़
गण्िविधियोंं गंतिविधियों
शासानदद्दललललल शासनाबद्दल
सर्वसंग्रह सर्वसंग्रह
तुम्चिपाकी तुमच्यापैकी


Epoch 6: 100%|██████████| 400/400 [00:50<00:00,  8.00batch/s, loss=0.594]


Train Loss = 0.6630961132049561; Train Accuracy = 22.81640625


Validation: 100%|██████████| 32/32 [00:01<00:00, 22.89batch/s, loss=1.16]


Validation Loss = 1.2794351652264595; Validation Accuracy = 22.1435546875


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.64batch/s]


Test Accuracy = 15.7958984375
देशभरममध्येे देशभरामध्ये
पेशाधारियों पेशनधारियों
अनुक्रमानुपा अनुक्रमानुपात
वाहिन्यांसाठीीी वाहिन्यांसाठी
लोकरचााओं लोकरचनाओं
ट्रामस्प्रस्पेन ट्रांसक्रिप्टेज़
गंतिविधियोंं गंतिविधियों
शससंबबद्दलल शासनाबद्दल
सर्वसंग्रह सर्वसंग्रह
तुमच्यपााीीीी तुमच्यापैकी


Epoch 7: 100%|██████████| 400/400 [00:50<00:00,  7.97batch/s, loss=0.402]


Train Loss = 0.5808846519142389; Train Accuracy = 27.658203125000004


Validation: 100%|██████████| 32/32 [00:01<00:00, 24.62batch/s, loss=1.09]


Validation Loss = 1.2992070019245148; Validation Accuracy = 26.806640625


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.13batch/s]


Test Accuracy = 22.8759765625
देशभरममध्येय देशभरामध्ये
पेशाधारियोंं पेशनधारियों
अनुक्रमानुपतत अनुक्रमानुपात
वाहिन्यांसाठीीी वाहिन्यांसाठी
लोकरचनाओं लोकरचनाओं
ट्रांसस्रिप्टेसट ट्रांसक्रिप्टेज़
गंतिविधियों गंतिविधियों
शासनाबब्दललल शासनाबद्दल
सर्वसंग्र सर्वसंग्रह
तुमच्यापाकीीीी तुमच्यापैकी


Epoch 8: 100%|██████████| 400/400 [00:50<00:00,  7.90batch/s, loss=0.374]


Train Loss = 0.5141593818366528; Train Accuracy = 32.470703125


Validation: 100%|██████████| 32/32 [00:01<00:00, 24.13batch/s, loss=1.08]


Validation Loss = 1.3377272933721542; Validation Accuracy = 26.1474609375


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.08batch/s]


Test Accuracy = 20.7275390625
देशभरामध्ये देशभरामध्ये
पेशनधारियों पेशनधारियों
अनुक्रमानुपात अनुक्रमानुपात
वाहिन्यांसाठीीीी वाहिन्यांसाठी
लोकरानाओं लोकरचनाओं
ट्रांसक्रेप्सेन ट्रांसक्रिप्टेज़
गणतिविधियोंं गंतिविधियों
शासनाबद्दल शासनाबद्दल
सर्वसंग्र सर्वसंग्रह
तुमच्यापाकीी तुमच्यापैकी


Epoch 9: 100%|██████████| 400/400 [00:50<00:00,  7.95batch/s, loss=0.373]


Train Loss = 0.45127433896064756; Train Accuracy = 34.099609375


Validation: 100%|██████████| 32/32 [00:01<00:00, 24.32batch/s, loss=1.2] 


Validation Loss = 1.3311801850795746; Validation Accuracy = 27.1484375


Testing: 100%|██████████| 32/32 [00:01<00:00, 20.89batch/s]

Test Accuracy = 21.19140625
देशभरामध्ये देशभरामध्ये
पेशनधारियों पेशनधारियों
अनुक्रमानुततत अनुक्रमानुपात
वाहिन्यांसाठीी वाहिन्यांसाठी
लोकरचनाओं लोकरचनाओं
ट्रांस्ककिक्टिस ट्रांसक्रिप्टेज़
गततिविधियों गंतिविधियों
शासंमद्ददल शासनाबद्दल
सर्वसंग्रह सर्वसंग्रह
तुमच्यापाकी तुमच्यापैकी





In [167]:
for i in range(10):
    _, _, p, w = runner.train(i, 0.2)
    runner.evaluate(); runner.inference()
    for x, y in list(zip(p,w))[:10]:
        print(x,y)

Epoch 0:  68%|██████▊   | 273/400 [00:34<00:15,  8.16batch/s, loss=0.641]