# Question-1 - Defining Dataset(vocab) and seq2seq model classes

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import wandb
wandb.login(key="f659082c2b19bf3ffaaceceb36c1e280541f6b11")

In [None]:
import torch.nn.functional as F

In [None]:
class tsvtokenizer(Dataset):
    def __init__(self, tsv_file, src_vocab=None, tgt_vocab=None, max_len=32, build_vocab=False):
        # Read TSV file - the fields are separated by tabs, and we're interested in the first two columns
        # Each line has: native_text, roman_text, frequency
        try:
            df = pd.read_csv(tsv_file, sep='\t', header=None, 
                             names=['native', 'roman', 'freq'], 
                             usecols=[0, 1], dtype=str)
            print(f"Successfully loaded {len(df)} entries from {tsv_file}")
            
            # Fill NA values to prevent errors
            df['native'] = df['native'].fillna('')
            df['roman'] = df['roman'].fillna('')
            
            # Create pairs of (roman, native) for transliteration
            self.pairs = list(zip(df['roman'], df['native']))
            print(f"Sample data: {self.pairs[:2]}")
        except Exception as e:
            print(f"Error loading dataset: {e}")
            # Create an empty dataset as fallback
            self.pairs = [('', '')]
            
        self.max_len = max_len
        # Build or use provided vocabulary
        if build_vocab:
            # Start vocabulary with special tokens including <sos> and <eos>
            self.src_vocab = {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3}
            self.tgt_vocab = {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3}
            self._build_vocab()
        else:
            self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab
            
    def _build_vocab(self):
        # Build character-level vocabulary from the dataset
        for src, tgt in self.pairs:
            for ch in src:
                if ch not in self.src_vocab: 
                    self.src_vocab[ch] = len(self.src_vocab)
            for ch in tgt:
                if ch not in self.tgt_vocab: 
                    self.tgt_vocab[ch] = len(self.tgt_vocab)
        print(f"Built vocabularies - Source: {len(self.src_vocab)}, Target: {len(self.tgt_vocab)}")
    
    def __len__(self): 
        return len(self.pairs)
    
    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        
        # Convert characters to indices with bounds checking and add start/end tokens
        # Source sequence (input)
        src_idxs = [self.src_vocab['<sos>']]  # Start token
        for ch in src:
            idx = self.src_vocab.get(ch, self.src_vocab['<unk>'])
            if idx >= len(self.src_vocab):
                idx = self.src_vocab['<unk>']  # Safety check
            src_idxs.append(idx)
        src_idxs.append(self.src_vocab['<eos>'])  # End token
            
        # Target sequence (output)
        tgt_idxs = [self.tgt_vocab['<sos>']]  # Start token
        for ch in tgt:
            idx = self.tgt_vocab.get(ch, self.tgt_vocab['<unk>'])
            if idx >= len(self.tgt_vocab):
                idx = self.tgt_vocab['<unk>']  # Safety check
            tgt_idxs.append(idx)
        tgt_idxs.append(self.tgt_vocab['<eos>'])  # End token
        
        # Make sure we don't exceed max_len (accounting for <sos> and <eos> tokens)
        if len(src_idxs) > self.max_len:
            # Keep <sos>, truncate middle, and keep <eos>
            src_idxs = [src_idxs[0]] + src_idxs[1:self.max_len-1] + [src_idxs[-1]]
        if len(tgt_idxs) > self.max_len:
            # Keep <sos>, truncate middle, and keep <eos>
            tgt_idxs = [tgt_idxs[0]] + tgt_idxs[1:self.max_len-1] + [tgt_idxs[-1]]
        
        # Add padding
        pad_src = [self.src_vocab['<pad>']] * (self.max_len - len(src_idxs))
        pad_tgt = [self.tgt_vocab['<pad>']] * (self.max_len - len(tgt_idxs))
        
        # Combine with padding
        src_idxs = src_idxs + pad_src
        tgt_idxs = tgt_idxs + pad_tgt
        
        # Make sure padding index is valid
        assert self.src_vocab['<pad>'] < len(self.src_vocab), "Padding index out of bounds for source vocab"
        assert self.tgt_vocab['<pad>'] < len(self.tgt_vocab), "Padding index out of bounds for target vocab"
        
        return torch.tensor(src_idxs, dtype=torch.long), torch.tensor(tgt_idxs, dtype=torch.long)

In [None]:
# ---- Model Definition ----
class Seq2Seq(nn.Module):
    def __init__(self, cfg, src_vocab_size, tgt_vocab_size):
        super().__init__()
        # Make sure vocab sizes are valid
        assert src_vocab_size > 0, f"Invalid source vocabulary size: {src_vocab_size}"
        assert tgt_vocab_size > 0, f"Invalid target vocabulary size: {tgt_vocab_size}"
        
        # Define model hyperparameters
        self.embed_dim = cfg.embed_dim
        self.hidden_dim = cfg.hidden_dim
        self.cell_type = cfg.cell_type
        self.enc_layers = cfg.enc_layers
        self.dec_layers = cfg.dec_layers
        
        # Initialize embeddings
        self.src_emb = nn.Embedding(src_vocab_size, cfg.embed_dim, padding_idx=0)
        self.tgt_emb = nn.Embedding(tgt_vocab_size, cfg.embed_dim, padding_idx=0)
        
        # Cell type selection
        if cfg.cell_type == 'RNN':
            cell = nn.RNN
        elif cfg.cell_type == 'GRU':
            cell = nn.GRU
        elif cfg.cell_type == 'LSTM':
            cell = nn.LSTM
        else:
            raise ValueError(f"Unsupported cell type: {cfg.cell_type}")
        
        # Define dropout rates
        enc_dr = cfg.dropout if cfg.enc_layers > 1 else 0.0
        dec_dr = cfg.dropout if cfg.dec_layers > 1 else 0.0
        # Initialize encoder and decoder
        self.encoder = cell(cfg.embed_dim, cfg.hidden_dim,
                               num_layers=cfg.enc_layers, batch_first=True, dropout=enc_dr)
        self.decoder = cell(cfg.embed_dim, cfg.hidden_dim,
                               num_layers=cfg.dec_layers, batch_first=True, dropout=dec_dr)
            
        # Output projection
        self.fc = nn.Linear(cfg.hidden_dim, tgt_vocab_size)
        
        print(f"Model initialized: {cfg.cell_type} with {cfg.enc_layers} encoder layers, "
              f"{cfg.dec_layers} decoder layers, {cfg.embed_dim} embedding dim, "
              f"{cfg.hidden_dim} hidden dim")

    def forward(self, src, tgt):
        batch_size = src.size(0)
        src_len = src.size(1)
        tgt_len = tgt.size(1)
        device = src.device
        # Check for out of bounds indices
        if src.max() >= self.src_emb.num_embeddings:
            print(f"Warning: Source index {src.max().item()} is out of bounds for vocab size {self.src_emb.num_embeddings}")
            # Clamp indices to valid range
            src = torch.clamp(src, 0, self.src_emb.num_embeddings - 1)
            
        if tgt.max() >= self.tgt_emb.num_embeddings:
            print(f"Warning: Target index {tgt.max().item()} is out of bounds for vocab size {self.tgt_emb.num_embeddings}")
            # Clamp indices to valid range
            tgt = torch.clamp(tgt, 0, self.tgt_emb.num_embeddings - 1)
        
        # Forward through encoder
        enc_in = self.src_emb(src)
        
        try:
            if self.cell_type == 'LSTM':
                _, (h_n, c_n) = self.encoder(enc_in)
                
                # Adapt hidden state dimensions for different layer counts
                if self.enc_layers != self.dec_layers:
                    # Initialize new hidden and cell states with correct dimensions
                    h_new = torch.zeros(self.dec_layers, batch_size, self.hidden_dim, device=device)
                    c_new = torch.zeros(self.dec_layers, batch_size, self.hidden_dim, device=device)
                    # Copy as many layers as possible
                    copy_layers = min(self.enc_layers, self.dec_layers)
                    h_new[:copy_layers] = h_n[:copy_layers]
                    c_new[:copy_layers] = c_n[:copy_layers]
                    
                    h_n, c_n = h_new, c_new
                
                # Forward through decoder with adjusted hidden states
                dec_in = self.tgt_emb(tgt)
                dec_out, _ = self.decoder(dec_in, (h_n, c_n))
                
            else:  # RNN or GRU
                _, h_n = self.encoder(enc_in)
                
                # Adapt hidden state dimensions for different layer counts
                if self.enc_layers != self.dec_layers:
                    h_new = torch.zeros(self.dec_layers, batch_size, self.hidden_dim, device=device)
                    copy_layers = min(self.enc_layers, self.dec_layers)
                    h_new[:copy_layers] = h_n[:copy_layers]
                    h_n = h_new
                
                # Forward through decoder with adjusted hidden state
                dec_in = self.tgt_emb(tgt)
                dec_out, _ = self.decoder(dec_in, h_n)
                
            # Output projection
            return self.fc(dec_out)
            
        except Exception as e:
            print(f"Error in forward pass: {e}")
            # Return dummy tensor on error
            return torch.zeros(batch_size, tgt.size(1), self.tgt_emb.num_embeddings, device=device)

