In [1]:
# ------------------------------
# Cell 1: Imports & Setup
# ------------------------------
import os
import random
import pickle
from collections import Counter
import time
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1aab36f88f0>

In [2]:
# ------------------------------
# Cell 2: Device detection
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("CUDA device:", torch.cuda.get_device_name(0))
    torch.backends.cudnn.benchmark = True

Using device: cuda
CUDA device: NVIDIA GeForce RTX 3050 4GB Laptop GPU


In [3]:
# ------------------------------
# Cell 3: Paths & folders
# ------------------------------
os.makedirs("data", exist_ok=True)
os.makedirs("vocab", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

CORPUS_PATH = "data/all_hindi_clean.txt"
DATA_PAIRS_PATH = "data/data_pairs.pkl"
VOCAB_PATH = "vocab/hindi_vocab_100k.tsv"

In [4]:
# ------------------------------
# Cell 4: Load corpus
# ------------------------------
MAX_LINES = 50000  # limit

if not os.path.exists(CORPUS_PATH):
    raise FileNotFoundError(f"Corpus not found at {CORPUS_PATH}")

sentences = []
with open(CORPUS_PATH, "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if i >= MAX_LINES:
            break
        line = line.strip()
        if line:
            sentences.append(line)

print(f"Loaded {len(sentences)} sentences (limited to {MAX_LINES})")

Loaded 50000 sentences (limited to 50000)


In [5]:
# ------------------------------
# Cell 5: Create typos function
# ------------------------------
def create_typos(sentence, typo_prob=0.2):
    words = sentence.split()
    new_words = []
    for w in words:
        if random.random() < typo_prob:
            typo_type = random.choice(["delete", "replace", "transpose"])
            if typo_type == "delete" and len(w) > 1:
                i = random.randint(0, len(w)-1)
                w = w[:i] + w[i+1:]
            elif typo_type == "replace" and len(w) > 0:
                i = random.randint(0, len(w)-1)
                w = w[:i] + random.choice(list(w)) + w[i+1:]
            elif typo_type == "transpose" and len(w) > 1:
                i = random.randint(0, len(w)-2)
                w = w[:i] + w[i+1] + w[i] + w[i+2:]
        new_words.append(w)
    return " ".join(new_words)

In [6]:
# ------------------------------
# Cell 6: Create dataset pairs
# ------------------------------
data_pairs = [(create_typos(s), s) for s in sentences]
print("Sample pair:", data_pairs[0])

Sample pair: ('के', 'के')


In [7]:
# ------------------------------
# Cell 7: Build vocabulary
# ------------------------------
word_counter = Counter()
for _, target in data_pairs:
    word_counter.update(target.split())

PAD, SOS, EOS, UNK = "<PAD>", "<SOS>", "<EOS>", "<UNK>"

vocab = {PAD:0, SOS:1, EOS:2, UNK:3}
top_k = 100000
for word, _ in word_counter.most_common(top_k):
    if word not in vocab:
        vocab[word] = len(vocab)

rev_vocab = {idx: word for word, idx in vocab.items()}
vocab_size = len(vocab)
print("Vocab size:", vocab_size)

# Save vocab
with open(VOCAB_PATH, "w", encoding="utf-8") as f:
    for word, idx in vocab.items():
        f.write(f"{word}\t{idx}\n")
print("Vocabulary saved!")

Vocab size: 44310
Vocabulary saved!


In [8]:
# ------------------------------
# Cell 8: Save data pairs
# ------------------------------
with open(DATA_PAIRS_PATH, "wb") as f:
    pickle.dump(data_pairs, f)
print("Data pairs saved!")

Data pairs saved!


In [9]:
# ------------------------------
# Cell 9: Dataset & DataLoader
# ------------------------------
class HindiSpellDataset(Dataset):
    def __init__(self, pairs, vocab):
        self.pairs = pairs
        self.vocab = vocab
        self.SOS = vocab[SOS]
        self.EOS = vocab[EOS]
        self.UNK = vocab[UNK]

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        src_ids = [self.vocab.get(w, self.UNK) for w in src.split()] + [self.EOS]
        tgt_ids = [self.SOS] + [self.vocab.get(w, self.UNK) for w in tgt.split()] + [self.EOS]
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)

def collate_fn(batch):
    PAD_IDX = vocab[PAD]
    src_batch, tgt_batch = zip(*batch)
    src_max = max(len(s) for s in src_batch)
    tgt_max = max(len(t) for t in tgt_batch)
    src_padded = torch.full((len(batch), src_max), PAD_IDX, dtype=torch.long)
    tgt_padded = torch.full((len(batch), tgt_max), PAD_IDX, dtype=torch.long)
    src_lengths = []
    tgt_lengths = []
    for i, (s, t) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :len(s)] = s
        tgt_padded[i, :len(t)] = t
        src_lengths.append(len(s))
        tgt_lengths.append(len(t))
    return src_padded, tgt_padded, torch.tensor(src_lengths), torch.tensor(tgt_lengths)

