In [1]:
# --- Core Libraries ---
import os                # Operating system interactions
import random            # Random number generation
import re                # Regular expressions for text processing

# --- Data Handling and Analysis ---
import pandas as pd      # Data manipulation and analysis

# --- Visualization ---
import matplotlib.pyplot as plt  # Plotting library
from matplotlib import rcParams   # Global plotting configuration
import seaborn as sns     # Statistical data visualization
from tqdm import tqdm     # Progress bars for iterables

# --- PyTorch ---
import torch              # Main PyTorch library
import torch.nn as nn     # Neural network modules
import torch.nn.functional as F  # Functional interface for NN operations
import torch.optim as optim  # Optimization algorithms
from torch.utils.data import Dataset, DataLoader  # Data loading utilities

# --- Experiment Tracking ---
import wandb             # Weights & Biases for experiment logging


In [2]:
# === Configuration and Setup ===

# Base directory for Dakshina dataset lexicons
# Modify the language code ('hi' for Hindi) as needed
dataset_dir = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/hi/lexicons'

# Select computation device: GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set the Weights & Biases API key for logging experiments
WANDB_KEY = "8f2f82255a6e5ea16321da3895ae6b00d50eb5b5"
os.environ.setdefault("WANDB_API_KEY", WANDB_KEY)
try:
    # ‘relogin=True’ forces a fresh session if needed
    wandb.login(key=WANDB_KEY, relogin=True)