In [None]:
def generate(model, src, max_len=32, eos_idx=3, sos_idx=2, device="cuda"):
    """
    Generate target sequence using greedy decoding.
    
    Args:
        model: Trained Seq2Seq model
        src: Source sequence tensor [batch_size, seq_len]
        max_len: Maximum length of generated sequence
        eos_idx: Index of <eos> token in target vocabulary
        sos_idx: Index of <sos> token in target vocabulary
        device: Device to run inference on
        
    Returns:
        List of generated sequences
    """
    model.eval()  # Set model to evaluation mode
    batch_size = src.size(0)
    
    with torch.no_grad():
        # Encode source sequence
        enc_in = model.src_emb(src)
        
        if model.cell_type == 'LSTM':
            _, (h_n, c_n) = model.encoder(enc_in)
            
            # Adapt hidden state dimensions for decoder if needed
            if model.enc_layers != model.dec_layers:
                h_new = torch.zeros(model.dec_layers, batch_size, model.hidden_dim, device=device)
                c_new = torch.zeros(model.dec_layers, batch_size, model.hidden_dim, device=device)
                copy_layers = min(model.enc_layers, model.dec_layers)
                h_new[:copy_layers] = h_n[:copy_layers]
                c_new[:copy_layers] = c_n[:copy_layers]
                h_n, c_n = h_new, c_new
                
            hidden = (h_n, c_n)
        else:  # RNN or GRU
            _, h_n = model.encoder(enc_in)
            
            # Adapt hidden state dimensions for decoder if needed
            if model.enc_layers != model.dec_layers:
                h_new = torch.zeros(model.dec_layers, batch_size, model.hidden_dim, device=device)
                copy_layers = min(model.enc_layers, model.dec_layers)
                h_new[:copy_layers] = h_n[:copy_layers]
                h_n = h_new
                
            hidden = h_n
        
        # Start with <sos> tokens for each sequence in batch
        current_token = torch.full((batch_size, 1), sos_idx, dtype=torch.long, device=device)
        
        # Store generated sequences
        generated_sequences = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
        generated_sequences[:, 0] = sos_idx
        
        # Track if sequence is completed (has generated EOS)
        completed = torch.zeros(batch_size, dtype=torch.bool, device=device)
        
        # Generate one token at a time
        for i in range(1, max_len):
            # Embed current token
            dec_in = model.tgt_emb(current_token)
            
            # Pass through decoder
            if model.cell_type == 'LSTM':
                dec_out, hidden = model.decoder(dec_in, hidden)
            else:
                dec_out, hidden = model.decoder(dec_in, hidden)
                
            # Get output probabilities and select most likely token
            logits = model.fc(dec_out)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.argmax(probs, dim=-1)
            
            # Store token in output sequence
            generated_sequences[:, i] = next_token.squeeze(1)
            
            # Mark sequences that generated EOS
            completed = completed | (next_token.squeeze(1) == eos_idx)
            
            # Stop if all sequences have generated EOS
            if completed.all():
                break
                
            # Update current token for next iteration
            # Only update tokens for sequences that haven't completed yet
            current_token = next_token
            
            # For completed sequences, we'll continue the loop but their tokens
            # won't matter since we'll post-process the output
            
        # Post-process sequences - ensure all have EOS and nothing after
        for b in range(batch_size):
            # Find first EOS token (if any)
            eos_positions = (generated_sequences[b] == eos_idx).nonzero(as_tuple=True)[0]
            
            if len(eos_positions) > 0:
                # Get the position of the first EOS
                first_eos = eos_positions[0].item()
                
                # If EOS isn't the last token, zero out everything after it
                if first_eos < max_len - 1:
                    generated_sequences[b, first_eos+1:] = 0
            else:
                # If no EOS found, append it at the end
                generated_sequences[b, -1] = eos_idx
                
    return generated_sequences


