In [None]:
# === Standard Library Imports ===
import os            # File system utilities
import re            # Regular expressions
import random        # Randomization tools

# === Data Handling & Progress Tracking ===
import pandas as pd              # DataFrame manipulation
from tqdm.auto import tqdm       # Smart progress bar
import wandb                    # Experiment tracking

# === PyTorch Core Modules ===
import torch                    # Base tensor library
import torch.nn as nn           # Neural net layers
import torch.optim as optim     # Optimizers (SGD, Adam, etc.)
import torch.nn.functional as F # Functional layer ops
from torch.utils.data import Dataset, DataLoader  # Custom dataset/dataloader


In [None]:
# === Dataset Path Configuration ===
# Modify language code to switch lexicon, e.g. 'hi' → 'mr'
LEXICON_DIR = "/kaggle/input/dakshina/dakshina_dataset_v1.0/hi/lexicons"

# === WANDB Authentication Setup ===
WANDB_API_KEY = "8f2f82255a6e5ea16321da3895ae6b00d50eb5b5"
os.environ.setdefault("WANDB_API_KEY", WANDB_API_KEY)
try:
    # Force fresh login if necessary
    wandb.login(key=WANDB_API_KEY, relogin=True)
except wandb.errors.UsageError:
    # Skip if already logged in or invalid key
    pass