except wandb.errors.UsageError:
    # Already authenticated or bad key—ignore quietly
    pass

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mcs24m020[0m ([33mcs24m020-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class CharEmbedder(nn.Module):
    """
    Converts sequences of character indices into dense embeddings.

    Args:
        vocab_size (int): Number of unique characters in the vocabulary.
        embed_dim (int): Dimensionality of each character embedding.
    """
    def __init__(self, vocab_size: int, embed_dim: int):
        super().__init__()
        # Embedding layer: maps each input index to an embedding vector
        self.char_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

    def forward(self, char_seq: torch.Tensor) -> torch.Tensor:
        """
        Produce embeddings for an input batch of character sequences.

        Args:
            char_seq (Tensor): Shape (batch_size, seq_len) with character indices.

        Returns:
            Tensor: Shape (batch_size, seq_len, embed_dim) of embeddings.
        """
        return self.char_embedding(char_seq)

In [4]:
class SequenceEncoder(nn.Module):
    """
    Encodes token sequences into contextual representations using RNN variants.
    Supports GRU, LSTM, or vanilla RNN with optional bidirectionality and dropout.
    """
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dim: int,
        num_layers: int = 1,
        cell: str = 'GRU',
        dropout: float = 0.1,
        bidirectional: bool = False
    ) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        rnn_dropout = dropout if num_layers > 1 else 0.0

        # Token embedding + pre-RNN dropout
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.drop  = nn.Dropout(dropout)

        # Choose RNN cell
        cell_map = {
            'GRU': nn.GRU,
            'LSTM': nn.LSTM,
            'RNN': lambda *args, **kwargs: nn.RNN(*args, nonlinearity='tanh', **kwargs)
        }
        rnn_cls = cell_map.get(cell, nn.GRU)
        self.rnn = rnn_cls(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=rnn_dropout,
            bidirectional=bidirectional,
            batch_first=True
        )

    def forward(self, tokens: torch.Tensor) -> tuple:
        """
        Args:
            tokens (Tensor): (batch, seq_len) token indices
        Returns:
            outputs: (batch, seq_len, hidden_dim * directions)
            hidden: final hidden state tuple or tensor
        """
        emb = self.drop(self.embed(tokens))
        outputs, hidden = self.rnn(emb)
        return outputs, hidden


class SoftDotAttention(nn.Module):
    """
    Computes soft attention weights and context vectors using dot-product style.
    """
    def __init__(self, hidden_dim: int) -> None:
        super().__init__()
        self.linear = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Parameter(torch.Tensor(hidden_dim))
        nn.init.uniform_(self.v, -0.1, 0.1)

    def forward(self, h_t: torch.Tensor, encoder_outputs: torch.Tensor) -> torch.Tensor:
        # h_t: (batch, hidden), encoder_outputs: (batch, seq_len, hidden)
        batch, seq_len, _ = encoder_outputs.size()
        # Expand h_t for concatenation
        h_exp = h_t.unsqueeze(1).expand(-1, seq_len, -1)
        # Compute energy scores
        energy = torch.tanh(self.linear(torch.cat([h_exp, encoder_outputs], dim=2)))
        energy = energy.transpose(1, 2)  # (batch, hidden, seq_len)
        v_exp = self.v.unsqueeze(0).expand(batch, -1).unsqueeze(1)  # (batch,1,hidden)
        scores = torch.bmm(v_exp, energy).squeeze(1)  # (batch, seq_len)
        return F.softmax(scores, dim=1)  # attention weights


class AttentiveDecoder(nn.Module):
    """
    Decoder with attention: generates output tokens step-by-step.
    """
    def __init__(
        self,
        out_vocab: int,
        embed_dim: int,
        hidden_dim: int,
        num_layers: int = 1,
        cell: str = 'GRU',
        dropout: float = 0.1
    ) -> None:
        super().__init__()
        rnn_dropout = dropout if num_layers > 1 else 0.0

        self.embed = nn.Embedding(out_vocab, embed_dim)
        self.drop  = nn.Dropout(dropout)

        # RNN input will include context from attention
        rnn_input_dim = embed_dim + hidden_dim
        cell_map = {'GRU': nn.GRU, 'LSTM': nn.LSTM, 'RNN': nn.RNN}
        rnn_cls = cell_map.get(cell, nn.GRU)
        self.rnn = rnn_cls(
            input_size=rnn_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=rnn_dropout,
            batch_first=True
        )

        self.attn   = SoftDotAttention(hidden_dim)
        self.out_dp = nn.Dropout(dropout)
        self.fc_out = nn.Linear(hidden_dim, out_vocab)

    def forward(
        self,
        input_step: torch.Tensor,
        prev_hidden,
        encoder_outputs: torch.Tensor
    ) -> tuple:
        # input_step: (batch, 1), encoder_outputs: (batch, seq_len, hidden)
        emb = self.drop(self.embed(input_step.squeeze(1))).unsqueeze(1)
        # Get last hidden state
        if isinstance(prev_hidden, tuple):  # LSTM
            h_last = prev_hidden[0][-1]
        else:
            h_last = prev_hidden[-1]
        # Compute attention and context
        attn_w = self.attn(h_last, encoder_outputs).unsqueeze(1)  # (batch,1,seq_len)
        context = torch.bmm(attn_w, encoder_outputs)  # (batch,1,hidden)
        # RNN input
        rnn_in = torch.cat([emb, context], dim=2)
        output, hidden = self.rnn(rnn_in, prev_hidden)
        # Generate final output distribution
        logits = self.fc_out(self.out_dp(output.squeeze(1)))  # (batch, out_vocab)
        return F.log_softmax(logits, dim=1), hidden, attn_w.squeeze(1)


In [5]:
def beam_search(
    model,
    src_seq: torch.Tensor,
    sos_idx: int,
    eos_idx: int,
    max_len: int = 30,
    beam_width: int = 3,
    device: str = 'cuda'
) -> list:
    """
    Performs beam search decoding on a seq2seq model with attention.

    Args:
        model: Seq2Seq model with encoder and decoder attributes.
        src_seq (Tensor): Input sequence tensor (1, seq_len).
        sos_idx (int): Start-of-sequence token index.
        eos_idx (int): End-of-sequence token index.
        max_len (int): Maximum decoding steps.
        beam_width (int): Number of beams to maintain.
        device (str): Device to perform computations.

    Returns:
        List of tuples: (token sequence list, cumulative log-probability).
    """
    model.eval()
    with torch.no_grad():
        # Encode input
        enc_outs, enc_hidden = model.encoder(src_seq.to(device))

        # Initialize decoder hidden state
        if model.bidirectional:
            # Combine bidirectional states
            if model.cell == 'LSTM':
                h_n, c_n = enc_hidden
                layers = model.encoder.num_layers
                h_dec = torch.zeros(layers, 1, model.decoder.hidden_dim, device=device)
                c_dec = torch.zeros(layers, 1, model.decoder.hidden_dim, device=device)
                for i in range(layers):
                    h_cat = torch.cat([h_n[2*i], h_n[2*i+1]], dim=1)
                    c_cat = torch.cat([c_n[2*i], c_n[2*i+1]], dim=1)
                    h_dec[i] = model.hidden_transform(h_cat)
                    c_dec[i] = model.hidden_transform(c_cat)
                dec_hidden = (h_dec, c_dec)
            else:
                layers = model.encoder.num_layers
                h_dec = torch.zeros(layers, 1, model.decoder.hidden_dim, device=device)
                for i in range(layers):
                    h_cat = torch.cat([enc_hidden[2*i], enc_hidden[2*i+1]], dim=1)
                    h_dec[i] = model.hidden_transform(h_cat)
                dec_hidden = h_dec
        else:
            dec_hidden = enc_hidden

        # Beam candidates: (sequence, score, hidden)
        beams = [([sos_idx], 0.0, dec_hidden)]
        completed = []

        for _ in range(max_len):
            candidates = []
            for seq, score, hidden in beams:
                # If EOS reached, collect result
                if seq[-1] == eos_idx:
                    completed.append((seq, score))
                    continue
                inp = torch.tensor([[seq[-1]]], device=device)
                # Decode step
                output, hidden_n, _ = model.decoder(inp, hidden, enc_outs)
                topk = torch.topk(output.squeeze(0), beam_width)
                for logp, idx in zip(topk.values, topk.indices):
                    new_seq = seq + [idx.item()]
                    new_score = score + logp.item()
                    # Detach hidden to prevent graph growth
                    if isinstance(hidden_n, tuple):
                        hn = tuple(h.detach() for h in hidden_n)
                    else:
                        hn = hidden_n.detach()
                    candidates.append((new_seq, new_score, hn))
            # Select top beams
            beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if not beams:
                break

        # Include any beams ended with EOS
        completed += [(seq, score) for seq, score, _ in beams if seq[-1] == eos_idx]
        # If none completed, use current beams
        if not completed:
            completed = beams
        # Sort by highest score
        return sorted(completed, key=lambda x: x[1], reverse=True)


In [6]:
class Seq2SeqModel(nn.Module):
    """
    End-to-end seq2seq model with optional bidirectional encoder and attention decoder.
    """
    def __init__(
        self,
        src_vocab: int,
        tgt_vocab: int,
        embed_dim: int = 256,
        hidden_dim: int = 256,
        enc_layers: int = 1,
        dec_layers: int = 1,
        cell_type: str = 'GRU',
        dropout: float = 0.2,
        bidirectional: bool = False
    ) -> None:
        super().__init__()
        self.encoder = SequenceEncoder(src_vocab, embed_dim, hidden_dim, enc_layers, cell_type, dropout, bidirectional)
        self.decoder = AttentiveDecoder(tgt_vocab, embed_dim, hidden_dim, dec_layers, cell_type, dropout)
        self.bidirectional = bidirectional
        if bidirectional:
            self.hidden_transform = nn.Linear(hidden_dim * 2, hidden_dim)
        self.cell_type = cell_type

    def _match_decoder(self, hidden, batch_size: int):
        """Match encoder hidden dims to decoder layers."""
        if isinstance(hidden, tuple):  # LSTM
            h, c = hidden
            return (self._pad_or_trim(h, batch_size), self._pad_or_trim(c, batch_size))
        return self._pad_or_trim(hidden, batch_size)

    def _pad_or_trim(self, h: torch.Tensor, batch_size: int) -> torch.Tensor:
        layers = self.decoder.rnn.num_layers
        if h.size(0) > layers:
            return h[:layers]
        if h.size(0) < layers:
            pad = h.new_zeros(layers - h.size(0), batch_size, h.size(2))
            return torch.cat([h, pad], dim=0)
        return h

    def _combine_bidirectional(self, hidden):
        """Transform bidirectional encoder states for decoder initialization."""
        if self.cell_type == 'LSTM':
            h, c = hidden
            return (self._merge_dirs(h), self._merge_dirs(c))
        return self._merge_dirs(hidden)

    def _merge_dirs(self, h: torch.Tensor) -> torch.Tensor:
        # h: (2*layers, batch, hidden)
        layers = h.size(0) // 2
        merged = []
        for i in range(layers):
            cat = torch.cat([h[2*i], h[2*i+1]], dim=1)
            merged.append(self.hidden_transform(cat))
        return torch.stack(merged)

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        teacher_forcing: float = 0.5,
        return_attn: bool = False
    ):
        batch, tgt_len = tgt.size()
        outputs = src.new_zeros(batch, tgt_len, self.decoder.fc_out.out_features, dtype=torch.float)
        attn_weights = [] if return_attn else None

        enc_outs, enc_hidden = self.encoder(src)
        dec_hidden = self._combine_bidirectional(enc_hidden) if self.bidirectional else self._match_decoder(enc_hidden, batch)

        input_tok = tgt[:, 0].unsqueeze(1)
        for t in range(1, tgt_len):
            out, dec_hidden, attn = self.decoder(input_tok, dec_hidden, enc_outs)
            outputs[:, t] = out
            if return_attn:
                attn_weights.append(attn.unsqueeze(1))
            if random.random() < teacher_forcing:
                input_tok = tgt[:, t].unsqueeze(1)
            else:
                input_tok = out.argmax(1).unsqueeze(1)

        if return_attn:
            return outputs, torch.cat(attn_weights, dim=1)
        return outputs


In [7]:
class LexiconDataset(Dataset):
    """
    Dataset for transliteration pairs loaded from a TSV file.
    Each line should contain target and source strings separated by a tab.

    Args:
        file_path (str): Path to the TSV file.
        src_vocab (dict, optional): Pre-built source vocabulary mapping.
        tgt_vocab (dict, optional): Pre-built target vocabulary mapping.
        build_vocab (bool): Whether to construct vocabularies from data.
    """
    def __init__(
        self,
        file_path: str,
        src_vocab: dict = None,
        tgt_vocab: dict = None,
        build_vocab: bool = False
    ) -> None:
        super().__init__()
        self.pairs = []
        # Read TSV and collect (source, target) pairs
        with open(file_path, encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) < 2:
                    continue
                tgt_text, src_text = parts[0], parts[1]
                self.pairs.append((src_text, tgt_text))

        if build_vocab:
            # Initialize special tokens
            self.src_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
            self.tgt_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
            # Populate with characters from data
            for src_text, tgt_text in self.pairs:
                for ch in src_text:
                    self.src_vocab.setdefault(ch, len(self.src_vocab))
                for ch in tgt_text:
                    self.tgt_vocab.setdefault(ch, len(self.tgt_vocab))
        else:
            assert src_vocab is not None and tgt_vocab is not None, \
                "Provide src_vocab and tgt_vocab if not building them."
            self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx: int) -> tuple:
        src_text, tgt_text = self.pairs[idx]
        # Encode source chars
        src_indices = [self.src_vocab.get(ch, self.src_vocab['<unk>']) for ch in src_text]
        # Encode target with <sos> and <eos>
        tgt_indices = [self.tgt_vocab['<sos>']] + \
                      [self.tgt_vocab.get(ch, self.tgt_vocab['<unk>']) for ch in tgt_text] + \
                      [self.tgt_vocab['<eos>']]
        return torch.tensor(src_indices, dtype=torch.long), torch.tensor(tgt_indices, dtype=torch.long)


