# CS6910 Assignment 3 (RNN Frameworks for transliteration)

In [8]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import tqdm
import wandb
import unicodedata
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd

In [20]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
TARGET = 'tam'
SOURCE = 'eng'
SOS_SYM = '@'
EOS_SYM = '$'
UNK_SYM = '!'
PAD_SYM = '%'

unicode_ranges = {'tam' : [0x0B80, 0x0BFF], 
                  'eng' : [0x0061, 0x007A],
                  'hin' : [0x0900, 0x097F]}

cuda


## Preprocessing Functions and Helpers

In [18]:
# function to load the 'cat' (= train/val/test) data of language 'lang'
def load_data(lang, cat):
    fcontents = open(f'aksharantar_sampled/{lang}/{lang}_{cat}.csv','r', encoding='utf-8').readlines()
    pairs = [tuple(l.strip().split(',')) for l in fcontents]
    x_data, y_data = list(map(list,zip(*pairs)))
    return x_data, y_data

class Language:
    def __init__(self, name):
        self.lname = name
    
    # function to create the vocabulary using the words in 'data'
    def create_vocabulary(self, *data):
        symbols = set()
        for wd in data:
            for c in wd:
                symbols.add(c)
        self.symbols = symbols

    # function to use unicode ranges for creating the character set
    def create_vocabulary_range(self):
        symbols = set()
        begin, end = unicode_ranges[self.lname]
        for i in range(begin, end+1):
            if (unicodedata.category(chr(i)) != 'Cn'):
                symbols.add(chr(i))
        self.symbols = symbols
    
    def generate_mappings(self):
        self.index2sym = {0: SOS_SYM, 1 : EOS_SYM, 2 : UNK_SYM, 3 : PAD_SYM}
        self.sym2index = {SOS_SYM : 0, EOS_SYM : 1, UNK_SYM : 2, PAD_SYM : 3}
        self.symbols = list(self.symbols)
        self.symbols.sort()

        for i, sym in enumerate(self.symbols):
            self.sym2index[sym] = i + 3
            self.index2sym[i+3] = sym
        
        self.num_tokens = len(self.index2sym.keys())
    
    def convert_to_numbers(self, word):
        enc = [self.sym2index[SOS_SYM]]
        for ch in word:
            if ch in self.sym2index.keys():
                enc.append(self.sym2index[ch])
            else:
                enc.append(self.sym2index[UNK_SYM])
        enc.append(self.sym2index[EOS_SYM])
        return enc

    def get_index(self, sym):
        return self.sym2index[sym]

In [12]:
x_train, y_train = load_data(TARGET, 'train')
x_valid, y_valid = load_data(TARGET, 'valid')
x_test, y_test = load_data(TARGET, 'test')

print(f'Number of train samples = {len(x_train)}')
print(f'Number of valid samples = {len(x_valid)}')
print(f'Number of test samples = {len(x_test)}')

Number of train samples = 51200
Number of valid samples = 4096
Number of test samples = 4096


In [22]:
# create language objects for storing vocabulary, index2sym and sym2index
src_lang = Language(SOURCE)
tar_lang = Language(TARGET)

# creating vocabulary using all data
src_lang.create_vocabulary(*(x_train), *(x_valid), *(x_test))
tar_lang.create_vocabulary(*(y_train), *(y_valid), *(y_test))

# otherwise, use unicode characters (assigned codepoints) in the script's range
# src_lang.create_vocabulary_range()
# tar_lang.create_vocabulary_range()

# generate mappings from characters to numbers and vice versa
src_lang.generate_mappings()
tar_lang.generate_mappings()

print(f'Source Vocabulary Size = {len(src_lang.symbols)}')
print(f'Source Vocabulary = {src_lang.symbols}')
print(f'Source Mapping {src_lang.index2sym}')
print(f'Target Vocabulary Size = {len(tar_lang.symbols)}')
print(f'Target Vocabulary = {tar_lang.symbols}')
print(f'Target Mapping {tar_lang.index2sym}')

Source Vocabulary Size = 26
Source Vocabulary = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Source Mapping {0: '@', 1: '$', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'x', 26: 'y', 27: 'z'}
Target Vocabulary Size = 72
Target Vocabulary = ['ஂ', 'ஃ', 'அ', 'ஆ', 'இ', 'ஈ', 'உ', 'ஊ', 'எ', 'ஏ', 'ஐ', 'ஒ', 'ஓ', 'ஔ', 'க', 'ங', 'ச', 'ஜ', 'ஞ', 'ட', 'ண', 'த', 'ந', 'ன', 'ப', 'ம', 'ய', 'ர', 'ற', 'ல', 'ள', 'ழ', 'வ', 'ஶ', 'ஷ', 'ஸ', 'ஹ', 'ா', 'ி', 'ீ', 'ு', 'ூ', 'ெ', 'ே', 'ை', 'ொ', 'ோ', 'ௌ', '்', 'ௐ', 'ௗ', '௦', '௧', '௨', '௩', '௪', '௫', '௬', '௭', '௮', '௯', '௰', '௱', '௲', '௳', '௴', '௵', '௶', '௷', '௸', '௹', '௺']
Target Mapping {0: '@', 1: '$', 2: 'ஂ', 3: 'ஃ', 4: 'அ', 5: 'ஆ', 6: 'இ', 7: 'ஈ', 8: 'உ', 9: 'ஊ', 10: 'எ', 11: 'ஏ', 12: 'ஐ', 13: 'ஒ', 14: 'ஓ', 15: 'ஔ', 16: 

In [None]:
class TransliterateDataset(Dataset):
    def __init__(self, x_data, y_data, src_lang : Language, tar_lang : Language):
        self.x_data = x_data
        self.y_data = y_data
        self.src_lang = src_lang
        self.tar_lang = tar_lang
        
    def __len__(self):
        return len(self.y_data)

    def __getitem__(self, idx):
        x, y = self.x_data[idx], self.y_data[idx]
        x = self.src_lang.convert_to_numbers(x)
        y = self.tar_lang.convert_to_numbers(y) 
        return torch.Tensor(x), torch.Tensor(y)

class CollationFunction:
    def __init__(self, src_lang : Language, tar_lang : Language):
        self.src_lang = src_lang
        self.tar_lang = tar_lang
    
    def __call__(self, batch):
        src, tar = zip(*batch)
        src = pad_sequence(list(src), batch_first=True, padding_value=src_lang.get_index(PAD_SYM))
        tar = pad_sequence(list(tar), batch_first=True, padding_value=tar_lang.get_index(PAD_SYM))
        return src, tar

## Encoder Model

In [None]:
class EncoderNet(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, hid_size, cell_type, 
                 bidirect=False, dropout=0):
        super(EncoderNet, self).__init__()
        self.hidden_size = hid_size
        self.embed_size = embed_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(p=dropout)
        
        # we create the required architecture using the received parameters
        if cell_type == 'RNN':
            self.network = nn.RNN(input_size=embed_size, hidden_size=hid_size, num_layers=num_layers, 
                               dropout=dropout, bidirectional=bidirect)
        elif cell_type == 'LSTM':
            self.network = nn.LSTM(input_size=embed_size, hidden_size=hid_size, num_layers=num_layers, 
                               dropout=dropout, bidirectional=bidirect)
        else:
            self.network = nn.GRU(input_size=embed_size, hidden_size=hid_size, num_layers=num_layers, 
                               dropout=dropout, bidirectional=bidirect)
        
        self.cell_type = cell_type
        self.bidirect = bidirect
        
    def forward():