# === Seed & Device Configuration ===
SEED = 42
random.seed(SEED)            # Seed Python RNG
torch.manual_seed(SEED)      # Seed PyTorch RNG
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [None]:
# === Character Embedding Layer ===
class CharEmbedder(nn.Module):
    """Converts char indices to embeddings."""
    def __init__(self, vocab_sz: int, emb_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_sz, emb_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Input shape: [batch, seq_len]
        # Output shape: [batch, seq_len, emb_dim]
        return self.embedding(x)


# === Sequence Encoder (RNN/GRU/LSTM) ===
class SeqEncoder(nn.Module):
    """
    Encodes sequences into context vectors.
    Parameters:
      vocab_sz: input vocab size
      hid_dim: hidden layer size
      emb_dim: embedding size
      layers: RNN layers count
      rnn_kind: 'GRU', 'LSTM', or 'RNN'
      dropout_p: dropout rate
      bidirectional: flag for bidirectional RNN
    """
    def __init__(
        self,
        vocab_sz: int,
        hid_dim: int,
        emb_dim: int,
        layers: int = 1,
        rnn_kind: str = "GRU",
        dropout_p: float = 0.1,
        bidirectional: bool = False,
    ):
        super().__init__()
        self.hid_dim = hid_dim
        self.directions = 2 if bidirectional else 1

        self.embedding = nn.Embedding(vocab_sz, emb_dim)
        self.dropout = nn.Dropout(dropout_p)
        # Apply dropout between layers only if >1 layer
        rnn_dropout = dropout_p if layers > 1 else 0.0

        rnn_map = {"GRU": nn.GRU, "LSTM": nn.LSTM, "RNN": nn.RNN}
        rnn_args = dict(
            input_size=emb_dim,
            hidden_size=hid_dim,
            num_layers=layers,
            dropout=rnn_dropout,
            bidirectional=bidirectional,
            batch_first=True,
        )
        if rnn_kind == "RNN":
            rnn_args["nonlinearity"] = "tanh"

        self.rnn = rnn_map[rnn_kind](**rnn_args)

    def forward(self, tokens: torch.Tensor):
        # tokens: [batch, seq_len]
        emb = self.dropout(self.embedding(tokens))  # [batch, seq_len, emb_dim]
        outputs, hidden_state = self.rnn(emb)
        return outputs, hidden_state  # all steps output + final hidden (and cell)


# === Stepwise Decoder for Tokens ===
class SeqDecoder(nn.Module):
    """
    Step-by-step token generator.
    Parameters:
      vocab_sz: target vocab size
      hid_dim: RNN hidden units
      emb_dim: embedding size
      layers: number of RNN layers
      rnn_kind: 'GRU', 'LSTM', or 'RNN'
      dropout_p: dropout before/after RNN
    """
    def __init__(
        self,
        vocab_sz: int,
        hid_dim: int,
        emb_dim: int,
        layers: int = 1,
        rnn_kind: str = "GRU",
        dropout_p: float = 0.1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_sz, emb_dim)
        self.input_dropout = nn.Dropout(dropout_p)
        rnn_dropout = dropout_p if layers > 1 else 0.0

        rnn_map = {"GRU": nn.GRU, "LSTM": nn.LSTM, "RNN": nn.RNN}
        rnn_args = dict(
            input_size=emb_dim,
            hidden_size=hid_dim,
            num_layers=layers,
            dropout=rnn_dropout,
            batch_first=True,
        )
        if rnn_kind == "RNN":
            rnn_args["nonlinearity"] = "tanh"

        self.rnn = rnn_map[rnn_kind](**rnn_args)
        self.output_dropout = nn.Dropout(rnn_dropout)
        self.output_layer = nn.Linear(hid_dim, vocab_sz)

    def forward(self, current_token: torch.Tensor, prev_state):
        """
        Inputs:
          current_token: [batch, 1], current input token indices
          prev_state: prior hidden (and cell) states
        Outputs:
          log_probs: [batch, vocab_sz], log-probabilities of next tokens
          next_state: updated RNN states
        """
        emb = self.embedding(current_token)  # [batch, 1, emb_dim]
        emb = self.input_dropout(emb)
        rnn_out, next_state = self.rnn(emb, prev_state)
        dropped_out = self.output_dropout(rnn_out[:, 0, :])  # timestep squeezed
        logits = self.output_layer(dropped_out)
        return F.log_softmax(logits, dim=-1), next_state

In [None]:
def beam_search_decode(
    model,
    src_seq: torch.Tensor,
    start_token: int,
    end_token: int,
    max_steps: int = 30,
    beam_width: int = 3,
    device: torch.device = DEVICE,
):
    """Perform beam search decoding with a seq2seq model.

    Args:
      model: seq2seq model (encoder + decoder)
      src_seq: input tensor batch [batch, seq_len]
      start_token: start-of-sequence token index
      end_token: end-of-sequence token index
      max_steps: max decode length
      beam_width: number of beams to track
      device: compute device

    Returns:
      List of (token sequence, score) sorted by descending score
    """
    model.eval()
    with torch.no_grad():
        # Encode input batch
        encoder_outputs, encoder_hidden = model.encoder(src_seq.to(device))

        # If encoder is bidirectional, merge forward/backward states
        if model.bidirectional:
            def merge_states(h, c=None):
                layers = model.encoder.num_layers
                h_merged = torch.zeros(layers, 1, model.decoder.hidden_size, device=device)
                c_merged = None if c is None else torch.zeros_like(h_merged)
                for layer in range(layers):
                    f_h, b_h = h[2*layer], h[2*layer + 1]
                    h_merged[layer] = model.hidden_transform(torch.cat((f_h, b_h), dim=1))
                    if c is not None:
                        f_c, b_c = c[2*layer], c[2*layer + 1]
                        c_merged[layer] = model.hidden_transform(torch.cat((f_c, b_c), dim=1))
                return (h_merged, c_merged) if c is not None else h_merged

            # Merge hidden states for LSTM or GRU/RNN
            if isinstance(encoder_hidden, tuple):  # LSTM: (h, c)
                decoder_state = merge_states(*encoder_hidden)
            else:  # GRU or vanilla RNN
                decoder_state = merge_states(encoder_hidden)
        else:
            # Use encoder hidden as decoder initial state
            decoder_state = encoder_hidden

        # Start beam search with start token
        beams = [([start_token], 0.0, decoder_state)]  # (sequence, log_prob, hidden_state)
        finished = []

        # Expand beams step-by-step
        for _ in range(max_steps):
            candidates = []
            for seq, log_prob, hidden in beams:
                # If end token generated, save completed sequence
                if seq[-1] == end_token:
                    finished.append((seq, log_prob))
                    continue
                # Decode next step probabilities
                input_token = torch.tensor([[seq[-1]]], device=device)
                log_probs, next_hidden = model.decoder(input_token, hidden)

                # Get top-k candidates
                top_probs, top_indices = torch.topk(log_probs.squeeze(0), beam_width)
                for prob, idx in zip(top_probs, top_indices):
                    new_seq = seq + [idx.item()]
                    new_log_prob = log_prob + prob.item()
                    candidates.append((new_seq, new_log_prob, next_hidden))

            # Keep top beam_width beams only
            beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if not beams:
                break

        # Add remaining beams ending with end token
        for seq, log_prob, _ in beams:
            if seq[-1] == end_token:
                finished.append((seq, log_prob))

        # If no sequences ended properly, use best available beams
        if not finished:
            finished = [(seq, log_prob) for seq, log_prob, _ in beams]

        # Return sorted sequences by score (high to low)
        return sorted(finished, key=lambda x: x[1], reverse=True)


In [None]:
# ─── Seq2Seq wrapper: combines encoder & decoder, handles bidirectionality and teacher forcing ───
class Seq2SeqModel(nn.Module):
    """
    Seq2Seq model wrapping encoder & decoder.
    Supports bidirectional encoder with hidden state projection.
    Uses teacher forcing during training.
    """
    def __init__(
        self,
        src_vocab: int,
        tgt_vocab: int,
        emb_dim: int = 256,
        hid_dim: int = 256,
        enc_layers: int = 1,
        dec_layers: int = 1,
        rnn_type: str = "GRU",
        drop_p: float = 0.2,
        bidir_enc: bool = False,
    ):
        super().__init__()
        # Encoder setup
        self.encoder = SeqEncoder(
            vocab_size=src_vocab,
            hid_dim=hid_dim,
            emb_dim=emb_dim,
            layers=enc_layers,
            rnn_type=rnn_type,
            drop=drop_p,
            bidir=bidir_enc,
        )
        self.bidir = bidir_enc
        # Projection for bidirectional hidden states
        if bidir_enc:
            self.hidden_proj = nn.Linear(hid_dim * 2, hid_dim)
        # Decoder setup
        self.decoder = SeqDecoder(
            out_size=tgt_vocab,
            hid_dim=hid_dim,
            emb_dim=emb_dim,
            layers=dec_layers,
            rnn_type=rnn_type,
            drop=drop_p,
        )
        self.rnn_type = rnn_type

    # Adjust hidden states to decoder layer count (trim or pad)
    def _match_layers(self, h: torch.Tensor, bsz: int):
        dl = self.decoder.rnn.num_layers
        if h.size(0) > dl:
            return h[:dl]
        if h.size(0) < dl:
            pad = torch.zeros(dl - h.size(0), bsz, h.size(2), device=h.device)
            return torch.cat([h, pad], dim=0)
        return h

    def forward(self, src: torch.Tensor, tgt: torch.Tensor, tf_ratio: float = 0.5):
        bsz, tgt_len = src.size(0), tgt.size(1)
        out_vocab = self.decoder.project.out_features if hasattr(self.decoder, 'project') else self.decoder.project.weight.size(0)
        outputs = torch.zeros(bsz, tgt_len, out_vocab, device=src.device)

        # Encode source
        enc_outs, enc_state = self.encoder(src)

        # Initialize decoder state from encoder
        if self.bidir:
            if self.rnn_type == "LSTM":
                h_n, c_n = enc_state
                h_dec = torch.zeros(self.decoder.rnn.num_layers, bsz, self.decoder.rnn.hidden_size, device=src.device)
                c_dec = torch.zeros_like(h_dec)
                for i in range(self.decoder.rnn.num_layers):
                    layer = min(i, self.encoder.rnn.num_layers - 1)
                    # Concatenate forward & backward states and project
                    h_cat = torch.cat((h_n[2*layer], h_n[2*layer+1]), dim=1)
                    c_cat = torch.cat((c_n[2*layer], c_n[2*layer+1]), dim=1)
                    h_dec[i] = self.hidden_proj(h_cat)
                    c_dec[i] = self.hidden_proj(c_cat)
                dec_state = (h_dec, c_dec)
            else:
                h_n = enc_state
                h_dec = torch.zeros(self.decoder.rnn.num_layers, bsz, self.decoder.rnn.hidden_size, device=src.device)
                for i in range(self.decoder.rnn.num_layers):
                    layer = min(i, self.encoder.rnn.num_layers - 1)
                    h_cat = torch.cat((h_n[2*layer], h_n[2*layer+1]), dim=1)
                    h_dec[i] = self.hidden_proj(h_cat)
                dec_state = h_dec
        else:
            # Match encoder-decoder layer counts for unidirectional
            if self.rnn_type == "LSTM":
                h, c = enc_state
                dec_state = (self._match_layers(h, bsz), self._match_layers(c, bsz))
            else:
                dec_state = self._match_layers(enc_state, bsz)

        # Decode with teacher forcing
        input_tok = tgt[:, 0].unsqueeze(1)  # start token
        for t in range(1, tgt_len):
            probs, dec_state = self.decoder(input_tok, dec_state)
            outputs[:, t, :] = probs
            # Use target token or predicted token based on tf_ratio
            top1 = probs.argmax(1).unsqueeze(1)
            input_tok = tgt[:, t].unsqueeze(1) if random.random() < tf_ratio else top1

        return outputs


In [None]:

# ─── Dataset for transliteration pairs ───────────────────────────────
class TransliterationDataset(Dataset):
    """
    Loads transliteration pairs from a TSV file.
    Builds or accepts existing vocabularies.
    """
    def __init__(self, file_path, src_vocab=None, tgt_vocab=None, create_vocab=False):
        self.data_pairs = []
        # Read each line and extract source and target sequences
        with open(file_path, encoding="utf-8") as file:
            for line in file:
                cols = line.strip().split("\t")
                if len(cols) >= 2:
                    tgt, src = cols[0], cols[1]
                    self.data_pairs.append((src, tgt))

        if create_vocab:
            # Initialize vocab with special tokens
            self.src_vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
            self.tgt_vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
            for src_seq, tgt_seq in self.data_pairs:
                for ch in src_seq:
                    self.src_vocab.setdefault(ch, len(self.src_vocab))
                for ch in tgt_seq:
                    self.tgt_vocab.setdefault(ch, len(self.tgt_vocab))
        else:
            # Use provided vocabularies
            assert src_vocab is not None and tgt_vocab is not None, "Vocabularies must be provided if not creating."
            self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        src_seq, tgt_seq = self.data_pairs[idx]

        # Map characters to indices with <unk> fallback for source
        src_indices = [self.src_vocab.get(ch, self.src_vocab["<unk>"]) for ch in src_seq]
        # For target, add <sos> and <eos> tokens around mapped indices
        tgt_indices = [self.tgt_vocab["<sos>"]] + [self.tgt_vocab.get(ch, self.tgt_vocab["<unk>"]) for ch in tgt_seq] + [self.tgt_vocab["<eos>"]]

        return torch.tensor(src_indices, dtype=torch.long), torch.tensor(tgt_indices, dtype=torch.long)


# ─── Collate function to pad batches ───────────────────────────────
def pad_batch(batch):
    """
    Pads sequences in batch to the length of the longest sequence.
    Returns padded source and target tensors.
    """
    src_seqs, tgt_seqs = zip(*batch)
    max_src_len = max(len(seq) for seq in src_seqs)
    max_tgt_len = max(len(seq) for seq in tgt_seqs)

    PAD = 0
    padded_src = torch.full((len(batch), max_src_len), PAD, dtype=torch.long)
    padded_tgt = torch.full((len(batch), max_tgt_len), PAD, dtype=torch.long)

    for i, (src, tgt) in enumerate(zip(src_seqs, tgt_seqs)):
        padded_src[i, :len(src)] = src
        padded_tgt[i, :len(tgt)] = tgt

    return padded_src, padded_tgt


# ─── Function to create DataLoaders ───────────────────────────────
def create_dataloaders(data_dir, batch_sz, build_vocab=False):
    """
    Generates DataLoaders for train, validation, and test sets,
    along with vocab sizes, pad token, and vocab dictionaries.
    """
    train_path = os.path.join(data_dir, "hi.translit.sampled.train.tsv")
    valid_path = os.path.join(data_dir, "hi.translit.sampled.dev.tsv")
    test_path  = os.path.join(data_dir, "hi.translit.sampled.test.tsv")

    train_dataset = TransliterationDataset(train_path, create_vocab=build_vocab)
    src_vocab, tgt_vocab = train_dataset.src_vocab, train_dataset.tgt_vocab

    val_dataset = TransliterationDataset(valid_path, src_vocab=src_vocab, tgt_vocab=tgt_vocab)
    test_dataset = TransliterationDataset(test_path, src_vocab=src_vocab, tgt_vocab=tgt_vocab)

    train_loader = DataLoader(train_dataset, batch_size=batch_sz, shuffle=True, collate_fn=pad_batch)
    val_loader = DataLoader(val_dataset, batch_size=batch_sz, shuffle=False, collate_fn=pad_batch)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=pad_batch)

    return (
        train_loader,
        val_loader,
        test_loader,
        len(src_vocab),
        len(tgt_vocab),
        src_vocab["<pad>"],
        src_vocab,
        tgt_vocab,
    )