def pad_collate(batch: list) -> tuple:
    """
    Pads source and target sequences in a batch to the maximum lengths.

    Args:
        batch (list): List of (src_tensor, tgt_tensor) pairs.

    Returns:
        padded_src (Tensor): (batch_size, max_src_len)
        padded_tgt (Tensor): (batch_size, max_tgt_len)
    """
    srcs, tgts = zip(*batch)
    max_src = max(s.size(0) for s in srcs)
    max_tgt = max(t.size(0) for t in tgts)

    padded_src = torch.full((len(batch), max_src), fill_value=0, dtype=torch.long)
    padded_tgt = torch.full((len(batch), max_tgt), fill_value=0, dtype=torch.long)
    for i, (s, t) in enumerate(zip(srcs, tgts)):
        padded_src[i, :s.size(0)] = s
        padded_tgt[i, :t.size(0)] = t
    return padded_src, padded_tgt


def get_dataloaders(
    base_dir: str,
    batch_size: int,
    build_vocab: bool = False
) -> tuple:
    """
    Creates DataLoaders for train, validation, and test splits.

    Args:
        base_dir (str): Directory containing 'hi.translit.sampled.*.tsv' files.
        batch_size (int): Batch size for DataLoaders.
        build_vocab (bool): Whether to build vocabulary from training data.

    Returns:
        train_loader, val_loader, test_loader,
        src_vocab_size, tgt_vocab_size, pad_token_idx,
        src_vocab, tgt_vocab
    """
    # Paths for splits
    train_fp = os.path.join(base_dir, 'hi.translit.sampled.train.tsv')
    val_fp   = os.path.join(base_dir, 'hi.translit.sampled.dev.tsv')
    test_fp  = os.path.join(base_dir, 'hi.translit.sampled.test.tsv')

    # Instantiate datasets
    train_ds = LexiconDataset(train_fp, build_vocab=build_vocab)
    src_vocab, tgt_vocab = train_ds.src_vocab, train_ds.tgt_vocab
    val_ds = LexiconDataset(val_fp, src_vocab, tgt_vocab)
    test_ds = LexiconDataset(test_fp, src_vocab, tgt_vocab)

    # Dataloaders
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=pad_collate)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=pad_collate)
    test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=pad_collate)

    return (
        train_loader, val_loader, test_loader,
        len(src_vocab), len(tgt_vocab), src_vocab['<pad>'],
        src_vocab, tgt_vocab
    )