def decode_predictions(sequences, idx_to_char, eos_idx=3):
    """
    Convert token indices to characters and trim after first <eos>.
    
    Args:
        sequences: Tensor of token indices [batch_size, seq_len]
        idx_to_char: Dictionary mapping indices to characters
        eos_idx: Index of <eos> token
        
    Returns:
        List of decoded strings
    """
    batch_size = sequences.size(0)
    decoded = []
    
    for b in range(batch_size):
        chars = []
        for idx in sequences[b]:
            idx = idx.item()
            if idx == eos_idx:
                chars.append('<eos>')
                break  # Stop at first <eos>
            if idx != 0:  # Skip padding
                chars.append(idx_to_char.get(idx, '<unk>'))
        
        decoded.append(''.join(chars))
    
    return decoded

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """
    Train the model for one epoch.
    
    Args:
        model: The Seq2Seq model
        loader: DataLoader with training data
        criterion: Loss function
        optimizer: Optimizer
        device: Device to use for computation
        
    Returns:
        Average loss for the epoch
    """
    model.train()
    total_loss = 0.0
    num_batches = len(loader)
    
    for batch_idx, (src, tgt) in enumerate(loader):
        try:
            # Move data to device
            src, tgt = src.to(device), tgt.to(device)
            
            # Verify indices are within vocabulary size range
            src_max = src.max().item()
            tgt_max = tgt.max().item()
            if src_max >= model.src_emb.num_embeddings or tgt_max >= model.tgt_emb.num_embeddings:
                print(f"Batch {batch_idx}/{num_batches}: Invalid indices - "
                      f"Source max: {src_max}, Target max: {tgt_max}, "
                      f"Source vocab: {model.src_emb.num_embeddings}, Target vocab: {model.tgt_emb.num_embeddings}")
                # Skip this batch
                continue
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            logits = model(src, tgt[:, :-1])
            
            # Reshape for loss calculation
            logits_flat = logits.reshape(-1, logits.size(-1))
            tgt_flat = tgt[:, 1:].reshape(-1)
            
            # Compute loss
            loss = criterion(logits_flat, tgt_flat)
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
                
        except Exception as e:
            print(f"Error in batch {batch_idx}/{num_batches}: {e}")
            continue
    
    return total_loss / num_batches

In [None]:
def decode_batch(batch_sequences, idx_to_char, eos_idx=3, join=True):
    """
    Decode a batch of token sequences back to text.
    
    Args:
        batch_sequences: Tensor of token indices [batch_size, seq_len]
        idx_to_char: Dictionary mapping indices to characters
        eos_idx: Index of <eos> token
        join: Whether to join characters into strings
        
    Returns:
        List of decoded strings or character lists
    """
    batch_size = batch_sequences.size(0)
    decoded = []
    
    for b in range(batch_size):
        chars = []
        for idx in batch_sequences[b]:
            idx = idx.item()
            # Skip <sos> (assumed to be at the start)
            if len(chars) == 0 and idx == 2:  # <sos> token
                continue
                
            # Stop at <eos>
            if idx == eos_idx:
                break
                
            # Skip padding
            if idx == 0:
                continue
                
            chars.append(idx_to_char.get(idx, '<unk>'))
        
        if join:
            decoded.append(''.join(chars))
        else:
            decoded.append(chars)
    
    return decoded


def transliterate(model, text, src_vocab, tgt_vocab, device, max_len=32):
    """
    Transliterate a single text input.
    
    Args:
        model: Trained Seq2Seq model
        text: Input text string
        src_vocab: Source vocabulary (char -> idx)
        tgt_vocab: Target vocabulary (char -> idx)
        device: Device to run inference on
        max_len: Maximum length of generated sequence
        
    Returns:
        Transliterated text
    """
    model.eval()
    
    # Create inverse vocab for decoding
    idx_to_char = {idx: char for char, idx in tgt_vocab.items()}
    
    # Get indices for special tokens
    sos_idx = src_vocab.get('<sos>', 2)  # Default to 2 if not in vocab
    eos_idx = src_vocab.get('<eos>', 3)  # Default to 3 if not in vocab
    
    # Convert input text to tensor
    indices = [src_vocab.get('<sos>', 2)]  # Start with <sos>
    for ch in text:
        indices.append(src_vocab.get(ch, src_vocab.get('<unk>', 1)))
    indices.append(src_vocab.get('<eos>', 3))  # End with <eos>
    
    # Pad to max_len
    if len(indices) < max_len:
        indices += [src_vocab.get('<pad>', 0)] * (max_len - len(indices))
    else:
        # If too long, keep <sos>, as much content as fits, and <eos>
        indices = [indices[0]] + indices[1:max_len-1] + [indices[-1]]
    
    # Convert to tensor and add batch dimension
    src_tensor = torch.tensor([indices], dtype=torch.long).to(device)
    
    # Generate transliteration
    with torch.no_grad():
        output = generate(
            model, 
            src_tensor,
            max_len=max_len,
            eos_idx=tgt_vocab.get('<eos>', 3),
            sos_idx=tgt_vocab.get('<sos>', 2),
            device=device
        )
        
        # Decode the output
        result = decode_batch(output, idx_to_char, eos_idx=tgt_vocab.get('<eos>', 3))[0]
    
    return result

## Metric functions

In [None]:
def exact_match_accuracy(model, loader, device, tgt_vocab, sos_idx=2, eos_idx=3):
    """
    Compute exact match accuracy for transliteration.
    
    Args:
        model: The Seq2Seq model
        loader: DataLoader with evaluation data
        device: Device to use for computation
        tgt_vocab: Target vocabulary (char -> idx)
        sos_idx: Index of the start-of-sequence token
        eos_idx: Index of the end-of-sequence token
        
    Returns:
        Exact match accuracy percentage
    """
    model.eval()
    correct = 0
    total = 0
    
    # Create reverse vocabulary mapping for decoding
    idx_to_char = {idx: char for char, idx in tgt_vocab.items()}
    
    with torch.no_grad():
        for src, tgt in loader:
            # Move data to device
            src, tgt = src.to(device), tgt.to(device)
            batch_size = src.size(0)
            
            # Use our generate function for inference
            predictions = generate(
                model,
                src,
                max_len=tgt.size(1),  # Use same max length as target
                eos_idx=eos_idx,
                sos_idx=sos_idx,
                device=device
            )
            
            # Convert predictions and targets to strings for comparison
            for i in range(batch_size):
                # Process target sequence (skip <sos>, stop at <eos>)
                target_seq = []
                for idx in tgt[i, 1:]:  # Skip <sos>
                    idx = idx.item()
                    if idx == eos_idx:  # Stop at <eos>
                        break
                    if idx != 0:  # Skip padding
                        target_seq.append(idx)
                
                # Process predicted sequence (skip <sos>, stop at <eos>)
                pred_seq = []
                for idx in predictions[i, 1:]:  # Skip <sos>
                    idx = idx.item()
                    if idx == eos_idx:  # Stop at <eos>
                        break
                    if idx != 0:  # Skip padding
                        pred_seq.append(idx)
                
                # Convert indices to characters
                target_text = ''.join([idx_to_char.get(idx, '') for idx in target_seq])
                pred_text = ''.join([idx_to_char.get(idx, '') for idx in pred_seq])
                
                # Check if prediction exactly matches target
                if target_text == pred_text:
                    correct += 1
                total += 1
    
    # Return accuracy as a percentage
    return 100.0 * correct / total if total > 0 else 0.0