In [None]:
class EarlyStopping:
    """
    Stops training if a monitored metric doesn't improve.
    
    Args:
      patience: max checks without improvement before stopping
      min_delta: minimum required improvement to reset patience
    """
    def __init__(self, patience: int = 5, min_delta: float = 1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = None
        self.counter = 0

    def check(self, metric: float) -> bool:
        """
        Determine if training should stop.
        
        Args:
          metric: current value of monitored metric
        Returns:
          True if no improvement for patience steps, else False
        """
        if self.best_score is None or metric > self.best_score + self.min_delta:
            self.best_score = metric
            self.counter = 0
        else:
            self.counter += 1
        return self.counter >= self.patience


In [None]:
# ─── WandB Sweep Configuration ────────────────────────────────────────────────
sweep_cfg = {
    "method": "bayes",  # Bayesian hyperparameter search
    "metric": {
        "name": "val_acc", 
        "goal": "maximize"
    },
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 2,
        "max_iter": 8,
        "s": 2
    },
    "parameters": {
        # Model & training sizes
        "batch_size":            {"values": [16, 32, 64, 128, 256]},
        "num_epochs":            {"values": [10]},
        "encoder_layers":        {"values": [1, 2, 3]},
        "decoder_layers":        {"values": [1, 2, 3]},
        "hidden_size":           {"values": [16, 32, 64, 128, 256, 512, 1024]},
        "embedding_dim":         {"values": [16, 32, 64, 256, 512]},
        "dropout_rate":          {"values": [0.2, 0.3, 0.4]},
        "bi_directional":        {"values": [True, False]},
        # Search & decoding
        "beam_width":            {"values": [1, 3, 5]},
        "teacher_forcing_ratio": {"values": [0.0, 0.3, 0.5, 0.7, 1.0]},
        "length_penalty":        {"values": [0, 0.4, 0.5, 0.6]},
        # Optimization
        "optimizer":             {"values": ["adam", "sgd", "rmsprop", "adagrad"]},
        "learning_rate":         {"values": [0.005, 0.001, 0.01, 0.1]},
        # RNN cell variants
        "rnn_cell":              {"values": ["RNN", "GRU", "LSTM"]},
    }
}