In [8]:
class EarlyStopper:
    """
    Monitors a metric and stops training if it does not improve within patience epochs.
    
    Args:
        patience (int): Number of epochs with no improvement before stopping.
        min_delta (float): Minimum change to qualify as improvement.
    """
    def __init__(self, patience: int = 5, min_delta: float = 1e-4) -> None:
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None

    def should_stop(self, current_score: float) -> bool:
        """
        Returns True if training should be stopped based on the monitored metric.
        """
        if self.best_score is None or current_score > self.best_score + self.min_delta:
            self.best_score = current_score
            self.counter = 0
        else:
            self.counter += 1
        return self.counter >= self.patience

In [9]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_accuracy',
        'goal': 'maximize'
    },
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 2,
        'max_iter': 8,
        's': 2
    },
    'parameters': {
        'embedding_dim': {'values': [16, 32, 64, 256]},
        'hidden_size':    {'values': [16, 32, 64, 256]},
        'encoder_layers': {'values': [1, 2, 3]},
        'decoder_layers': {'values': [1, 2, 3]},
        'cell_type':      {'values': ['RNN', 'GRU', 'LSTM']},
        'dropout_p':      {'values': [0.2, 0.3, 0.4]},
        'beam_width':     {'values': [1, 3, 5]},
        'teacher_forcing_ratio': {'values': [0.0, 0.3, 0.5, 0.7, 1.0]}
    }
}

