In [2]:
import torch, random, itertools, math, json
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from collections import Counter
from jiwer import wer
import pandas as pd
# !pip install jiwer

# from google.colab import drive
# drive.mount('/content/drive')

# 1. Load three subsets

In [None]:
# from datasets import load_dataset
# import csv

# lang_ids = {"hin": "hin", "tam": "tam", "ben": "ben"}
# output_file = "train.csv"

# with open(output_file, "w", newline='') as csvfile:
#     writer = csv.DictWriter(csvfile, fieldnames=["roman", "native", "lang"])
#     writer.writeheader()

#     for tag in lang_ids:
#         ds = load_dataset("ai4bharat/Aksharantar", streaming=True, split="train")

#         def belongs_to_lang(x):
#             return x.get("unique_identifier", "").startswith(tag)

#         filtered = ds.filter(belongs_to_lang)

#         count = 0
#         for x in filtered:
#             try:
#                 roman = x["english word"]
#                 native = x["native word"]
#                 score = x.get("score", 0.0)  # default to 0.0 if missing
#                 if score < 0.35:
#                     writer.writerow({
#                         "roman": f"<LANG_{tag.upper()}> {roman}",
#                         "native": native,
#                         "lang": tag
#                     })
#                     count += 1
#             except Exception:
#                 continue

#         print(f"✅ Saved {count} samples for {tag}")


In [3]:
df = pd.read_csv("/content/drive/MyDrive/pc_draft/train.csv")

In [4]:
df.describe

In [5]:
df.columns

Index(['roman', 'native', 'lang'], dtype='object')

In [6]:
# Clean up: ensure strings and remove NaNs
df = df[['roman', 'native', 'lang']].dropna()
df['roman'] = df['roman'].astype(str)
df['native'] = df['native'].astype(str)

In [7]:
df['lang_token'] = df['roman'].str.extract(r'(<LANG_\w+>)')
df['roman_clean'] = df['roman'].str.replace(r'<LANG_\w+>\s*', '', regex=True)

In [8]:
df['lang_token'].value_counts()


Unnamed: 0_level_0,count
lang_token,Unnamed: 1_level_1
<LANG_TAM>,3148883
<LANG_BEN>,1026393
<LANG_HIN>,968726


In [9]:
# Set seed for reproducibility
sample_size = 500_000
random_state = 42
import pandas as pd

# Sample 500k from each language token
df_balanced = pd.concat([
    df[df['lang_token'] == '<LANG_HIN>'].sample(n=sample_size, random_state=random_state),
    df[df['lang_token'] == '<LANG_BEN>'].sample(n=sample_size, random_state=random_state),
    df[df['lang_token'] == '<LANG_TAM>'].sample(n=sample_size, random_state=random_state)
], ignore_index=True)

# Optional: Shuffle the combined dataset
df_balanced = df_balanced.sample(frac=1, random_state=random_state).reset_index(drop=True)

# ✅ Final check
df_balanced['lang_token'].value_counts()


Unnamed: 0_level_0,count
lang_token,Unnamed: 1_level_1
<LANG_BEN>,500000
<LANG_HIN>,500000
<LANG_TAM>,500000


In [10]:
df = df_balanced
df['lang_token'].value_counts()


Unnamed: 0_level_0,count
lang_token,Unnamed: 1_level_1
<LANG_BEN>,500000
<LANG_HIN>,500000
<LANG_TAM>,500000


In [11]:
def build_vocab_char(df):
    charset = set()

    for _, row in df.iterrows():
        charset.update(set(row['roman_clean']))   # add all roman characters
        charset.update(set(row['native']))        # add all native characters

    # ✅ Add all <LANG_XXX> tokens explicitly (not character-by-character!)
    charset.update(df['lang_token'].unique())     # this ensures all language tokens are added
    print(f"found land is {df['lang_token'].unique()}")

    charset = sorted(list(charset))

    stoi = {c: i + 1 for i, c in enumerate(charset)}  # leave 0 for <pad>
    stoi['<pad>'] = 0
    stoi['<sos>'] = len(stoi)
    stoi['<eos>'] = len(stoi)

    itos = {i: c for c, i in stoi.items()}
    return stoi, itos