In [None]:
# ─── Optimized Sweep Parameters ───────────────────────────────────────────────
best_cfg = {
    "parameters": {
        # batch & epochs
        "batch_size":            {"values": [64]},
        "num_epochs":            {"values": [10]},
        # model architecture
        "encoder_layers":        {"values": [2]},
        "decoder_layers":        {"values": [2]},
        "hidden_size":           {"values": [256]},
        "embedding_dim":         {"values": [64]},
        "dropout_rate":          {"values": [0.4]},
        "bi_directional":        {"values": [False]},
        # decoding & loss
        "beam_width":            {"values": [5]},
        "length_penalty":        {"values": [0, 4]},
        "teacher_forcing_ratio": {"values": [1.0]},
        # optimization
        "optimizer":             {"values": ["adam"]},
        "learning_rate":         {"values": [0.001]},
        # RNN variant
        "rnn_cell":              {"values": ["LSTM"]},
    }
}


In [None]:
def train_model():
    """
    Train Seq2Seq with W&B tracking, early stopping, and evaluation.
    """
    with wandb.init():
        cfg = wandb.config

        # Compose run name for clarity in W&B
        run_name = (
            f"{cfg.rnn_cell.lower()}_dp{int(cfg.dropout_rate*100)}"
            f"_bw{cfg.beam_width}_tf{int(cfg.teacher_forcing_ratio*100)}"
            f"_emb{cfg.embedding_dim}_hid{cfg.hidden_size}"
            f"_enc{cfg.encoder_layers}_dec{cfg.decoder_layers}"
        )
        wandb.run.name = run_name

        # Load datasets and vocabularies
        train_dl, val_dl, test_dl, src_vocab_size, tgt_vocab_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
            LEXICON_ROOT, batch_size=cfg.batch_size, build_vocab=True
        )
        idx_to_char = {i: ch for ch, i in tgt_vocab.items()}

        # Initialize model, loss fn, optimizer, early stopper
        model = Seq2SeqModel(
            src_vocab=src_vocab_size,
            tgt_vocab=tgt_vocab_size,
            emb_dim=cfg.embedding_dim,
            hid_dim=cfg.hidden_size,
            enc_layers=cfg.encoder_layers,
            dec_layers=cfg.decoder_layers,
            rnn_type=cfg.rnn_cell,
            drop_p=cfg.dropout_rate,
            bidir_enc=cfg.bi_directional,
        ).to(DEVICE)

        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
        loss_fn = nn.NLLLoss(ignore_index=pad_idx)
        early_stop = EarlyStopping(patience=5, min_delta=1e-4)
        best_val_accuracy = 0.0

        for epoch in range(1, cfg.num_epochs + 1):
            model.train()
            total_train_loss = 0.0
            correct_preds, total_preds = 0, 0

            for src_batch, tgt_batch in tqdm(train_dl, desc=f"[Epoch {epoch}] Training", leave=False):
                src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)

                optimizer.zero_grad()
                output = model(src_batch, tgt_batch, tf_ratio=cfg.teacher_forcing_ratio)
                loss = loss_fn(output.view(-1, tgt_vocab_size), tgt_batch.view(-1))
                loss.backward()
                optimizer.step()

                total_train_loss += loss.item()

                # Accuracy calculation per sequence
                preds = output.argmax(dim=2)
                for pred_seq, tgt_seq in zip(preds, tgt_batch):
                    pred_trim = pred_seq[1:][tgt_seq[1:] != pad_idx]
                    tgt_trim = tgt_seq[1:][tgt_seq[1:] != pad_idx]
                    if torch.equal(pred_trim, tgt_trim):
                        correct_preds += 1
                    total_preds += 1

            train_accuracy = 100 * correct_preds / total_preds

            # Validation phase
            model.eval()
            val_loss_sum, val_correct, val_total = 0.0, 0, 0
            with torch.no_grad():
                for src_batch, tgt_batch in tqdm(val_dl, desc=f"[Epoch {epoch}] Validation", leave=False):
                    src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)
                    output = model(src_batch, tgt_batch, tf_ratio=0.0)
                    val_loss_sum += loss_fn(output.view(-1, tgt_vocab_size), tgt_batch.view(-1)).item()

                    preds = output.argmax(dim=2)
                    for pred_seq, tgt_seq in zip(preds, tgt_batch):
                        pred_trim = pred_seq[1:][tgt_seq[1:] != pad_idx]
                        tgt_trim = tgt_seq[1:][tgt_seq[1:] != pad_idx]
                        if torch.equal(pred_trim, tgt_trim):
                            val_correct += 1
                        val_total += 1

            val_accuracy = 100 * val_correct / val_total
            avg_val_loss = val_loss_sum / len(val_dl)

            wandb.log({
                "epoch": epoch,
                "train_loss": total_train_loss,
                "train_accuracy": train_accuracy,
                "val_loss": avg_val_loss,
                "val_accuracy": val_accuracy,
            })

            print(f"[Epoch {epoch}] Train Loss={total_train_loss:.3f} Train Acc={train_accuracy:.2f}% | Val Loss={avg_val_loss:.3f} Val Acc={val_accuracy:.2f}%")

            # Early stopping check
            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
            elif early_stop.check(val_accuracy):
                print("Early stopping triggered.")
                break

        # Test evaluation
        model.eval()
        test_correct, test_total = 0, 0
        with torch.no_grad():
            for src_batch, tgt_batch in tqdm(test_dl, desc="Testing", leave=False):
                src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)
                output = model(src_batch, tgt_batch, tf_ratio=0.0)
                preds = output.argmax(dim=2)
                for pred_seq, tgt_seq in zip(preds, tgt_batch):
                    pred_trim = pred_seq[1:][tgt_seq[1:] != pad_idx]
                    tgt_trim = tgt_seq[1:][tgt_seq[1:] != pad_idx]
                    if torch.equal(pred_trim, tgt_trim):
                        test_correct += 1
                    test_total += 1

        test_accuracy = 100 * test_correct / test_total
        print(f"\nFinal Test Accuracy: {test_accuracy:.2f}%")
        wandb.log({"test_accuracy": test_accuracy})


