In [100]:
import numpy as np
import pandas as pd
import torch
import os
import sys
from torch.utils.data import Dataset, DataLoader
import wandb
import regex as re
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [22]:
train_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.train.tsv"
valid_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.dev.tsv"
test_path = "/home/user/Documents/Courses/dakshina_dataset_v1.0/ta/lexicons/ta.translit.sampled.test.tsv"

train_df = pd.read_csv(train_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
valid_df = pd.read_csv(valid_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
test_df = pd.read_csv(test_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')

train_df.head()

Unnamed: 0,native,latin,n_annot
0,ஃபியட்,fiat,2
1,ஃபியட்,phiyat,1
2,ஃபியட்,piyat,1
3,ஃபிரான்ஸ்,firaans,1
4,ஃபிரான்ஸ்,france,2


In [176]:
class NativeTokenizer():
    def __init__(self, train_path, valid_path, test_path, special_tokens={'START': '<start>','END':'<end>', 'PAD':'<pad>'}):
        
        self.train_df = pd.read_csv(train_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.valid_df = pd.read_csv(valid_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.test_df = pd.read_csv(test_path, sep="\t", header=None, names=["native", "latin", 'n_annot'], encoding='utf-8')
        self.special_tokens = special_tokens
        # Build vocabulary
        self._build_vocab(add_special_tokens=True)
        
        # Id to token mapping
        self.id_to_latin = {i: char for i, char in enumerate(self.latin_vocab)}
        self.id_to_native = {i: char for i, char in enumerate(self.native_vocab)}

        self.latin_vocab_size = len(self.latin_vocab)
        self.nat_vocab_size = len(self.native_vocab)

    # Build vocabulary
    def _build_vocab(self, add_special_tokens=True):
        self.nat_set = set()
        self.latin_set = set()
        for lat, nat in zip(self.train_df['latin'], self.train_df['native']):
            nat_chars = re.findall(r'\X' , nat)
            try:
                lat_chars = list(lat)
            except:
                print(f"Invalid latin string: {lat}, skipping....")
            
            for char in nat_chars:
                self.nat_set.add(char)
            for char in lat_chars:
               self.latin_set.add(char.lower())
            
        self.nat_set = sorted(list(self.nat_set))
        self.latin_set = sorted(list(self.latin_set))
        
        if add_special_tokens:
            self.nat_set = list(self.special_tokens.values()) + self.nat_set
            self.latin_set = [self.special_tokens['PAD']] + self.latin_set   

        self.latin_vocab = {char: i for i, char in enumerate(self.latin_set)}
        self.native_vocab = {char: i for i, char in enumerate(self.nat_set)}

    def tokenize(self, text, lang='latin'):
        if lang == 'latin':
            return [self.latin_vocab[char] for char in text]
        elif lang == 'native':
            return [self.native_vocab['<start>']] + [self.native_vocab[char] for char in re.findall('\X', text)] + [self.native_vocab['<end>']]
        else:
            raise ValueError("Language must be either 'latin' or 'native'.")




In [177]:
tokenizer = NativeTokenizer(train_path, valid_path, test_path)
print(f"Latin vocab size: {tokenizer.latin_vocab_size}")
print(f"Native vocab size: {tokenizer.nat_vocab_size}")

Invalid latin string: nan, skipping....
Invalid latin string: nan, skipping....
Invalid latin string: nan, skipping....
Latin vocab size: 27
Native vocab size: 253


In [178]:
class LatNatDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        entry = self.df.iloc[idx]
        latin_word = entry['latin']
        native_word = entry['native']
               
        # Tokenize and convert to IDs
        #latin_ids = [self.tokenizer.latin_vocab[i] for i in latin_word]
        #native_ids = [self.tokenizer.native_vocab[i] for i in re.findall(r'\X' , native_word)]
        latin_ids = self.tokenizer.tokenize(latin_word, lang='latin')
        native_ids = self.tokenizer.tokenize(native_word, lang='native')


        return (torch.tensor(latin_ids),
            torch.tensor(native_ids))

    def collate_fn(self, batch):
        x,y = zip(*batch)
        x_len = [len(seq) for seq in x]
        y_len = [len(seq) for seq in y]

        padded_x = pad_sequence(x, batch_first=True, padding_value=self.tokenizer.latin_vocab['<pad>'])
        padded_y = pad_sequence(y, batch_first=True, padding_value=self.tokenizer.native_vocab['<pad>'])
        
        x_len, perm_idx = torch.tensor(x_len).sort(0, descending=True)
        padded_x = padded_x[perm_idx]

        y_len = torch.tensor(y_len).sort(0, descending=True)
        padded_y = padded_y[perm_idx]

        return padded_x, x_len, padded_y, y_len



In [179]:
train_dataset = LatNatDataset(train_df, tokenizer)
valid_dataset = LatNatDataset(valid_df, tokenizer)
test_dataset = LatNatDataset(test_df, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=train_dataset.collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=valid_dataset.collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=test_dataset.collate_fn)

In [180]:
class Encoder(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(input_size, hidden_size)
        self.rnn = torch.nn.RNN(input_size=hidden_size, hidden_size=hidden_size, batch_first=True)
    
    def forward(self, seq, seq_len):
        embedding = self.embedding(input=seq)
        print("embedded")
        packed = pack_padded_sequence(input=embedding, lengths=seq_len, batch_first=True, enforce_sorted=True)
        output, hidden = self.rnn(packed)
        output, _ = pad_packed_sequence(output, batch_first=True)
        return output, hidden

In [185]:
class Decoder(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = torch.nn.Embedding(input_size, hidden_size)
        self.rnn = torch.nn.RNN(input_size=hidden_size, hidden_size=hidden_size, batch_first=True)
    
    def forward(self, seq, seq_len):
        embedding = self.embedding(input=seq)
        packed = pack_padded_sequence(input=embedding, lengths=seq_len, batch_first=True, enforce_sorted=True)
        output, hidden = self.rnn(packed)
        output, _ = pad_packed_sequence(output, batch_first=True)
        return output

In [186]:
encoder = Encoder(tokenizer.latin_vocab_size, 128)
decoder = Decoder(tokenizer.nat_vocab_size, 128)


In [189]:
for batch in train_dataloader:
    x, x_len, y, y_len = batch
    _, hidden = encoder(x, x_len)
    for i in range(y.shape[1]):
        output, hidden = decoder(y[:, i].unsqueeze(1), y_len)
        print(output.shape)
    

embedded


TypeError: only integer tensors of a single element can be converted to an index

In [183]:
y

tensor([[  0,  82, 148, 161,  27,  15, 133, 120,  69,  58,  86,   1],
        [  0,  32, 183, 174, 184,  27,  17, 147,  28,  15, 207,   1],
        [  0,   6, 119, 174, 147, 142, 148,  82, 225,   1,   2,   2],
        [  0,  10, 212,  84, 149, 163, 106,  83, 171,   1,   2,   2],
        [  0,  83,  27,  19,  82, 186, 195,   1,   2,   2,   2,   2],
        [  0, 124,  93,  82,  15, 135,  15,   1,   2,   2,   2,   2],
        [  0, 224,  69,  66, 148,  62, 146,   1,   2,   2,   2,   2],
        [  0,  35, 215,  96, 192,  15, 207,   1,   2,   2,   2,   2],
        [  0,   4, 180, 147, 133, 120,  69,  62,   1,   2,   2,   2],
        [  0,  95, 119,  16, 149, 159, 146,   1,   2,   2,   2,   2],
        [  0, 135, 176, 121,  62,  15, 207,   1,   2,   2,   2,   2],
        [  0, 216, 159, 224,  81,  62, 146,   1,   2,   2,   2,   2],
        [  0,  83, 107, 135,  15,   1,   2,   2,   2,   2,   2,   2],
        [  0,  82, 196,  93,  90,   1,   2,   2,   2,   2,   2,   2],
        [  0,  96, 1