In [100]:
import numpy as np
import pandas as pd
import torch
import os
import sys
from torch.utils.data import Dataset, DataLoader
import wandb
import regex as re
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [22]:
train_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.train.tsv"
valid_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.dev.tsv"
test_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.test.tsv"

train_df = pd.read_csv(train_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
valid_df = pd.read_csv(valid_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
test_df = pd.read_csv(test_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')

train_df.head()

Unnamed: 0,native,latin,n_annot
0,ஃபியட்,fiat,2
1,ஃபியட்,phiyat,1
2,ஃபியட்,piyat,1
3,ஃபிரான்ஸ்,firaans,1
4,ஃபிரான்ஸ்,france,2


In [87]:
class NativeTokenizer():
    def __init__(self, train_path, valid_path, test_path, special_tokens={'START': '<start>','END':'<end>', 'PAD':'<pad>'}):
        
        self.train_df = pd.read_csv(train_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.valid_df = pd.read_csv(valid_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.test_df = pd.read_csv(test_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.special_tokens = special_tokens
        # Build vocabulary
        self._build_vocab(add_special_tokens=True)
        
        # Id to token mapping
        self.id_to_latin = {i: char for i, char in enumerate(self.latin_vocab)}
        self.id_to_native = {i: char for i, char in enumerate(self.native_vocab)}

        self.latin_vocab_size = len(self.latin_vocab)
        self.nat_vocab_size = len(self.native_vocab)

    # Build vocabulary
    def _build_vocab(self, add_special_tokens=True):
        self.nat_set = set()
        self.latin_set = set()
        for lat, nat in zip(self.train_df['latin'], self.train_df['native']):
            nat_chars = re.findall(r'\X' , nat)
            try:
                lat_chars = list(lat)
            except:
                print(f"Invalid latin string: {lat}, skipping....")
            
            for char in nat_chars:
                self.nat_set.add(char)
            for char in lat_chars:
               self.latin_set.add(char.lower())
            
        self.nat_set = sorted(list(self.nat_set))
        self.latin_set = sorted(list(self.latin_set))
        
        if add_special_tokens:
            self.nat_set = list(self.special_tokens.values()) + self.nat_set
            self.latin_set = [self.special_tokens['PAD']] + self.latin_set   

        self.latin_vocab = {char: i for i, char in enumerate(self.latin_set)}
        self.native_vocab = {char: i for i, char in enumerate(self.nat_set)}




In [88]:
tokenizer = NativeTokenizer(train_path, valid_path, test_path)
print(f"Latin vocab size: {tokenizer.latin_vocab_size}")
print(f"Native vocab size: {tokenizer.nat_vocab_size}")

Invalid latin string: nan, skipping....
Invalid latin string: nan, skipping....
Invalid latin string: nan, skipping....
Latin vocab size: 27
Native vocab size: 253


In [112]:
class LatNatDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        entry = self.df.iloc[idx]
        latin_word = entry['latin']
        native_word = entry['native']
               
        # Tokenize and convert to IDs
        latin_ids = [self.tokenizer.latin_vocab[i] for i in latin_word]
        native_ids = [self.tokenizer.native_vocab[i] for i in re.findall(r'\X' , native_word)]

        return (torch.tensor(latin_ids),
            torch.tensor(native_ids))

    def collate_fn(self, batch):
        x,y = zip(*batch)
        x_len = [len(seq) for seq in x]
        y_len = [len(seq) for seq in y]

        padded_x = pad_sequence(x, batch_first=True, padding_value=self.tokenizer.latin_vocab['<pad>'])
        padded_y = pad_sequence(y, batch_first=True, padding_value=self.tokenizer.native_vocab['<pad>'])
    
        return padded_x, x_len, padded_y, y_len



In [113]:
train_dataset = LatNatDataset(train_df, tokenizer)
valid_dataset = LatNatDataset(valid_df, tokenizer)
test_dataset = LatNatDataset(test_df, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=train_dataset.collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=valid_dataset.collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=test_dataset.collate_fn)

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

tensor([[ 3,  8,  5, 12,  1, 22,  1,  1, 11, 21, 13,  0,  0,  0,  0,  0,  0,  0],
        [ 5,  5, 20, 18, 20, 18, 21, 11, 11, 15, 12, 22,  1, 20,  8, 21,  0,  0],
        [ 1,  1, 18, 17, 20,  9, 17,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [20,  8, 21, 18,  1, 14,  4,  8,  1,  1, 18,  0,  0,  0,  0,  0,  0,  0],
        [22,  9, 12,  1,  9, 22, 21,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [16,  1, 11, 21, 20,  8,  9, 25,  1,  1, 14,  1,  0,  0,  0,  0,  0,  0],
        [13,  1, 20,  1, 12, 11,  1, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [22,  5,  5, 18,  1, 18, 11,  1, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [14,  1,  4, 21, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [10,  1, 25,  1, 14,  1,  7,  1, 18,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [25,  5,  5, 20,  9,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 4,  5,  5, 16,  9, 11,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [19, 15,