In [None]:
# ─── Initiate Hyperparameter Sweep ─────────────────────────────────────────────
sweep_id = wandb.sweep(sweep_cfg, project="cs24m020_dl_a3_v2")
wandb.agent(sweep_id, function=run_training, count=100)

# ─── Optional Alternate Sweep Setup ────────────────────────────────────────────
# alt_sweep = "8espi10w"
# wandb.agent(
#     sweep_id=alt_sweep,
#     function=run_training,
#     count=100,
#     entity="cs24m020-indian-institute-of-technology-madras",
#     project="cs24m020_dl_a3_v2",
# )

# ─── Finalize W&B Session ──────────────────────────────────────────────────────
wandb.finish()


In [None]:
# Get data loaders and vocabs
train_loader, val_loader, test_loader, src_size, tgt_size, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
    BASE_DIR, batch_size=64, build_vocab=True
)

# Reverse vocab mappings for easy idx-to-char conversion
IDX2CHAR_TGT = {idx: ch for ch, idx in tgt_vocab.items()}
IDX2CHAR_SRC = {idx: ch for ch, idx in src_vocab.items()}

sos_idx = tgt_vocab['<sos>']
eos_idx = tgt_vocab['<eos>']
pad_idx = tgt_vocab['<pad>']

# Best config from tuning/sweep results
BEST_CONFIG = {
    'embedding_dim': 64,
    'hidden_size': 256,
    'encoder_layers': 2,
    'decoder_layers': 2,
    'cell_type': 'LSTM',
    'dropout_p': 0.4,
    'beam_width': 5,
    'teacher_forcing_ratio': 1.0,
    'bidirectional_encoder': False
}