In [11]:
def train_run():
    """
    Executes one run of training, validation, and testing within a W&B sweep context.
    """
    with wandb.init():
        cfg = wandb.config
        # Construct a descriptive run name
        run_name = (
            f"emb{cfg.embedding_dim}_hid{cfg.hidden_size}"
            f"_enc{cfg.encoder_layers}_dec{cfg.decoder_layers}_"
            f"{cfg.cell_type.lower()}_do{int(cfg.dropout_p*100)}"
            f"_beam{cfg.beam_width}_tf{int(cfg.teacher_forcing_ratio*100)}"
        )
        wandb.run.name = run_name

        # Load data and build vocabularies
        train_loader, val_loader, test_loader, src_size, tgt_size, \
        pad_idx, _, tgt_vocab = get_dataloaders(
            dataset_dir, batch_size=64, build_vocab=True
        )
        idx2char = {i: ch for ch, i in tgt_vocab.items()}

        # Initialize model, optimizer, loss, and early stopper
        model = Seq2SeqModel(
            src_vocab=src_size,
            tgt_vocab=tgt_size,
            embed_dim=cfg.embedding_dim,
            hidden_dim=cfg.hidden_size,
            enc_layers=cfg.encoder_layers,
            dec_layers=cfg.decoder_layers,
            cell_type=cfg.cell_type,
            dropout=cfg.dropout_p,
            bidirectional=False
        ).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.NLLLoss(ignore_index=pad_idx)
        stopper = EarlyStopper(patience=5)

        best_val_acc = 0.0
        # Training loop
        for epoch in range(1, 11):
            model.train()
            epoch_loss = 0.0
            for src, tgt in tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False):
                src, tgt = src.to(device), tgt.to(device)
                optimizer.zero_grad()
                outputs = model(src, tgt, teacher_forcing=cfg.teacher_forcing_ratio)
                loss = criterion(outputs.view(-1, tgt_size), tgt.view(-1))
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            # Validation
            model.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for src, tgt in tqdm(val_loader, desc=f"Epoch {epoch} Validation", leave=False):
                    src, tgt = src.to(device), tgt.to(device)
                    preds = model(src, tgt, teacher_forcing=0.0).argmax(dim=2)
                    for p_seq, t_seq in zip(preds, tgt):
                        pred_tokens = p_seq[1:][t_seq[1:] != pad_idx]
                        true_tokens = t_seq[1:][t_seq[1:] != pad_idx]
                        if torch.equal(pred_tokens, true_tokens):
                            correct += 1
                        total += 1
            val_acc =100* correct / total
            wandb.log({'epoch': epoch, 'train_loss': epoch_loss, 'val_accuracy': val_acc})

            print(f"Epoch {epoch} | Loss: {epoch_loss:.4f} | Val Acc: {val_acc:.4f}")
            if val_acc > best_val_acc:
                best_val_acc = val_acc
            elif stopper.should_stop(val_acc):
                print("Early stopping triggered.")
                break

        # Final Test Evaluation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for src, tgt in tqdm(test_loader, desc="Test Evaluation", leave=False):
                src, tgt = src.to(device), tgt.to(device)
                preds = model(src, tgt, teacher_forcing=0.0).argmax(dim=2)
                for p_seq, t_seq in zip(preds, tgt):
                    pt = p_seq[1:][t_seq[1:] != pad_idx]
                    tt = t_seq[1:][t_seq[1:] != pad_idx]
                    if torch.equal(pt, tt):
                        correct += 1
                    total += 1
        test_acc = 100* correct / total
        print(f"Final Test Accuracy: {test_acc:.4f}")
        wandb.log({'final_test_accuracy': test_acc})

In [12]:
sweep_id = wandb.sweep(sweep_config, project='cs24m020_dl_a3_att')
wandb.agent(sweep_id, function=train_run, count=100)

Create sweep with ID: 4yok39ch
Sweep URL: https://wandb.ai/cs24m020-indian-institute-of-technology-madras/cs24m020_dl_a3_att/sweeps/4yok39ch