def compute_metrics(model, loader, device, src_vocab, tgt_vocab, sos_idx=2, eos_idx=3):
    """
    Compute comprehensive evaluation metrics for the model.
    
    Args:
        model: The Seq2Seq model
        loader: DataLoader with evaluation data
        device: Device to use for computation
        tgt_vocab: Target vocabulary (char -> idx)
        sos_idx: Index of the start-of-sequence token
        eos_idx: Index of the end-of-sequence token
        
    Returns:
        Dictionary with multiple evaluation metrics
    """
    model.eval()
    
    # Metrics
    exact_matches = 0
    total_samples = 0
    char_correct = 0
    char_total = 0
    
    # For more detailed analysis
    results = []
    
    # Create reverse vocabulary mapping for decoding
    idx_to_char = {idx: char for char, idx in tgt_vocab.items()}
    
    with torch.no_grad():
        for src, tgt in loader:
            # Move data to device
            src, tgt = src.to(device), tgt.to(device)
            batch_size = src.size(0)
            
            # Use our generate function for inference
            predictions = generate(
                model,
                src,
                max_len=tgt.size(1),  # Use same max length as target
                eos_idx=eos_idx,
                sos_idx=sos_idx,
                device=device
            )
            
            # Process each sample in the batch
            for i in range(batch_size):
                # Convert source to text
                src_text = ""
                for idx in src[i]:
                    idx = idx.item()
                    if idx == sos_idx:
                        continue  # Skip <sos>
                    if idx == eos_idx:
                        break  # Stop at <eos>
                    if idx != 0:  # Skip padding
                        src_char = next((char for char, vidx in src_vocab.items() if vidx == idx), '')
                        src_text += src_char
                
                # Process target sequence (skip <sos>, stop at <eos>)
                target_seq = []
                for idx in tgt[i, 1:]:  # Skip <sos>
                    idx = idx.item()
                    if idx == eos_idx:  # Stop at <eos>
                        break
                    if idx != 0:  # Skip padding
                        target_seq.append(idx)
                
                # Process predicted sequence (skip <sos>, stop at <eos>)
                pred_seq = []
                for idx in predictions[i, 1:]:  # Skip <sos>
                    idx = idx.item()
                    if idx == eos_idx:  # Stop at <eos>
                        break
                    if idx != 0:  # Skip padding
                        pred_seq.append(idx)
                
                # Convert indices to characters
                target_text = ''.join([idx_to_char.get(idx, '') for idx in target_seq])
                pred_text = ''.join([idx_to_char.get(idx, '') for idx in pred_seq])
                
                # Check exact match
                is_exact_match = target_text == pred_text
                if is_exact_match:
                    exact_matches += 1
                
                # Count character-level matches
                min_len = min(len(target_text), len(pred_text))
                char_correct += sum(1 for j in range(min_len) if target_text[j] == pred_text[j])
                char_total += len(target_text)
                
                # Store result for this sample
                results.append({
                    'source': f"<sos>{src_text}<eos>",
                    'target': f"{target_text}<eos>",
                    'prediction': f"{pred_text}<eos>",
                    'is_correct': is_exact_match
                })
                
                total_samples += 1
    
    # Calculate metrics
    exact_match_acc = (exact_matches / total_samples) * 100 if total_samples > 0 else 0
    char_acc = (char_correct / char_total) * 100 if char_total > 0 else 0
    
    return {
        'exact_match_accuracy': exact_match_acc,
        'character_accuracy': char_acc,
        'results': results[:10]
    }

In [None]:
def eval_model(model, loader, criterion, device, sos_idx=2, eos_idx=3):
    """
    Evaluate the model on validation data.
    
    Args:
        model: The Seq2Seq model
        loader: DataLoader with validation data
        criterion: Loss function
        device: Device to use for computation
        sos_idx: Index of <sos> token
        eos_idx: Index of <eos> token
        
    Returns:
        Dictionary with evaluation metrics
    """
    model.eval()
    total_loss = 0.0
    num_batches = len(loader)
    
    # For computing accuracy
    total_correct = 0
    total_tokens = 0
    
    # For sequence accuracy
    correct_sequences = 0
    total_sequences = 0
    
    with torch.no_grad():
        for batch_idx, (src, tgt) in enumerate(loader):
            try:
                src, tgt = src.to(device), tgt.to(device)
                batch_size = src.size(0)
                
                # Forward pass for loss computation
                logits = model(src, tgt[:, :-1])
                
                # Reshape for loss calculation
                logits_flat = logits.reshape(-1, logits.size(-1))
                tgt_flat = tgt[:, 1:].reshape(-1)
                
                # Compute loss
                loss = criterion(logits_flat, tgt_flat)
                total_loss += loss.item()
                
                # Use our generate function for inference
                predictions = generate(
                    model,
                    src,
                    max_len=tgt.size(1),  # Use same max length as target
                    eos_idx=eos_idx,
                    sos_idx=sos_idx,
                    device=device
                )
                
                # Compute accuracy metrics
                for b in range(batch_size):
                    # Compare each predicted token with target
                    # Skip the first token (<sos>) for evaluation
                    pred_seq = predictions[b, 1:]  # Skip <sos>
                    tgt_seq = tgt[b, 1:]  # Skip <sos>
                    
                    # Find position of first <eos> in each sequence
                    pred_eos_pos = (pred_seq == eos_idx).nonzero(as_tuple=True)[0]
                    tgt_eos_pos = (tgt_seq == eos_idx).nonzero(as_tuple=True)[0]
                    
                    # Get effective lengths (up to first <eos> or full length)
                    pred_len = pred_eos_pos[0].item() + 1 if len(pred_eos_pos) > 0 else len(pred_seq)
                    tgt_len = tgt_eos_pos[0].item() + 1 if len(tgt_eos_pos) > 0 else len(tgt_seq)
                    
                    # Get the sequences up to the first <eos> or full length
                    pred_trimmed = pred_seq[:pred_len]
                    tgt_trimmed = tgt_seq[:tgt_len]
                    
                    # Compare each token
                    min_len = min(len(pred_trimmed), len(tgt_trimmed))
                    correct_tokens = (pred_trimmed[:min_len] == tgt_trimmed[:min_len]).sum().item()
                    
                    total_correct += correct_tokens
                    total_tokens += len(tgt_trimmed)
                    
                    # Check if the whole sequence is correct
                    if len(pred_trimmed) == len(tgt_trimmed) and correct_tokens == len(tgt_trimmed):
                        correct_sequences += 1
                    
                    total_sequences += 1
                
            except Exception as e:
                print(f"Error in evaluation batch {batch_idx}/{num_batches}: {e}")
                continue
    
    # Calculate metrics
    avg_loss = total_loss / num_batches
    token_accuracy = total_correct / total_tokens if total_tokens > 0 else 0
    sequence_accuracy = correct_sequences / total_sequences if total_sequences > 0 else 0
    
    return {
        "val_loss": avg_loss,
        "token_accuracy": token_accuracy,
        "sequence_accuracy": sequence_accuracy
    }

# Question-2: WandB sweeps with val_loss

