This notebook can be used to train RNN model with/without attention for transliteration task

In [None]:
import numpy as np
import pandas as pd
import torch
import os
from torch import nn
import sys
from torch.utils.data import Dataset, DataLoader
import wandb
import regex as re
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import wandb
import lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import WandbLogger
from torch.nn.functional import pad
import gc
from ..src.models import RNN_light, RNN_light_attention
import matplotlib.pyplot as plt
import seaborn as sns
from ..src.dataloader import NativeTokenizer, LatNatDataset

In [None]:
os.environ['WANDB_API_KEY'] = "key"
wandb.login(key=os.getenv("WANDB_API_KEY"))

In [None]:
train_path = "/kaggle/input/dakshina/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.train.tsv"
valid_path = "/kaggle/input/dakshina/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.dev.tsv"
test_path = "/kaggle/input/dakshina/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.test.tsv"

train_df = pd.read_csv(train_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
valid_df = pd.read_csv(valid_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
test_df = pd.read_csv(test_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')

train_df = train_df[~train_df['latin'].isna()]
valid_df = valid_df[~valid_df['latin'].isna()]
test_df = test_df[~test_df['latin'].isna()]


tokenizer = NativeTokenizer(train_path, valid_path, test_path)
print(f"Latin vocab size: {tokenizer.latin_vocab_size}")
print(f"Native vocab size: {tokenizer.nat_vocab_size}")

train_dataset = LatNatDataset(train_df, tokenizer)
valid_dataset = LatNatDataset(valid_df, tokenizer)
test_dataset = LatNatDataset(test_df, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate_fn , num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=test_dataset.collate_fn, num_workers=2)

# With attention

In [None]:
special_tokens = {key: val for key, val in tokenizer.native_vocab.items() if key in ['<start>', '<end>', '<pad>']}

model = RNN_light_attention(
            input_sizes=(tokenizer.latin_vocab_size, tokenizer.nat_vocab_size),
            embedding_size=128,
            hidden_size=256,
            cell='LSTM',
            layers=3,
            dropout=0.25,
            activation='tanh',
            beam_size=3,
            optim='adam',
            special_tokens=special_tokens,
            lr=0.0002)

logger= WandbLogger(project= 'DLA3_sweeps', name = "bestmodel") #,resume="never")
trainer = pl.Trainer(max_epochs=1,  accelerator="auto",logger=logger, profiler='simple',  precision="16-mixed",)
trainer.fit(model, train_dataloader,  valid_dataloader)
trainer.test(model, dataloaders=test_dataloader)
 
rand_ind = np.random.choice(len(model.test_preds), size=9, replace=False)
attention_map = [model.attention_maps[ind] for ind in rand_ind]
src = np.array(model.test_inputs)[rand_ind]
pred = np.array(model.test_preds)[rand_ind]
tgt = np.array(model.test_labels)[rand_ind]

fig, axes = plt.subplots(3, 3, figsize=(15, 12))
fig.suptitle("Attention map", fontsize=16)

for i, ax in enumerate(axes.flat):
    attn_map = attention_map[i][0:len(pred[i]), 0:len(src[i])]  

    sns.heatmap(attn_map, ax=ax, xticklabels=src[i], yticklabels=pred[i],
                cmap="Blues", cbar=True)

    ax.tick_params(axis='x', labelsize=8)
    ax.tick_params(axis='y', labelsize=8)
    fig.supxlabel("Latin script", fontsize=14)
    fig.supylabel("Tamil script", fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.96])  
    plt.show()




# Without attention

In [None]:
model = RNN_light(
            input_sizes=(tokenizer.latin_vocab_size, tokenizer.nat_vocab_size),
            embedding_size=128,
            hidden_size=256,
            cell='LSTM',
            layers=3,
            dropout=0.25,
            activation='tanh',
            beam_size=3,
            optim='adam',
            special_tokens=special_tokens,
            lr=0.0002)

logger= WandbLogger(project= 'DLA3_sweeps', name = "bestmodel") #,resume="never")
trainer = pl.Trainer(max_epochs=1,  accelerator="auto",logger=logger, profiler='simple',  precision="16-mixed",)
trainer.fit(model, train_dataloader,  valid_dataloader)
trainer.test(model, dataloaders=test_dataloader)

rand_ind = np.random.choice(len(model.test_preds), size=9, replace=False)
src = np.array(model.test_inputs)[rand_ind]
tgt = np.array(model.test_labels)[rand_ind]
preds = np.array(model.test_preds)[rand_ind]

# table fo comparison
fig, ax = plt.subplots(figsize=(10, len(src) * 1.5))
ax.axis("off")

table_data = [["Input", "Actual", "Prediction"]]
for inp, true, pred in zip(src, tgt, preds):
    table_data.append([inp, true, pred])

table = ax.table(cellText=table_data, colLabels=None, loc='center', cellLoc='left')
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1, 1.5)