def fast_beam_search(model, src, sos_idx, eos_idx, max_len=30, beam_width=5, device='cuda'):
    """Optimized beam search for a single input sequence."""
    model.eval()
    with torch.no_grad():
        # Encode the input
        encoder_outputs, encoder_hidden = model.encoder(src.unsqueeze(0))
        
        # Prepare decoder hidden state if encoder is bidirectional
        if model.bidirectional_encoder:
            if model.cell_type == 'LSTM':
                h_n, c_n = encoder_hidden
                h_dec = torch.zeros(model.decoder_layers, 1, model.hidden_size).to(device)
                c_dec = torch.zeros(model.decoder_layers, 1, model.hidden_size).to(device)
                for layer in range(model.decoder_layers):
                    enc_layer = min(layer, model.encoder_layers - 1)
                    h_combined = torch.cat((h_n[2*enc_layer], h_n[2*enc_layer+1]), dim=1)
                    c_combined = torch.cat((c_n[2*enc_layer], c_n[2*enc_layer+1]), dim=1)
                    h_dec[layer] = model.hidden_transform(h_combined)
                    c_dec[layer] = model.hidden_transform(c_combined)
                decoder_hidden = (h_dec, c_dec)
            else:
                decoder_hidden = torch.zeros(model.decoder_layers, 1, model.hidden_size).to(device)
                for layer in range(model.decoder_layers):
                    enc_layer = min(layer, model.encoder_layers - 1)
                    h_combined = torch.cat((encoder_hidden[2*enc_layer], encoder_hidden[2*enc_layer+1]), dim=1)
                    decoder_hidden[layer] = model.hidden_transform(h_combined)
        else:
            decoder_hidden = encoder_hidden

        beams = [([sos_idx], 0.0, decoder_hidden)]
        completed = []

        for _ in range(max_len):
            new_beams = []
            for seq, score, hidden in beams:
                if seq[-1] == eos_idx:
                    completed.append((seq, score))
                    continue
                
                input_char = torch.tensor([[seq[-1]]], device=device)
                output, hidden_new = model.decoder(input_char, hidden)
                log_probs = output.squeeze(0)
                topk_log_probs, topk_indices = torch.topk(log_probs, beam_width)
                
                for k in range(beam_width):
                    new_seq = seq + [topk_indices[k].item()]
                    new_score = score + topk_log_probs[k].item()
                    new_beams.append((new_seq, new_score, hidden_new))
            
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
            if not beams:
                break

        completed += [(seq, score) for seq, score, _ in beams if seq[-1] == eos_idx]
        if not completed:
            completed = beams
            
        return sorted(completed, key=lambda x: x[1], reverse=True)[0][0]