def make_dataloader(pairs, batch_size=16, shuffle=True):
    dataset = HindiSpellDataset(pairs, vocab)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn,
                        num_workers=0, pin_memory=(device.type=="cuda"))
    return loader

In [10]:
# ------------------------------
# Cell 10: Subsample dataset for laptop training
# ------------------------------
TRAIN_SUBSET = 50_000
if len(data_pairs) > TRAIN_SUBSET:
    data_pairs_subset = random.sample(data_pairs, TRAIN_SUBSET)
else:
    data_pairs_subset = data_pairs

print(f"Training on {len(data_pairs_subset)} sentence pairs (subset)")
dataloader = make_dataloader(data_pairs_subset, batch_size=16)


Training on 50000 sentence pairs (subset)


In [11]:
# ------------------------------
# Cell 11: Model definition
# ------------------------------
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        embedded = self.dropout(self.embedding(x))
        packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (h, c) = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        return out, (h, c)

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, hidden):
        x = x.unsqueeze(1)
        embedded = self.dropout(self.embedding(x))
        output, hidden = self.lstm(embedded, hidden)
        output = self.fc(output.squeeze(1))
        return output, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, tgt_len, vocab_size, device=self.device)

        encoder_out, hidden = self.encoder(src, src_lengths)
        input = tgt[:, 0]
        for t in range(1, tgt_len):
            output, hidden = self.decoder(input, hidden)
            outputs[:, t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = tgt[:, t] if teacher_force else top1
        return outputs


In [12]:
# ------------------------------
# Cell 12: Initialize model & optimizer
# ------------------------------
embed_size = 192
hidden_size = 256
num_layers = 2
dropout = 0.1

encoder = Encoder(vocab_size, embed_size, hidden_size, num_layers, dropout)
decoder = Decoder(vocab_size, embed_size, hidden_size, num_layers, dropout)
model = Seq2Seq(encoder, decoder, device).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD])
optimizer = optim.Adam(model.parameters(), lr=0.001)

use_amp = (device.type=="cuda")
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


In [13]:
## ------------------------------
# Cell 13: Training loop (subset + gradient accumulation) (with timing & stats or tqdm)
# ------------------------------

num_epochs = 3
max_grad_norm = 1.0
accum_steps = 2  # simulate batch size 32

model.train()
for epoch in range(1, num_epochs+1):
    epoch_loss = 0.0
    start_time = time.time()
    
    batch_iterator = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}")
    
    for batch_idx, (src, tgt, src_lengths, tgt_lengths) in batch_iterator:
        batch_start = time.time()
        
        src, tgt = src.to(device), tgt.to(device)
        src_lengths, tgt_lengths = src_lengths.to(device), tgt_lengths.to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type=device.type, enabled=use_amp):
            output = model(src, src_lengths, tgt, teacher_forcing_ratio=0.5)
            output_dim = output.shape[-1]
            output = output[:,1:].reshape(-1, output_dim)
            tgt_target = tgt[:,1:].reshape(-1)
            loss = criterion(output, tgt_target) / accum_steps

        scaler.scale(loss).backward()

        if (batch_idx + 1) % accum_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        epoch_loss += loss.item() * accum_steps
        
        batch_time = time.time() - batch_start
        elapsed = time.time() - start_time
        avg_batch_time = elapsed / (batch_idx + 1)
        remaining_batches = len(dataloader) - batch_idx - 1
        eta = remaining_batches * avg_batch_time
        
        batch_iterator.set_postfix({
            "Loss": f"{loss.item()*accum_steps:.4f}",
            "Batch Time": f"{batch_time:.2f}s",
            "ETA": f"{eta/60:.1f}m"
        })
    
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch} Average Loss: {avg_loss:.4f}")

    ckpt_path = f"checkpoints/seq2seq_epoch{epoch}.pt"
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "vocab": vocab
    }, ckpt_path)
    print(f"Checkpoint saved: {ckpt_path}")

print("Training complete!")

Epoch 1: 100%|████████████████████████████| 3125/3125 [03:29<00:00, 14.94it/s, Loss=5.4196, Batch Time=0.25s, ETA=0.0m]


Epoch 1 Average Loss: 5.5091
Checkpoint saved: checkpoints/seq2seq_epoch1.pt


Epoch 2: 100%|████████████████████████████| 3125/3125 [03:27<00:00, 15.03it/s, Loss=5.5305, Batch Time=0.04s, ETA=0.0m]


Epoch 2 Average Loss: 5.4761
Checkpoint saved: checkpoints/seq2seq_epoch2.pt


Epoch 3: 100%|████████████████████████████| 3125/3125 [03:25<00:00, 15.24it/s, Loss=5.2702, Batch Time=0.04s, ETA=0.0m]


Epoch 3 Average Loss: 5.4606
Checkpoint saved: checkpoints/seq2seq_epoch3.pt
Training complete!


In [15]:
torch.save(encoder.state_dict(), 'encoder_state_dict.h5')
torch.save(decoder.state_dict(), 'decoder_state_dict.h5')