In [None]:
# ---- Sweep Configuration ----
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_loss', 'goal': 'minimize'},
    'parameters': {
        'embed_dim': {'values': [16, 32, 64, 256]},
        'hidden_dim': {'values': [16, 32, 64, 256]},
        'cell_type': {'values': ['GRU','RNN','LSTM']},
        'enc_layers': {'values': [1, 2, 3]},
        'dec_layers': {'values': [1, 2, 3]},
        'dropout': {'values': [0.2, 0.3]},
        'learning_rate': {'values': [1e-3, 1e-4]},
        'batch_size': {'values': [32, 64]},
        'beam_size': {'values': [1, 3, 5]}
    }
}

In [None]:
# ---- Helper Functions ----
def save_vocab(vocab_path, src_vocab, tgt_vocab):
    """Save vocabularies to JSON files"""
    os.makedirs(vocab_path, exist_ok=True)
    with open(os.path.join(vocab_path, 'src.json'), 'w', encoding='utf-8') as f:
        json.dump(src_vocab, f, ensure_ascii=False, indent=2)
    with open(os.path.join(vocab_path, 'tgt.json'), 'w', encoding='utf-8') as f:
        json.dump(tgt_vocab, f, ensure_ascii=False, indent=2)

def load_vocab(vocab_path):
    """Load vocabularies from JSON files"""
    with open(os.path.join(vocab_path, 'src.json'), 'r', encoding='utf-8') as f:
        src_vocab = json.load(f)
    with open(os.path.join(vocab_path, 'tgt.json'), 'r', encoding='utf-8') as f:
        tgt_vocab = json.load(f)
    return src_vocab, tgt_vocab

# ---- Simple test method to check for CUDA issues ----
def test_cuda():
    """Test if CUDA works correctly"""
    print("Testing CUDA...")
    try:
        x = torch.rand(10, 10)
        if torch.cuda.is_available():
            print("CUDA is available")
            x = x.cuda()
            print("Successfully moved tensor to CUDA")
            y = x + x
            print("Successfully performed CUDA operation")
        else:
            print("CUDA is not available, using CPU")
        return True
    except Exception as e:
        print(f"CUDA test failed: {e}")
        return False

In [None]:
cuda_ok = test_cuda()

# File paths
train_tsv = '/kaggle/input/dakshina-dataset/hindi/hi/lexicons/hi.translit.sampled.train.tsv'
dev_tsv = '/kaggle/input/dakshina-dataset/hindi/hi/lexicons/hi.translit.sampled.dev.tsv'
vocab_path = '/kaggle/working/vocab'
epochs = 10

# Build vocabulary
print("Building vocabulary...")
train_dataset = tsvtokenizer(train_tsv, build_vocab=True)
src_vocab, tgt_vocab = train_dataset.src_vocab, train_dataset.tgt_vocab

# Save vocabulary
save_vocab(vocab_path, src_vocab, tgt_vocab)
print(f"Vocabulary sizes: Source = {len(src_vocab)}, Target = {len(tgt_vocab)}")

# Print sample entries from vocabulary
print("Sample source vocabulary entries:")
sample_src = list(src_vocab.items())[:10]
for char, idx in sample_src:
    print(f"  '{char}': {idx}")

print("Sample target vocabulary entries:")
sample_tgt = list(tgt_vocab.items())[:10]
for char, idx in sample_tgt:
    print(f"  '{char}': {idx}")

In [None]:
# Initialize W&B sweep
sweep_id = wandb.sweep(sweep_config, project='DA6401_Assignment_03')
def sweep_run(epochs=10):
    """Function to run for each sweep configuration"""
    run = wandb.init()
    cfg = run.config
    
    # Create a descriptive run name
    run.name = f"{cfg.cell_type}-e{cfg.embed_dim}-h{cfg.hidden_dim}-enc{cfg.enc_layers}-dec{cfg.dec_layers}-d{cfg.dropout}-lr{cfg.learning_rate}-b{cfg.batch_size}-beam{cfg.beam_size}"
    
    # Set device - force CPU initially if CUDA issues were detected
    if cuda_ok:
        try:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        except:
            device = torch.device('cpu')
    else:
        device = torch.device('cpu')
        
    print(f"Using device: {device}")
    
    # Initialize model
    try:
        # First create on CPU
        model = Seq2Seq(cfg, len(src_vocab), len(tgt_vocab))
        print("Model created on CPU, trying to move to device...")
        # Then try to move to target device
        model = model.to(device)
        print("Model successfully moved to device.")
    except Exception as e:
        print(f"Error initializing model on {device}: {e}")
        print("Falling back to CPU")
        device = torch.device('cpu')
        model = Seq2Seq(cfg, len(src_vocab), len(tgt_vocab)).to(device)
    
    # Load datasets
    print("Loading datasets...")
    train_dataset = tsvtokenizer(train_tsv, src_vocab, tgt_vocab)
    dev_dataset = tsvtokenizer(dev_tsv, src_vocab, tgt_vocab)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
    dev_loader = DataLoader(dev_dataset, batch_size=cfg.batch_size)
    print(f"Loaded {len(train_dataset)} training examples and {len(dev_dataset)} validation examples")
        
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is padding index
    optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
    
    # Training loop
    best_val_loss = float('inf')
    step = 0
        
    try:
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            # Training
            model.train()
            total_loss = 0
            batches = 0
            
            for batch_idx, (src, tgt) in enumerate(train_loader):
                batches += 1
                src, tgt = src.to(device), tgt.to(device)
                
                optimizer.zero_grad()
                output = model(src, tgt)
                
                # Reshape output and target for loss calculation
                output = output[:, :-1, :].reshape(-1, output.shape[-1])
                target = tgt[:, 1:].reshape(-1)
                
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                # Optionally log every N batches
                if batch_idx % 500 == 0 and batch_idx > 0:
                    wandb.log({
                        'batch_train_loss': loss.item(),
                    }, step=step + batch_idx)
            
            step += batches
            train_loss = total_loss / batches
            print(f"Train loss: {train_loss:.4f}")
            
            # Validation
            model.eval()
            total_val_loss = 0
            val_batches = 0
            
            with torch.no_grad():
                for src, tgt in dev_loader:
                    val_batches += 1
                    src, tgt = src.to(device), tgt.to(device)
                    
                    output = model(src, tgt)
                    
                    # Reshape output and target for loss calculation
                    output = output[:, :-1, :].reshape(-1, output.shape[-1])
                    target = tgt[:, 1:].reshape(-1)
                    
                    loss = criterion(output, target)
                    total_val_loss += loss.item()
            
            val_loss = total_val_loss / val_batches
            print(f"Validation loss: {val_loss:.4f}")
            
            # Log metrics
            wandb.log({
                'train_loss': train_loss, 
                'val_loss': val_loss, 
                'epoch': epoch+1,
            }, step=step)
                
            # Save best model - create a NEW artifact each time
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                model_path = f'/kaggle/working/best_model_sweep.pt'
                
                # Save the model
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'config': {k: v for k, v in cfg.items()},
                    'src_vocab_size': len(src_vocab),
                    'tgt_vocab_size': len(tgt_vocab),
                    'epoch': epoch+1,
                    'step': step,
                    'best_val_loss': best_val_loss,
                }, model_path)
                
                # Create a NEW artifact with a unique name including the epoch
                artifact_name = f'model-epoch-{epoch+1}'
                artifact = wandb.Artifact(artifact_name, type='model')
                artifact.add_file(model_path)
                run.log_artifact(artifact)
                
                print(f"Saved new best model at {model_path} with validation loss: {val_loss:.4f}")
                
    except Exception as e:
        print(f"Error during training: {str(e)}")
        import traceback
        traceback.print_exc()  # This will print the full stack trace