[34m[1mwandb[0m: Agent Starting Run: vsu5bz6p with config:
[34m[1mwandb[0m: 	beam_width: 5
[34m[1mwandb[0m: 	cell_type: RNN
[34m[1mwandb[0m: 	decoder_layers: 2
[34m[1mwandb[0m: 	dropout_p: 0.2
[34m[1mwandb[0m: 	embedding_dim: 64
[34m[1mwandb[0m: 	encoder_layers: 2
[34m[1mwandb[0m: 	hidden_size: 64
[34m[1mwandb[0m: 	teacher_forcing_ratio: 0.5


                                                                   

Epoch 1 | Loss: 1760.0845 | Val Acc: 0.0032


                                                                   

Epoch 2 | Loss: 1128.9210 | Val Acc: 0.1239


                                                                   

Epoch 3 | Loss: 884.5545 | Val Acc: 0.1641


                                                                   

Epoch 4 | Loss: 803.2041 | Val Acc: 0.1804


                                                                   

Epoch 5 | Loss: 745.5470 | Val Acc: 0.2118


                                                                   

Epoch 6 | Loss: 710.2273 | Val Acc: 0.2299


                                                                   

Epoch 7 | Loss: 691.8183 | Val Acc: 0.2393


                                                                   

Epoch 8 | Loss: 672.3760 | Val Acc: 0.2480


                                                                   

Epoch 9 | Loss: 652.8452 | Val Acc: 0.2643


                                                                    

Epoch 10 | Loss: 640.2279 | Val Acc: 0.2616


                                                                     

Final Test Accuracy: 0.0355




0,1
epoch,▁▂▃▃▄▅▆▆▇█
final_test_accuracy,▁
train_loss,█▄▃▂▂▁▁▁▁▁
val_accuracy,▁▄▅▆▇▇▇███

0,1
epoch,10.0
final_test_accuracy,0.03554
train_loss,640.22792
val_accuracy,0.26159


[34m[1mwandb[0m: Agent Starting Run: nm51vjma with config:
[34m[1mwandb[0m: 	beam_width: 5
[34m[1mwandb[0m: 	cell_type: GRU
[34m[1mwandb[0m: 	decoder_layers: 2
[34m[1mwandb[0m: 	dropout_p: 0.2
[34m[1mwandb[0m: 	embedding_dim: 16
[34m[1mwandb[0m: 	encoder_layers: 2
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	teacher_forcing_ratio: 0.7


Epoch 1 Training:  94%|█████████▍| 652/691 [00:22<00:01, 30.54it/s][34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
                                                                   

Epoch 1 | Loss: 2009.0010 | Val Acc: 0.0000


Epoch 2 Training:  35%|███▌      | 244/691 [00:08<00:17, 26.04it/s]

In [None]:
# === Dataset and Vocabulary ===
CHAR2IDX_SRC = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3, **{ch: i + 4 for i, ch in enumerate("abcdefghijklmnopqrstuvwxyz")}}
IDX2CHAR_SRC = {i: ch for ch, i in CHAR2IDX_SRC.items()}

train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
    BASE_DIR, batch_size=64, build_vocab=True
)
IDX2CHAR_TGT = {i: ch for ch, i in tgt_vocab.items()}

# === Model Instantiation ===
best_model = Seq2Seq(
    input_size=src_size,
    output_size=tgt_size,
    embedding_dim=64,
    hidden_size=256,
    encoder_layers=3,
    decoder_layers=1,
    cell_type='LSTM',
    dropout_p=0.4,
    bidirectional_encoder=False
).to(DEVICE)

optimizer = optim.Adam(best_model.parameters(), lr=1e-3)
criterion = nn.NLLLoss(ignore_index=pad_idx)
stopper = EarlyStopper(patience=5)

# === Training Routine ===
best_val_acc = 0.0
for epoch in range(1, 11):
    best_model.train()
    total_loss = 0.0
    for src, tgt in tqdm(train_loader, desc=f"Epoch {epoch} - Training", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        optimizer.zero_grad()
        outputs = best_model(src, tgt, teacher_forcing_ratio=1.0)
        loss = criterion(outputs.view(-1, tgt_size), tgt.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # === Validation ===
    best_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"Epoch {epoch} - Validation", leave=False):
            src, tgt = src.to(DEVICE), tgt.to(DEVICE)
            outputs = best_model(src, tgt, teacher_forcing_ratio=0.0)
            predictions = outputs.argmax(dim=2)
            for pred_seq, true_seq in zip(predictions, tgt):
                pred_clean = pred_seq[1:][true_seq[1:] != pad_idx]
                true_clean = true_seq[1:][true_seq[1:] != pad_idx]
                if torch.equal(pred_clean, true_clean):
                    correct += 1
                total += 1

    val_acc = correct / total
    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | Validation Accuracy: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
    elif stopper.should_stop(val_acc):
        print("Early stopping criteria met.")
        break

# === Test Evaluation ===
correct, total = 0, 0
predictions_all, targets_all = [], []

best_model.eval()
with torch.no_grad():
    for src, tgt in tqdm(test_loader, desc="Testing", leave=False):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        outputs = best_model(src, tgt, teacher_forcing_ratio=0.0)
        predictions = outputs.argmax(dim=2)

        for pred_seq, true_seq in zip(predictions, tgt):
            pred_clean = pred_seq[1:][true_seq[1:] != pad_idx]
            true_clean = true_seq[1:][true_seq[1:] != pad_idx]
            if torch.equal(pred_clean, true_clean):
                correct += 1
            total += 1
            predictions_all.append(pred_clean)
            targets_all.append(true_clean)

print(f"\nFinal Exact Match Accuracy on Test Set: {correct / total:.4f}")


In [13]:
# === Sample Predictions ===
roman_inputs = []
with open(os.path.join(BASE_DIR, 'te.translit.sampled.test.tsv'), encoding='utf-8') as test_file:
    for line in test_file:
        _, romanized_word, _ = line.strip().split()
        roman_inputs.append(romanized_word)

predicted_samples = []
best_model.eval()
with torch.no_grad():
    for batch_idx, (src_batch, tgt_batch) in enumerate(test_loader):
        src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)
        logits = best_model(src_batch, tgt_batch, teacher_forcing_ratio=0.0)
        decoded = logits.argmax(dim=2)

        for i in range(src_batch.size(0)):
            predicted = ''.join(
                IDX2CHAR_TGT[idx.item()] for idx in decoded[i][1:] if idx.item() != pad_idx
            )
            ground_truth = ''.join(
                IDX2CHAR_TGT[idx.item()] for idx in tgt_batch[i][1:] if idx.item() != pad_idx
            )
            predicted_samples.append({
                'Romanized Input': roman_inputs[batch_idx * src_batch.size(0) + i],
                'True Output': ground_truth,
                'Model Output': predicted
            })

# === Display Sample Results ===
preview_samples = random.sample(predicted_samples, min(10, len(predicted_samples)))
df = pd.DataFrame(preview_samples)
print(df.to_markdown(index=False))

Epoch 2 Training:  36%|███▌      | 247/691 [00:08<00:17, 25.76it/s]

NameError: name 'BASE_DIR' is not defined

                                                                   

Epoch 2 | Loss: 1821.2137 | Val Acc: 0.0000


Epoch 3 Training:  41%|████▏     | 286/691 [00:10<00:14, 27.85it/s]

In [22]:
# === Display Sample Results with Highlighting ===
def style_mismatches(row):
    styling = [''] * len(row)
    try:
        col_names = list(row.index)
        idx = col_names.index('Model Output')
        if row['Model Output'] == row['True Output']:
            styling[idx] = 'background-color: #d4edda; font-weight: bold;'
        else:
            styling[idx] = 'background-color: #f8d7da; font-weight: bold;'
    except ValueError:
        pass
    return styling

sample_subset = pd.DataFrame(random.sample(predicted_samples, min(10, len(predicted_samples))))
styled_df = (
    sample_subset.style
        .apply(style_mismatches, axis=1)
        .set_table_styles([
            {'selector': 'td, th', 'props': [('text-align', 'center'), ('padding', '6px')]},
            {'selector': 'th', 'props': [('background-color', '#4F81BD'), ('color', 'white'), ('font-weight', 'bold'), ('padding', '8px')]}
        ])
        .set_caption("✨ Sample Transliteration Predictions (Green = Correct, Red = Wrong) ✨")
)

# Display in Jupyter Notebook or similar environment
display(styled_df)

Unnamed: 0,Input,True Hindi,Predicted Hindi
0,manchit,मंचित,माचित
1,sanskrit,संस्कृत,संस्कृत
2,girijagharon,गिरजाघरों,गिरिजाघरों
3,mem,मेम,म
4,uniyal,उनियाल,यूनियल
5,vishhin,विषहीन,विशियों
6,majin,माजिन,मीन
7,bhairavnath,भैरवनाथ,भैरवनाथ
8,bokaro,बोकारो,बोकरों
9,amitao,अमिताव,अमीताओं


In [40]:
wandb.init(project="cs24m020_dl_a3_att")

In [41]:
# Set the font family to Noto Sans Telugu
rcParams['font.family'] = ['Noto Sans', 'Noto Sans Hindi', 'sans-serif']

In [42]:
# === Attention Heatmap Logging ===
wandb.init(project="cs24m020_dl_a3_att")
rcParams['font.family'] = ['Noto Sans', 'Noto Sans Hindi', 'sans-serif']
os.makedirs("predictions_vanilla/attention_maps", exist_ok=True)

best_model.eval()
attention_imgs = []
max_samples = 12
collected = 0

with torch.no_grad():
    for src, tgt in test_loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        logits, attention = best_model(src, tgt, return_attention=True)

        for i in range(src.size(0)):
            src_seq = [IDX2CHAR_SRC[idx.item()] for idx in src[i] if idx.item() != pad_idx]
            true_seq = [IDX2CHAR_TGT[idx.item()] for idx in tgt[i][1:] if idx.item() != pad_idx]
            pred_seq = logits.argmax(dim=2)[i][1:len(true_seq)+1]
            pred_chars = [IDX2CHAR_TGT[idx.item()] for idx in pred_seq]
            attn_matrix = attention[i][:len(pred_chars), :len(src_seq)]

            fig, ax = plt.subplots(figsize=(6, 4))
            sns.heatmap(attn_matrix.cpu().numpy(),
                        xticklabels=src_seq,
                        yticklabels=pred_chars,
                        cmap='coolwarm',
                        cbar=False,
                        linewidths=0.5,
                        ax=ax)
            ax.set_xlabel("Input")
            ax.set_ylabel("Predicted Output (Hindi)")
            ax.set_title(f"Attention Heatmap {collected + 1}")
            plt.tight_layout()

            save_path = f"predictions_vanilla/attention_maps/sample_{collected + 1}.png"
            fig.savefig(save_path)
            plt.close(fig)

            attention_imgs.append(wandb.Image(save_path, caption=f"Sample {collected + 1}"))
            collected += 1
            if collected >= max_samples:
                break
        if collected >= max_samples:
            break

wandb.log({"attention_maps": attention_imgs})

  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.savefig(path)
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.savefig(path)
  fig.savefig(path)


In [43]:
df = pd.DataFrame(samples)
output_dir = '/kaggle/working/predictions-vanilla'
os.makedirs(output_dir, exist_ok=True) 
csv_path = '/kaggle/working/predictions-vanilla/output.csv'
df.to_csv(csv_path, index=False, encoding='utf-8-sig')

print(f"\n✅ Saved all predictions to: {csv_path}")


✅ Saved all predictions to: /kaggle/working/predictions-vanilla/output.csv


In [45]:
# === Interactive HTML Attention Viewer ===
html_snippets = []
for idx, (src_tok, pred_tok, weights) in enumerate(random_samples):
    js_input = json.dumps(src_tok, ensure_ascii=False)
    js_output = json.dumps(pred_tok, ensure_ascii=False)
    js_matrix = json.dumps(weights)

    snippet = f'''
    <div class="sample-block">
      <h2>Sample {idx + 1}</h2>
      <div class="token-label">Input (Romanized):</div>
      <div class="token-container input-tokens" id="input-{idx}"></div>
      <div class="token-label">Predicted Hindi Output:</div>
      <div class="token-container output-tokens" id="output-{idx}"></div>
      <script>
        const input_{idx} = {js_input};
        const output_{idx} = {js_output};
        const attn_{idx} = {js_matrix};

        const input_div_{idx} = d3.select("#input-{idx}");
        const output_div_{idx} = d3.select("#output-{idx}");

        input_{idx}.forEach((tok, i) => {{
          input_div_{idx}.append("span")
            .attr("class", "token input")
            .attr("id", "tok-in-{idx}-" + i)
            .text(tok);
        }});

        output_{idx}.forEach((tok, i) => {{
          output_div_{idx}.append("span")
            .attr("class", "token output")
            .attr("title", "Hover to see alignment")
            .text(tok)
            .on("mouseover", () => {{
              d3.selectAll(".token.input").style("background-color", "#f9f9f9");
              attn_{idx}[i].forEach((val, j) => {{
                const color = d3.interpolateBlues(val);
                d3.select("#tok-in-{idx}-" + j).style("background-color", color);
              }});
            }})
            .on("mouseout", () => {{
              d3.selectAll(".token.input").style("background-color", "#f9f9f9");
            }});
        }});
      </script>
    </div>
    '''
    html_snippets.append(snippet)

full_html = f'''
<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Seq2Seq Attention Viewer</title>
  <script src="https://d3js.org/d3.v7.min.js"></script>
  <style>
    body {{ font-family: 'Segoe UI', sans-serif; background: #f2f4f8; margin: 40px; color: #333; }}
    .sample-block {{ background: #fff; padding: 20px; margin-bottom: 40px; border-radius: 8px; box-shadow: 0 4px 10px rgba(0,0,0,0.06); }}
    .token-label {{ font-weight: bold; font-size: 16px; margin-top: 10px; color: #555; }}
    .token-container {{ display: flex; flex-wrap: wrap; margin: 10px 0; }}
    .token {{ padding: 8px 12px; margin: 4px; border-radius: 6px; font-size: 18px; cursor: pointer; border: 1px solid #ccc; background-color: #f9f9f9; transition: background-color 0.3s; }}
    .token.output {{ background-color: #e8f0fe; border-color: #a0c4ff; }}
  </style>
</head>
<body>
  <h1>Attention Visualizations for 10 Random Samples</h1>
  {''.join(html_snippets)}
</body>
</html>
'''

wandb.log({"att_visual": wandb.Html(full_html)})

In [48]:
import wandb

# 1. Start your run
run = wandb.init(
    project="my-project",
    job_type="html-report",
    name="attention-visualization"
)

# 2. Write the HTML string to a file
with open("att_visual.html", "w", encoding="utf-8") as f:
    f.write(full_html)

# 3. (Optional) Log it for in‐UI preview this run
run.log({"att_visual_preview": wandb.Html("att_visual.html")})

# 4. Create an Artifact and attach the file
artifact = wandb.Artifact(
    name="attention-viz-report",
    type="html-report",
    description="Permanent HTML attention visualization"
)
artifact.add_file("att_visual.html")

# 5. Log the Artifact so it never expires
run.log_artifact(artifact)

# 6. Finish the run
run.finish()


In [None]:
wandb.finish()