In [12]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class MultilingualTranslitDataset(Dataset):
    def __init__(self, df, vocab, max_len=40):
        self.vocab = vocab
        self.samples = []

        for _, row in df.iterrows():
            src = [vocab[row['lang_token']]] + [vocab[c] for c in row['roman_clean'] if c in vocab]
            tgt = [vocab['<sos>']] + [vocab[c] for c in row['native'] if c in vocab] + [vocab['<eos>']]

            if len(src) <= max_len and len(tgt) <= max_len:
                self.samples.append((torch.tensor(src), torch.tensor(tgt)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    srcs, tgts = zip(*batch)
    src_pad = pad_sequence(srcs, batch_first=True, padding_value=0)
    tgt_pad = pad_sequence(tgts, batch_first=True, padding_value=0)
    return src_pad, tgt_pad


In [13]:
import torch.nn as nn
import torch

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True, bidirectional=True)

    def forward(self, src):
        embedded = self.embedding(src)  # [B, T] -> [B, T, emb_dim]
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, hidden, cell  # outputs: [B, T, 2*hidden_dim]

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(enc_hid_dim * 2 + dec_hid_dim, 1)

    def forward(self, enc_outputs, dec_hidden):
        batch_size, seq_len, enc_dim = enc_outputs.size()
        dec_hidden = dec_hidden.unsqueeze(1).repeat(1, seq_len, 1)  # [B, T, H]
        energy = self.attn(torch.cat((enc_outputs, dec_hidden), dim=2)).squeeze(2)  # [B, T]
        attn_weights = torch.softmax(energy, dim=1)  # [B, T]
        context = torch.bmm(attn_weights.unsqueeze(1), enc_outputs).squeeze(1)  # [B, 2*H]
        return context

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=0)
        self.attention = Attention(enc_hid_dim, dec_hid_dim)
        self.lstm = nn.LSTM(emb_dim + enc_hid_dim * 2, dec_hid_dim, batch_first=True)
        self.fc = nn.Linear(dec_hid_dim, output_dim)

    def forward(self, input_token, hidden, cell, enc_outputs):
        embedded = self.embedding(input_token).unsqueeze(1)  # [B] -> [B, 1, emb_dim]
        context = self.attention(enc_outputs, hidden[-1])     # [B, 2*enc_hid_dim]
        context = context.unsqueeze(1)                        # [B, 1, 2*enc_hid_dim]
        rnn_input = torch.cat((embedded, context), dim=2)     # [B, 1, emb+2*H]
        output, (hidden, cell) = self.lstm(rnn_input, (hidden, cell))
        prediction = self.fc(output.squeeze(1))               # [B, output_dim]
        return prediction, hidden, cell


In [14]:
vocab, inv_vocab = build_vocab_char(df)

dataset = MultilingualTranslitDataset(df, vocab)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

# Define encoder/decoder
vocab_size = len(vocab)
encoder = Encoder(input_dim=vocab_size, emb_dim=128, hidden_dim=256)
decoder = Decoder(output_dim=vocab_size, emb_dim=128, enc_hid_dim=256, dec_hid_dim=256)

# Training loop same as before


found land is ['<LANG_BEN>' '<LANG_HIN>' '<LANG_TAM>']


In [15]:
print(vocab)