In [None]:
# Run the sweep
wandb.agent(sweep_id, function=sweep_run, count=30)

# Question-3: Observations based on sweeps.

# Question-4: Running on test data, creating predictions vanilla folder, analyze and comment on errors.

In [None]:
# Load the best model from your sweep
def load_best_model(model_path, src_vocab, tgt_vocab):
    checkpoint = torch.load(model_path)
    
    # Extract configuration from checkpoint
    config_dict = checkpoint['config']
    # Convert to a config object with attributes
    cfg = type('Config', (), config_dict)
    
    # Initialize model with saved config
    model = Seq2Seq(cfg, len(src_vocab), len(tgt_vocab))
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, cfg

In [None]:
import time
def get_predictions(model, loader, device, src_vocab, tgt_vocab, num_samples=10):
    """
    Get sample predictions from the model using the generate function.
    
    Args:
        model: Trained Seq2Seq model
        loader: DataLoader with evaluation data
        device: Device to run inference on
        src_vocab: Source vocabulary (char -> idx)
        tgt_vocab: Target vocabulary (char -> idx)
        num_samples: Number of samples to collect
        
    Returns:
        List of dictionaries with source, target, prediction, and correctness
    """
    model.eval()
    predictions = []
    
    # Create reverse vocabulary mappings
    idx_to_src = {idx: char for char, idx in src_vocab.items()}
    idx_to_tgt = {idx: char for char, idx in tgt_vocab.items()}
    
    # Get indices for special tokens
    eos_idx = tgt_vocab.get('<eos>', 3)
    sos_idx = tgt_vocab.get('<sos>', 2)
    
    with torch.no_grad():
        for src, tgt in loader:
            # Move data to device
            src, tgt = src.to(device), tgt.to(device)
            
            # Generate predictions using the specialized generate function
            generated_sequences = generate(
                model,
                src,
                max_len=tgt.size(1),  # Use same length as target sequences
                eos_idx=eos_idx,
                sos_idx=sos_idx,
                device=device
            )
            
            # Decode all sequences
            src_texts = decode_batch(src, idx_to_src, eos_idx=eos_idx)
            tgt_texts = decode_batch(tgt, idx_to_tgt, eos_idx=eos_idx)
            pred_texts = decode_batch(generated_sequences, idx_to_tgt, eos_idx=eos_idx)
            
            # Process batch
            for i in range(src.size(0)):
                # Check if prediction is correct (after removing any special tokens)
                is_correct = (tgt_texts[i] == pred_texts[i])
                
                predictions.append({
                    'source': src_texts[i],
                    'target': tgt_texts[i],
                    'prediction': pred_texts[i],
                    'is_correct': is_correct
                })
                
                if len(predictions) >= num_samples:
                    return predictions
    
    return predictions