def generate_predictions_csv(model, test_loader, output_file='predictions.csv'):
    """Generate predictions for the test set and save results to a CSV file."""
    model.eval()
    device = next(model.parameters()).device
    
    results = {
        'input': [],
        'prediction': [],
        'target': [],
        'is_correct': [],
        'input_length': [],
        'prediction_length': []
    }
    
    with torch.no_grad():
        for src, tgt in tqdm(test_loader, desc="Generating Predictions"):
            src = src.to(device)
            tgt = tgt.to(device)
            
            for i in range(src.size(0)):
                # Convert source indices to string, ignoring padding
                src_seq = src[i].tolist()
                src_str = ''.join([IDX2CHAR_SRC.get(idx, '<unk>') for idx in src_seq if idx != pad_idx])
                
                # Process target sequence, strip <sos>, <eos>, and padding
                tgt_seq = []
                for idx in tgt[i].tolist():
                    if idx == sos_idx:
                        continue
                    if idx == eos_idx or idx == pad_idx:
                        break
                    tgt_seq.append(idx)
                tgt_str = ''.join([IDX2CHAR_TGT.get(idx, '<unk>') for idx in tgt_seq])
                
                # Run beam search prediction
                pred_seq = fast_beam_search(
                    model, src[i], sos_idx, eos_idx,
                    beam_width=BEST_CONFIG['beam_width'],
                    device=device
                )
                # Remove <sos> and <eos> tokens from prediction
                pred_str = ''.join([IDX2CHAR_TGT.get(idx, '<unk>') for idx in pred_seq[1:-1]])
                
                # Record results
                results['input'].append(src_str)
                results['prediction'].append(pred_str)
                results['target'].append(tgt_str)
                results['is_correct'].append(pred_str == tgt_str)
                results['input_length'].append(len(src_str))
                results['prediction_length'].append(len(pred_str))
    
    # Save to CSV
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)
    
    # Print stats
    accuracy = df['is_correct'].mean() * 100
    avg_input_len = df['input_length'].mean()
    avg_pred_len = df['prediction_length'].mean()
    
    print(f"\nPrediction Generation Complete")
    print(f"Saved to: {output_file}")
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Avg Input Length: {avg_input_len:.1f} chars")
    print(f"Avg Prediction Length: {avg_pred_len:.1f} chars")
    
    return df