{'<LANG_BEN>': 1, '<LANG_HIN>': 2, '<LANG_TAM>': 3, 'a': 4, 'b': 5, 'c': 6, 'd': 7, 'e': 8, 'f': 9, 'g': 10, 'h': 11, 'i': 12, 'j': 13, 'k': 14, 'l': 15, 'm': 16, 'n': 17, 'o': 18, 'p': 19, 'q': 20, 'r': 21, 's': 22, 't': 23, 'u': 24, 'v': 25, 'w': 26, 'x': 27, 'y': 28, 'z': 29, 'ँ': 30, 'ं': 31, 'ः': 32, 'अ': 33, 'आ': 34, 'इ': 35, 'ई': 36, 'उ': 37, 'ऊ': 38, 'ऋ': 39, 'ए': 40, 'ऐ': 41, 'ऑ': 42, 'ओ': 43, 'औ': 44, 'क': 45, 'ख': 46, 'ग': 47, 'घ': 48, 'ङ': 49, 'च': 50, 'छ': 51, 'ज': 52, 'झ': 53, 'ञ': 54, 'ट': 55, 'ठ': 56, 'ड': 57, 'ढ': 58, 'ण': 59, 'त': 60, 'थ': 61, 'द': 62, 'ध': 63, 'न': 64, 'प': 65, 'फ': 66, 'ब': 67, 'भ': 68, 'म': 69, 'य': 70, 'र': 71, 'ल': 72, 'ळ': 73, 'व': 74, 'श': 75, 'ष': 76, 'स': 77, 'ह': 78, '़': 79, 'ा': 80, 'ि': 81, 'ी': 82, 'ु': 83, 'ू': 84, 'ृ': 85, 'ॅ': 86, 'े': 87, 'ै': 88, 'ॉ': 89, 'ॊ': 90, 'ो': 91, 'ौ': 92, '्': 93, 'ঁ': 94, 'ং': 95, 'ঃ': 96, 'অ': 97, 'আ': 98, 'ই': 99, 'ঈ': 100, 'উ': 101, 'ঊ': 102, 'ঋ': 103, 'এ': 104, 'ঐ': 105, 'ও': 106, 'ঔ': 107, 'ক': 108, 

In [16]:
def predict_multilingual(word_with_lang_token, encoder, decoder, vocab, itos, device, max_len=40):
    import time
    start_time = time.time()

    encoder.eval()
    decoder.eval()

    with torch.no_grad():
        tokens = word_with_lang_token.strip().split()
        lang_token = tokens[0]
        word = ' '.join(tokens[1:])
        src_seq = [vocab[lang_token]] + [vocab[c] for c in word if c in vocab]
        src = torch.tensor(src_seq).unsqueeze(0).to(device)

        enc_out, h, c = encoder(src)

        # Merge bi-directional encoder states
        h = h[0:h.size(0):2] + h[1:h.size(0):2]
        c = c[0:c.size(0):2] + c[1:c.size(0):2]

        dec_input = torch.tensor([vocab['<sos>']], device=device)
        preds = []

        for _ in range(max_len):
            output, h, c = decoder(dec_input, h, c, enc_out)
            pred_token = output.argmax(1).item()
            if itos.get(pred_token) == '<eos>':
                break
            preds.append(pred_token)
            dec_input = torch.tensor([pred_token], device=device)

    duration = time.time() - start_time
    print(f"Inference took {duration:.2f} seconds")
    return ''.join(itos[i] for i in preds if i in itos)


In [17]:
from jiwer import cer, wer

def evaluate_multilang(df, encoder, decoder, vocab, itos, device):
    results = []
    for lang in df['lang'].unique():
        lang_df = df[df['lang'] == lang].sample(n=200)  # Sample subset
        total_cer, total_wer = 0, 0
        for _, row in lang_df.iterrows():
            input_seq = f"{row['lang_token']} {row['roman_clean']}"
            pred = predict_multilingual(input_seq, encoder, decoder, vocab, itos, device)
            cer_score = cer(row['native'], pred)
            wer_score = wer(row['native'], pred)
            total_cer += cer_score
            total_wer += wer_score
        avg_cer = total_cer / len(lang_df)
        avg_wer = total_wer / len(lang_df)
        results.append((lang, avg_cer, avg_wer))
    return results


In [18]:
import os
from tqdm import tqdm
import time
import torch

def train(encoder, decoder, dataloader, vocab, criterion, optimizer, device='cpu', num_epochs=10, checkpoint_dir="checkpoints", start_epoch=0):
    os.makedirs(checkpoint_dir, exist_ok=True)

    encoder.to(device)
    decoder.to(device)

    for epoch in range(start_epoch, num_epochs):
        start_time = time.time()

        encoder.train()
        decoder.train()
        total_loss = 0

        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, (src_batch, tgt_batch) in enumerate(progress_bar):
            src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)

            optimizer.zero_grad()
            enc_outputs, h, c = encoder(src_batch)

            # Merge bidirectional states
            h = h[::2] + h[1::2]
            c = c[::2] + c[1::2]

            dec_input = tgt_batch[:, 0]
            loss = 0

            for t in range(1, tgt_batch.size(1)):
                output, h, c = decoder(dec_input, h, c, enc_outputs)
                loss += criterion(output, tgt_batch[:, t])
                dec_input = tgt_batch[:, t]

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        epoch_time = time.time() - start_time
        print(f"\nEpoch {epoch+1} finished in {epoch_time:.2f}s | Total Loss: {total_loss:.4f}")

        # ✅ Save checkpoint
        checkpoint_path = os.path.join("/content/drive/MyDrive/pc_draft", f"checkpoint_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'vocab': vocab
        }, checkpoint_path)

        print(f"✅ Saved checkpoint to {checkpoint_path}")