In [None]:
def retrain_model(model_path, train_tsv, dev_tsv, vocab_path, epochs=20, max_len=32, early_stop_patience=5):
    """
    Retrain the best model from scratch using the best hyperparameters
    from the sweep.
    
    Args:
        model_path: Path to the best model checkpoint
        train_tsv: Path to the training data
        dev_tsv: Path to the validation data
        vocab_path: Path to save/load vocabulary
        epochs: Number of epochs to train for
        max_len: Maximum sequence length
        early_stop_patience: Number of epochs to wait without improvement before stopping
        
    Returns:
        model: The trained model
        best_val_loss: The best validation loss achieved
        src_vocab: Source vocabulary
        tgt_vocab: Target vocabulary
    """
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load vocabulary from JSON files
    try:
        print("Loading saved vocabulary...")
        with open(f"{vocab_path}_src.json", 'r', encoding='utf-8') as f:
            src_vocab = json.load(f)
        with open(f"{vocab_path}_tgt.json", 'r', encoding='utf-8') as f:
            tgt_vocab = json.load(f)
        print("Vocabulary loaded successfully")
    except Exception as e:
        print(f"Error loading vocabulary: {e}")
        print("Building vocabulary from training data...")
        train_dataset = tsvtokenizer(train_tsv, max_len=max_len, build_vocab=True)
        src_vocab, tgt_vocab = train_dataset.src_vocab, train_dataset.tgt_vocab
        
        # Save vocabulary
        try:
            print("Saving vocabulary...")
            os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
            with open(f"{vocab_path}_src.json", 'w', encoding='utf-8') as f:
                json.dump(src_vocab, f, ensure_ascii=False, indent=2)
            with open(f"{vocab_path}_tgt.json", 'w', encoding='utf-8') as f:
                json.dump(tgt_vocab, f, ensure_ascii=False, indent=2)
            print("Vocabulary saved successfully")
        except Exception as e:
            print(f"Error saving vocabulary: {e}")
    
    print(f"Vocabulary sizes: Source = {len(src_vocab)}, Target = {len(tgt_vocab)}")
    
    # Load the best model configuration
    print(f"Loading best model configuration from: {model_path}")
    try:
        checkpoint = torch.load(model_path, map_location=device)
        config_dict = checkpoint['config']
        print("Successfully loaded checkpoint")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        raise
    
    # Initialize a new run with the same config
    try:
        run = wandb.init(project='DA6401_Assignment_03_Retrain', 
                        config=config_dict,
                        name=f"Retrain_{os.path.basename(model_path).replace('best_model_', '').replace('.pt', '')}")
        
        cfg = run.config
    except Exception as e:
        print(f"Error initializing wandb: {e}. Continuing without logging.")
        cfg = config_dict
    
    # Create a Config object with attributes for model initialization
    class Config:
        def __init__(self, config_dict):
            for key, value in config_dict.items():
                setattr(self, key, value)
    
    config_obj = Config(config_dict)
    
    # Print configuration
    print("Training with the following configuration:")
    for key, value in config_dict.items():
        print(f"  {key}: {value}")
    
    # Initialize model with the best hyperparameters
    print("Initializing model with best hyperparameters")
    model = Seq2Seq(config_obj, len(src_vocab), len(tgt_vocab)).to(device)
    
    # Load datasets
    print("Loading datasets...")
    train_dataset = tsvtokenizer(train_tsv, src_vocab, tgt_vocab, max_len=max_len)
    dev_dataset = tsvtokenizer(dev_tsv, src_vocab, tgt_vocab, max_len=max_len)
    
    # Create data loaders
    batch_size = config_dict.get('batch_size', 64)  # Use default if not available
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
    print(f"Loaded {len(train_dataset)} training examples and {len(dev_dataset)} validation examples")
    
    # Initialize loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is padding index
    lr = config_dict.get('learning_rate', 0.001)  # Use default if not available
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Add learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=2, factor=0.5, verbose=True
    )
    
    # Training loop
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    step = 0
    start_time = time.time()
    
    # Track some sample predictions
    print("Collecting initial sample predictions...")
    samples = get_predictions(model, dev_loader, device, src_vocab, tgt_vocab, num_samples=5)
    for i, sample in enumerate(samples):
        print(f"Sample {i+1}:")
        print(f"  Source: {sample['source']}")
        print(f"  Target: {sample['target']}")
        print(f"  Prediction: {sample['prediction']}")
        print(f"  Correct: {sample['is_correct']}")
    
    for epoch in range(epochs):
        epoch_start = time.time()
        print(f"Epoch {epoch+1}/{epochs}")
        
        # Training
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validation
        val_result = eval_model(model, dev_loader, criterion, device)
        val_loss = val_result['val_loss']
        
        # Update learning rate scheduler
        scheduler.step(val_loss)
        
        # Calculate epoch time
        epoch_time = time.time() - epoch_start
        
        print(f"Epoch {epoch+1} complete in {epoch_time:.2f}s")
        print(f"  Train loss: {train_loss:.4f}")
        print(f"  Validation loss: {val_loss:.4f}")
        print(f"  Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Log metrics if wandb is available
        try:
            wandb.log({
                'train_loss': train_loss, 
                'val_loss': val_loss, 
                'epoch': epoch+1,
                'learning_rate': optimizer.param_groups[0]['lr'],
                'epoch_time': epoch_time
            }, step=step + epoch)
        except Exception as e:
            print(f"Error logging to wandb: {e}. Continuing without logging.")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            model_path_epoch = f'/kaggle/working/final_retrained_model.pt'
            
            # Save the model
            torch.save({
                'model_state_dict': model.state_dict(),
                'config': config_dict,
                'src_vocab_size': len(src_vocab),
                'tgt_vocab_size': len(tgt_vocab),
                'epoch': epoch+1,
                'step': step + epoch,
                'best_val_loss': best_val_loss,
            }, model_path_epoch)
            
            # Log artifact if wandb is available
            try:
                # Create a NEW artifact with a unique name
                artifact_name = f'retrained-model-epoch-{epoch+1}'
                artifact = wandb.Artifact(artifact_name, type='model')
                artifact.add_file(model_path_epoch)
                run.log_artifact(artifact)
            except Exception as e:
                print(f"Error logging artifact to wandb: {e}. Continuing without logging.")
            
            print(f"Saved new best model with validation loss: {val_loss:.4f}")
            
            # Reset patience counter
            patience_counter = 0
            
            # Get and print some sample predictions
            if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
                print("Collecting sample predictions...")
                samples = get_predictions(model, dev_loader, device, src_vocab, tgt_vocab, num_samples=5)
                for i, sample in enumerate(samples):
                    print(f"Sample {i+1}:")
                    print(f"  Source: {sample['source']}")
                    print(f"  Target: {sample['target']}")
                    print(f"  Prediction: {sample['prediction']}")
                    print(f"  Correct: {sample['is_correct']}")
        else:
            # Increment patience counter if no improvement
            patience_counter += 1
            if patience_counter >= early_stop_patience:
                print(f"Early stopping triggered after {epoch+1} epochs without improvement")
                break
    
    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f}s")
    print(f"Best validation loss: {best_val_loss:.4f}")
    
    # If we have a best model state, load it back
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    # Calculate final accuracy on dev set
    print("Evaluating model on development set...")
    samples = get_predictions(model, dev_loader, device, src_vocab, tgt_vocab, num_samples=100)
    correct = sum(1 for sample in samples if sample['is_correct'])
    accuracy = correct / len(samples) if samples else 0
    print(f"Development set accuracy: {accuracy:.4f} ({correct}/{len(samples)})")
    
    try:
        wandb.log({'final_accuracy': accuracy})
        wandb.finish()
    except Exception as e:
        print(f"Error finishing wandb run: {e}")
    
    # Test with a few examples using the transliterate function
    print("\nTesting with a few examples:")
    test_examples = ["hello", "world", "python", "transliteration"]
    for example in test_examples:
        result = transliterate(model, example, src_vocab, tgt_vocab, device, max_len=max_len)
        print(f"  '{example}' -> '{result}'")
    
    return model, best_val_loss, src_vocab, tgt_vocab

In [None]:
best_model_path = "/kaggle/working/best_model_sweep.pt"
test_tsv = '/kaggle/input/dakshina-dataset/hindi/hi/lexicons/hi.translit.sampled.test.tsv'
vocab_path = '/kaggle/working/vocab'

In [None]:
print("Retraining model with best hyperparameters...")
epochs = 50  # Increase for better results
retrained_model, best_val_loss,src_vocab,tgt_vocab = retrain_model(
    best_model_path, train_tsv, dev_tsv, vocab_path, epochs=epochs
)
wandb.init()
config_dict = {k: v for k, v in wandb.config.items()}
retrained_model_path = '/kaggle/working/final_retrained_model.pt'
torch.save({
    'model_state_dict': retrained_model.state_dict(),
    'val_loss': best_val_loss,
    'src_vocab_size': len(src_vocab),
    'tgt_vocab_size': len(tgt_vocab),
    'config': config_dict,  # Save the current config
}, retrained_model_path)
print(f"Retrained model saved to {retrained_model_path}")

In [None]:
def load_model_for_inference(model_path, src_vocab, tgt_vocab):
    """
    Load a trained model for inference.
    
    Args:
        model_path: Path to the model checkpoint
        src_vocab: Source vocabulary
        tgt_vocab: Target vocabulary
        
    Returns:
        model: The loaded model on the appropriate device
        config: Configuration object
    """
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract configuration
    config_dict = checkpoint['config']
    
    # Create a Config object with attributes for model initialization
    class Config:
        def __init__(self, config_dict):
            for key, value in config_dict.items():
                setattr(self, key, value)
    
    config_obj = Config(config_dict)
    
    # Print configuration
    print("Model configuration:")
    for key, value in config_dict.items():
        print(f"  {key}: {value}")
    
    # Initialize model with the loaded configuration
    model = Seq2Seq(config_obj, len(src_vocab), len(tgt_vocab)).to(device)
    
    # Load the model weights
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Set model to evaluation mode
    model.eval()
    
    return model, config_obj