# Initialize model with best config
best_model = Seq2Seq(
    input_size=len(src_vocab),
    output_size=len(tgt_vocab),
    embedding_dim=BEST_CONFIG['embedding_dim'],
    hidden_size=BEST_CONFIG['hidden_size'],
    encoder_layers=BEST_CONFIG['encoder_layers'],
    decoder_layers=BEST_CONFIG['decoder_layers'],
    cell_type=BEST_CONFIG['cell_type'],
    dropout_p=BEST_CONFIG['dropout_p'],
    bidirectional_encoder=BEST_CONFIG['bidirectional_encoder']
).to(DEVICE)

# ─── Load Data & Build Vocab ────────────────────────────────────────────
train_loader, val_loader, test_loader, src_sz, tgt_sz, pad_idx, src_vocab, tgt_vocab = get_dataloaders(
        LEXICON_ROOT, batch_size=cfg.batch_size, build_vocab=True
)
idx2char = {i: ch for ch, i in tgt_vocab.items()}

# ─── Model, Loss & Optimizer ────────────────────────────────────────────
model = Seq2SeqModel(
        src_vocab=src_sz,
        tgt_vocab=tgt_sz,
        emb_dim=64,
        hid_dim=256,
        enc_layers=2,
        dec_layers=2,
        rnn_type='LSTM',
        drop_p=0.4,
        bidir_enc=True,
        ).to(DEVICE)

# Load trained weights here
# best_model.load_state_dict(torch.load('best_model.pth'))

# Generate predictions CSV
predictions_df = generate_predictions_csv(best_model, test_loader, '/kaggle/working/predictions_vanilla/output.csv')

# Show sample predictions
print("\nSample Predictions:")
sample_df = predictions_df.sample(min(5, len(predictions_df)))
for _, row in sample_df.iterrows():
    print(f"\nInput: {row['input']}")
    print(f"Target: {row['target']}")
    print(f"Prediction: {row['prediction']}")
    print(f"Correct: {'✓' if row['is_correct'] else '✗'}")