In [19]:
for key in vocab.keys():
    if key.startswith("<LANG"):
        print("Found:", key)

Found: <LANG_BEN>
Found: <LANG_HIN>
Found: <LANG_TAM>


In [20]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
device = torch.device("cuda")

train(encoder, decoder, dataloader, vocab, criterion, optimizer, device=device, num_epochs=10)


Epoch 1/10: 100%|██████████| 23438/23438 [18:28<00:00, 21.15it/s, loss=2.06]



Epoch 1 finished in 1108.29s | Total Loss: 114475.3078
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_1.pth


Epoch 2/10: 100%|██████████| 23438/23438 [18:19<00:00, 21.32it/s, loss=1.84]



Epoch 2 finished in 1099.20s | Total Loss: 54512.2484
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_2.pth


Epoch 3/10: 100%|██████████| 23438/23438 [18:11<00:00, 21.47it/s, loss=1.64]



Epoch 3 finished in 1091.69s | Total Loss: 46772.6660
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_3.pth


Epoch 4/10: 100%|██████████| 23438/23438 [18:04<00:00, 21.62it/s, loss=2.97]



Epoch 4 finished in 1084.18s | Total Loss: 42918.3472
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_4.pth


Epoch 5/10: 100%|██████████| 23438/23438 [18:04<00:00, 21.62it/s, loss=1.22]



Epoch 5 finished in 1084.04s | Total Loss: 40292.9706
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_5.pth


Epoch 6/10: 100%|██████████| 23438/23438 [18:06<00:00, 21.57it/s, loss=1.13]



Epoch 6 finished in 1086.60s | Total Loss: 38538.0547
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_6.pth


Epoch 7/10: 100%|██████████| 23438/23438 [18:12<00:00, 21.46it/s, loss=2.68]



Epoch 7 finished in 1092.37s | Total Loss: 37331.7633
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_7.pth


Epoch 8/10: 100%|██████████| 23438/23438 [18:31<00:00, 21.09it/s, loss=1.91]



Epoch 8 finished in 1111.34s | Total Loss: 36290.7890
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_8.pth


Epoch 9/10: 100%|██████████| 23438/23438 [18:02<00:00, 21.65it/s, loss=2.31]



Epoch 9 finished in 1082.74s | Total Loss: 35340.5804
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_9.pth


Epoch 10/10: 100%|██████████| 23438/23438 [18:04<00:00, 21.62it/s, loss=2.07]



Epoch 10 finished in 1084.15s | Total Loss: 34698.9414
✅ Saved checkpoint to /content/drive/MyDrive/pc_draft/checkpoint_epoch_10.pth


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [21]:
itos = {i: ch for ch, i in vocab.items()}

torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'vocab': vocab,
    'itos': itos,
}, "/content/drive/MyDrive/pc_draft/translit_model_checkpoint.pth")
print("✅ Model saved.")


✅ Model saved.


In [None]:
def predict_multilingual(word_with_lang_token, encoder, decoder, vocab, itos, device, max_len=40):
    encoder.eval(); decoder.eval()
    with torch.no_grad():
        tokens = word_with_lang_token.strip().split()
        lang_token = tokens[0]
        word = ' '.join(tokens[1:])
        src_seq = [vocab[lang_token]] + [vocab[c] for c in word if c in vocab]
        src = torch.tensor(src_seq).unsqueeze(0).to(device)

        enc_out, h, c = encoder(src)
        dec_input = torch.tensor([vocab['<sos>']], device=device)
        preds = []
        for _ in range(max_len):
            output, h, c = decoder(dec_input, h, c, enc_out)
            pred_token = output.argmax(1).item()
            if itos[pred_token] == '<eos>':
                break
            preds.append(pred_token)
            dec_input = torch.tensor([pred_token], device=device)
        return ''.join(itos[i] for i in preds if i in itos)