# Usage example:
model, cfg = load_model_for_inference(retrained_model_path, src_vocab, tgt_vocab)
print(f"Model loaded successfully on {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

In [None]:
# Load vocabulary
with open(os.path.join(vocab_path, 'src.json'), 'r', encoding='utf-8') as f:
    src_vocab = json.load(f)
with open(os.path.join(vocab_path, 'tgt.json'), 'r', encoding='utf-8') as f:
    tgt_vocab = json.load(f)

# Load test dataset
test_dataset = tsvtokenizer(test_tsv, src_vocab, tgt_vocab)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Load best model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(retrained_model_path, map_location=device)
config_dict = checkpoint.get('config', {})
# Convert to a config object with attributes
cfg = type('Config', (), config_dict)
model = Seq2Seq(cfg, len(src_vocab), len(tgt_vocab))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval() 

In [None]:
# Evaluate model using exact_match_accuracy
accuracy = exact_match_accuracy(model, test_loader, device, tgt_vocab)
print(f"Test accuracy: {accuracy:.2f}%")

In [None]:
print(compute_metrics(model, test_loader, device,src_vocab, tgt_vocab))

In [None]:
# Get sample predictions
sample_predictions = get_predictions(model, test_loader, device, src_vocab, tgt_vocab, num_samples=20)

In [None]:
sample_predictions

In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from matplotlib.font_manager import FontProperties
import os
import urllib.request
from pathlib import Path


# Get the Hindi font
hindi_font = FontProperties(fname='/kaggle/input/ttf-fonts/NotoSansDevanagari-VariableFont_wdth,wght.ttf', size=18)

In [None]:
# Display sample predictions in a nice grid
plt.figure(figsize=(15, 10))
for i, pred in enumerate(sample_predictions[:9]):
    plt.subplot(3, 3, i + 1)
    
    # Create text for display
    text = f"Input: {pred['source']}\n"
    text += f"Target: {pred['target']}\n"
    text += f"Prediction: {pred['prediction']}\n"
    text += f"Status: {'True' if pred['is_correct'] else 'False'}"
    
    # Set background color based on correctness
    bg_color = '#e6ffe6' if pred['is_correct'] else '#ffe6e6'
    
    # Create a text box
    plt.text(0.5, 0.5, text, ha='center', va='center', wrap=True, 
             bbox=dict(facecolor=bg_color, alpha=0.8, boxstyle='round,pad=1'),fontproperties=hindi_font)
    plt.axis('off')

plt.tight_layout()
plt.savefig('sample_predictions.png')
plt.show()

In [None]:
def save_all_predictions(model, loader, device, src_vocab, tgt_vocab, output_file):
    """
    Save all predictions to a file using the generate function.
    
    Args:
        model: Trained Seq2Seq model
        loader: DataLoader with evaluation data
        device: Device to run inference on
        src_vocab: Source vocabulary (char -> idx)
        tgt_vocab: Target vocabulary (char -> idx)
        output_file: Path to output JSON file
    """
    model.eval()
    all_predictions = []
    
    # Create reverse vocabulary mappings
    idx_to_src = {idx: char for char, idx in src_vocab.items()}
    idx_to_tgt = {idx: char for char, idx in tgt_vocab.items()}
    
    # Get indices for special tokens
    eos_idx = tgt_vocab.get('<eos>', 3)
    sos_idx = tgt_vocab.get('<sos>', 2)
    
    with torch.no_grad():
        for src, tgt in loader:
            # Move data to device
            src, tgt = src.to(device), tgt.to(device)
            
            # Generate predictions using the specialized generate function
            generated_sequences = generate(
                model,
                src,
                max_len=tgt.size(1),  # Use same length as target sequences
                eos_idx=eos_idx,
                sos_idx=sos_idx,
                device=device
            )
            
            # Decode all sequences
            src_texts = decode_batch(src, idx_to_src, eos_idx=eos_idx)
            tgt_texts = decode_batch(tgt, idx_to_tgt, eos_idx=eos_idx)
            pred_texts = decode_batch(generated_sequences, idx_to_tgt, eos_idx=eos_idx)
            
            # Process batch
            for i in range(src.size(0)):
                # Check if prediction is correct (after removing any special tokens)
                is_correct = (tgt_texts[i] == pred_texts[i])
                
                all_predictions.append({
                    'source': src_texts[i],
                    'target': tgt_texts[i],
                    'prediction': pred_texts[i],
                    'is_correct': is_correct
                })
    
    # Save to file
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(all_predictions, f, ensure_ascii=False, indent=2)
    
    return all_predictions

In [None]:
# Save all predictions
all_predictions = save_all_predictions(model, test_loader, device, src_vocab, tgt_vocab, 
                                      'predictions_vanilla/test_predictions.json')

In [None]:
from collections import Counter
# Analyze errors
def analyze_errors(predictions):
    # Initialize counters
    total = len(predictions)
    correct = sum(1 for p in predictions if p['target'] == p['prediction'])
    
    # Calculate error rate by sequence length
    length_errors = {}
    for p in predictions:
        length = len(p['target'])
        if length not in length_errors:
            length_errors[length] = {'total': 0, 'errors': 0}
        
        length_errors[length]['total'] += 1
        if p['target'] != p['prediction']:
            length_errors[length]['errors'] += 1
    
    # Convert to error rates
    length_error_rates = {
        length: data['errors'] / data['total'] * 100 
        for length, data in length_errors.items() if data['total'] > 0
    }
    
    # Analyze character-level errors
    char_errors = []
    for p in predictions:
        if p['target'] != p['prediction']:
            # Find character-level differences
            min_len = min(len(p['target']), len(p['prediction']))
            for i in range(min_len):
                if p['target'][i] != p['prediction'][i]:
                    char_errors.append((p['target'][i], p['prediction'][i]))
    
    # Count character error frequencies
    char_error_counts = Counter(char_errors)
    
    return {
        'accuracy': correct / total * 100,
        'length_error_rates': length_error_rates,
        'char_error_counts': char_error_counts
    }

In [None]:
# Analyze errors
error_analysis = analyze_errors(all_predictions)

In [None]:
# Plot error rate by sequence length
plt.figure(figsize=(10, 6))
lengths = sorted(error_analysis['length_error_rates'].keys())
error_rates = [error_analysis['length_error_rates'][length] for length in lengths]

plt.bar(lengths, error_rates)
plt.xlabel('Sequence Length')
plt.ylabel('Error Rate (%)')
plt.title('Error Rate by Sequence Length')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.savefig('error_by_length.png')
plt.show()

In [None]:
# Plot top character-level errors
plt.figure(figsize=(12, 6))
top_errors = error_analysis['char_error_counts'].most_common(10)
error_chars = [f"{true}->{pred}" for (true, pred) in [err[0] for err in top_errors]]
error_counts = [count for _, count in top_errors]

plt.bar(error_chars, error_counts)
plt.xlabel('Character Error (True->Predicted)',fontproperties=hindi_font)
plt.ylabel('Count',fontproperties=hindi_font)
plt.title('Top 10 Character-Level Errors',fontproperties=hindi_font)
plt.xticks(rotation=45,fontproperties=hindi_font)
plt.tight_layout()
plt.savefig('top_char_errors.png')
plt.show()

In [None]:
print(f"Overall accuracy: {error_analysis['accuracy']:.2f}%")
print("Top 5 character-level errors:")
for i, ((true_char, pred_char), count) in enumerate(error_analysis['char_error_counts'].most_common(5)):
    print(f"  {i+1}. '{true_char}' -> '{pred_char}': {count} times")