In [2]:
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

my_secret = user_secrets.get_secret("wandb_api") 

wandb.login(key=my_secret)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mda24m016[0m ([33mda24m016-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
import os
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import wandb
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# ---- Dataset Definition with special tokens ----
class DakshinaTSVDataset(Dataset):
    def __init__(self, tsv_file, src_vocab=None, tgt_vocab=None, max_len=64, build_vocab=False):
        df = pd.read_csv(tsv_file, sep='\t', header=None,
                         names=['native', 'roman', 'freq'], usecols=[0, 1], dtype=str)
        # Fix the pandas warning by using a copy
        df = df.copy()
        df['native'] = df['native'].fillna('')
        df['roman'] = df['roman'].fillna('')
        self.pairs = list(zip(df['roman'], df['native']))
        print(f"Loaded {len(self.pairs)} examples from {tsv_file}")
        
        # Print a few examples
        if len(self.pairs) > 0:
            print("Sample examples:")
            for i in range(min(3, len(self.pairs))):
                print(f"  Roman: '{self.pairs[i][0]}', Native: '{self.pairs[i][1]}'")
                
        self.max_len = max_len
        
        if build_vocab:
            self.src_vocab = {'<pad>': 0, '<unk>': 1, '<eos>': 2, '<sos>': 3}
            self.tgt_vocab = {'<pad>': 0, '<unk>': 1, '<eos>': 2, '<sos>': 3}
            self._build_vocab()
        else:
            self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab
            # Ensure special tokens exist
            for v in ('<eos>', '<sos>'):
                if v not in self.src_vocab: self.src_vocab[v] = len(self.src_vocab)
                if v not in self.tgt_vocab: self.tgt_vocab[v] = len(self.tgt_vocab)

    def _build_vocab(self):
        for src, tgt in self.pairs:
            for ch in src:
                if ch not in self.src_vocab: self.src_vocab[ch] = len(self.src_vocab)
            for ch in tgt:
                if ch not in self.tgt_vocab: self.tgt_vocab[ch] = len(self.tgt_vocab)
        print(f"Vocab sizes -> src: {len(self.src_vocab)}, tgt: {len(self.tgt_vocab)}")

    def __len__(self): return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        
        # Add <sos> and <eos> tokens
        src_idxs = [self.src_vocab['<sos>']] + [self.src_vocab.get(ch, self.src_vocab['<unk>']) for ch in src] + [self.src_vocab['<eos>']]
        tgt_idxs = [self.tgt_vocab['<sos>']] + [self.tgt_vocab.get(ch, self.tgt_vocab['<unk>']) for ch in tgt] + [self.tgt_vocab['<eos>']]
        
        # Pad sequences
        pad_src = [self.src_vocab['<pad>']] * max(0, self.max_len - len(src_idxs))
        pad_tgt = [self.tgt_vocab['<pad>']] * max(0, self.max_len - len(tgt_idxs))
        
        # Truncate if necessary and convert to tensor
        src_tensor = torch.tensor((src_idxs + pad_src)[:self.max_len], dtype=torch.long)
        tgt_tensor = torch.tensor((tgt_idxs + pad_tgt)[:self.max_len], dtype=torch.long)
        
        return src_tensor, tgt_tensor

# ---- Encoder with bidirectional support ----
class Encoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers, dropout=0, bidirectional=True, cell_type='lstm'):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.cell_type = cell_type.lower()
        
        self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0)
        
        if self.cell_type == 'lstm':
            self.rnn = nn.LSTM(
                embedding_size, 
                hidden_size, 
                num_layers=num_layers, 
                dropout=dropout if num_layers > 1 else 0,
                bidirectional=bidirectional,
                batch_first=True
            )
        elif self.cell_type == 'gru':
            self.rnn = nn.GRU(
                embedding_size, 
                hidden_size, 
                num_layers=num_layers, 
                dropout=dropout if num_layers > 1 else 0,
                bidirectional=bidirectional,
                batch_first=True
            )
        else:  # rnn
            self.rnn = nn.RNN(
                embedding_size, 
                hidden_size, 
                num_layers=num_layers, 
                dropout=dropout if num_layers > 1 else 0,
                bidirectional=bidirectional,
                batch_first=True
            )
            
        # Initialize weights
        for name, param in self.named_parameters():
            if 'weight' in name and 'embedding' not in name:
                nn.init.xavier_uniform_(param)
                
    def forward(self, x):
        # x: [batch_size, seq_len]
        batch_size = x.shape[0]
        
        # Create mask for attention
        mask = (x != 0).float()  # 0 is <pad>
        
        # Embedding
        embedded = self.embedding(x)  # [batch_size, seq_len, embedding_size]
        
        # Pass through RNN
        outputs, hidden = self.rnn(embedded)
        
        # Process hidden state based on RNN type
        if self.cell_type == 'lstm':
            hidden_state, cell_state = hidden
            
            if self.bidirectional:
                # Reshape hidden from [num_layers*2, batch_size, hidden_size]
                # to [num_layers, 2, batch_size, hidden_size]
                hidden_state = hidden_state.view(self.num_layers, 2, batch_size, self.hidden_size)
                cell_state = cell_state.view(self.num_layers, 2, batch_size, self.hidden_size)
                
                # Concatenate bidirectional states
                hidden_state = torch.cat([hidden_state[:, 0], hidden_state[:, 1]], dim=2)
                cell_state = torch.cat([cell_state[:, 0], cell_state[:, 1]], dim=2)
                
                # Final hidden state is now [num_layers, batch_size, hidden_size*2]
                hidden = (hidden_state, cell_state)
            
        else:  # GRU or RNN
            if self.bidirectional:
                # Reshape hidden from [num_layers*2, batch_size, hidden_size]
                # to [num_layers, 2, batch_size, hidden_size]
                hidden = hidden.view(self.num_layers, 2, batch_size, self.hidden_size)
                
                # Concatenate bidirectional states
                hidden = torch.cat([hidden[:, 0], hidden[:, 1]], dim=2)
                
                # Final hidden state is now [num_layers, batch_size, hidden_size*2]
        
        # For bidirectional, output is [batch_size, seq_len, hidden_size*2]
        return outputs, hidden, mask

# ---- Attention Mechanism ----
class Attention(nn.Module):
    def __init__(self, enc_hidden_size, dec_hidden_size):
        super().__init__()
        # Create a linear layer to convert the concatenated hidden states to attention scores
        self.energy = nn.Linear(enc_hidden_size + dec_hidden_size, dec_hidden_size)
        self.v = nn.Linear(dec_hidden_size, 1, bias=False)
        
    def forward(self, hidden, encoder_outputs, mask):
        # hidden: [batch_size, dec_hidden_size]
        # encoder_outputs: [batch_size, src_len, enc_hidden_size]
        # mask: [batch_size, src_len]
        
        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]
        
        # Repeat decoder hidden state src_len times
        # [batch_size, dec_hidden_size] -> [batch_size, src_len, dec_hidden_size]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        # Create energy by concatenating encoder outputs and decoder hidden
        # [batch_size, src_len, enc_hidden_size + dec_hidden_size]
        energy = torch.cat((hidden, encoder_outputs), dim=2)
        
        # Apply attention layer
        # [batch_size, src_len, dec_hidden_size]
        energy = torch.tanh(self.energy(energy))
        
        # Get attention scores
        # [batch_size, src_len, 1]
        attention = self.v(energy)
        
        # [batch_size, src_len]
        attention = attention.squeeze(2)
        
        # Mask out padding positions
        attention = attention.masked_fill(mask == 0, -1e10)
        
        # Apply softmax to get probabilities
        # [batch_size, src_len]
        return F.softmax(attention, dim=1)

# ---- Decoder with attention and teacher forcing ----
class Decoder(nn.Module):
    def __init__(self, output_size, embedding_size, enc_hidden_size, dec_hidden_size, 
                 num_layers, dropout=0, cell_type='lstm'):
        super().__init__()
        self.output_size = output_size
        self.dec_hidden_size = dec_hidden_size
        self.enc_hidden_size = enc_hidden_size
        self.num_layers = num_layers
        self.cell_type = cell_type.lower()
        
        # Initialize embedding layer
        self.embedding = nn.Embedding(output_size, embedding_size, padding_idx=0)
        
        # Initialize attention mechanism
        self.attention = Attention(enc_hidden_size, dec_hidden_size)
        
        # Context vector + embedding size as input to RNN
        rnn_input_size = embedding_size + enc_hidden_size
        
        # Initialize RNN based on cell type
        if self.cell_type == 'lstm':
            self.rnn = nn.LSTM(
                rnn_input_size, 
                dec_hidden_size, 
                num_layers=num_layers, 
                dropout=dropout if num_layers > 1 else 0,
                batch_first=True
            )
        elif self.cell_type == 'gru':
            self.rnn = nn.GRU(
                rnn_input_size, 
                dec_hidden_size, 
                num_layers=num_layers, 
                dropout=dropout if num_layers > 1 else 0,
                batch_first=True
            )
        else:  # rnn
            self.rnn = nn.RNN(
                rnn_input_size, 
                dec_hidden_size, 
                num_layers=num_layers, 
                dropout=dropout if num_layers > 1 else 0,
                batch_first=True
            )
        
        # Final output layer that combines decoder output, context and embedding
        self.fc_out = nn.Linear(dec_hidden_size + enc_hidden_size + embedding_size, output_size)
        
        # Initialize weights using Xavier initialization
        for name, param in self.named_parameters():
            if 'weight' in name and 'embedding' not in name:
                nn.init.xavier_uniform_(param)
                
    def forward(self, input, hidden, encoder_outputs, mask):
        # input: [batch_size]
        # hidden: [num_layers, batch_size, dec_hidden_size] or tuple for LSTM
        # encoder_outputs: [batch_size, src_len, enc_hidden_size]
        # mask: [batch_size, src_len]
        
        # Embed input token
        # [batch_size] -> [batch_size, 1, embedding_size]
        embedded = self.embedding(input).unsqueeze(1)
        
        # Get attention weights
        # [batch_size, src_len]
        if self.cell_type == 'lstm':
            h_for_attn = hidden[0][-1]  # use last layer's hidden state
        else:
            h_for_attn = hidden[-1]  # use last layer's hidden state
            
        attn_weights = self.attention(h_for_attn, encoder_outputs, mask)
        
        # Create context vector by weighting encoder outputs with attention
        # [batch_size, 1, src_len] * [batch_size, src_len, enc_hidden_size]
        # -> [batch_size, 1, enc_hidden_size]
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
        
        # Combine embedded token and context vector
        # [batch_size, 1, embedding_size + enc_hidden_size]
        rnn_input = torch.cat((embedded, context), dim=2)
        
        # Pass through RNN
        # output: [batch_size, 1, dec_hidden_size]
        # hidden: [num_layers, batch_size, dec_hidden_size] or tuple for LSTM
        output, hidden = self.rnn(rnn_input, hidden)
        
        # Combine output, context and embedding for final prediction
        # [batch_size, 1, dec_hidden_size + enc_hidden_size + embedding_size]
        output = torch.cat((output, context, embedded), dim=2)
        
        # Remove sequence dimension
        # [batch_size, dec_hidden_size + enc_hidden_size + embedding_size]
        output = output.squeeze(1)
        
        # Pass through final linear layer
        # [batch_size, output_size]
        prediction = self.fc_out(output)
        
        return prediction, hidden, attn_weights

# ---- Complete Seq2Seq Model with Teacher Forcing ----
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, teacher_forcing_ratio=0.7):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.teacher_forcing_ratio = teacher_forcing_ratio
        
    def forward(self, src, tgt):
        # src: [batch_size, src_len]
        # tgt: [batch_size, tgt_len]
        
        batch_size = src.shape[0]
        tgt_len = tgt.shape[1]
        tgt_vocab_size = self.decoder.output_size
        
        # Tensor to store outputs
        outputs = torch.zeros(batch_size, tgt_len-1, tgt_vocab_size).to(self.device)
        
        # Encode source
        encoder_outputs, hidden, mask = self.encoder(src)
        
        # First input to decoder is the <sos> token (already embedded in tgt)
        input = tgt[:, 0]
        
        # Teacher forcing ratio determines how often to use true target as input
        use_teacher_forcing = random.random() < self.teacher_forcing_ratio
        
        # Decode one token at a time
        for t in range(1, tgt_len):
            # Get output from decoder
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs, mask)
            
            # Store output
            outputs[:, t-1] = output
            
            # Next input is either true target (teacher forcing) or predicted token
            if use_teacher_forcing:
                input = tgt[:, t]
            else:
                # Get highest scoring token
                input = output.argmax(1)
                
        return outputs
    
    # For inference (no teacher forcing)
    def decode(self, src, max_len=100):
        # src: [batch_size, src_len]
        
        batch_size = src.shape[0]
        
        # Encode source
        encoder_outputs, hidden, mask = self.encoder(src)
        
        # First input is <sos> token
        input = torch.ones(batch_size, dtype=torch.long).to(self.device) * 3  # <sos> = 3
        
        # Track generated tokens
        outputs = [input]
        attentions = []
        
        # Track if sequence has ended
        ended = torch.zeros(batch_size, dtype=torch.bool).to(self.device)
        
        # Decode until max length or all sequences end
        for t in range(1, max_len):
            # Get output from decoder
            output, hidden, attn = self.decoder(input, hidden, encoder_outputs, mask)
            
            # Get next token
            input = output.argmax(1)
            
            # Store output
            outputs.append(input)
            attentions.append(attn)
            
            # Check if all sequences have ended
            ended = ended | (input == 2)  # 2 is <eos>
            if ended.all():
                break
                
        # Convert list of tensors to single tensor
        outputs = torch.stack(outputs, dim=1)  # [batch_size, seq_len]
        attentions = torch.stack(attentions, dim=1)  # [batch_size, seq_len-1, src_len]
        
        return outputs, attentions

# ---- Metrics & Utils ----
def compute_exact_match_accuracy(preds, targets, tgt_vocab):
    """Compute exact match accuracy between predictions and targets"""
    batch_size = preds.size(0)
    correct = 0
    
    # Convert ids to strings
    id_to_char = {v: k for k, v in tgt_vocab.items() if k not in ['<pad>', '<sos>', '<eos>', '<unk>']}
    
    for i in range(batch_size):
        # Extract character sequences (removing special tokens)
        pred_seq = ''.join([id_to_char.get(idx.item(), '') for idx in preds[i, 1:] 
                            if idx.item() not in [0, 1, 2, 3]])  # Skip <pad>, <unk>, <eos>, <sos>
        
        # For target, skip first token (<sos>) and stop at <eos> or <pad>
        tgt_seq = ''
        for idx in targets[i, 1:]:  # Skip first token
            token_id = idx.item()
            if token_id in [0, 2]:  # <pad> or <eos>
                break
            if token_id not in [1, 3]:  # Skip <unk> and <sos>
                tgt_seq += id_to_char.get(token_id, '')
        
        # Check for exact match
        if pred_seq == tgt_seq:
            correct += 1
    
    return correct / batch_size

def compute_char_accuracy(logits, targets):
    """Compute character-level accuracy between logits and targets"""
    preds = logits.argmax(dim=-1)
    mask = (targets != 0)  # Ignore padding
    correct = ((preds == targets) & mask).sum().item()
    total = mask.sum().item()
    return correct / total if total > 0 else 0

# ---- Training & Evaluation Functions ----
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    epoch_char_acc = 0
    epoch_exact_match_acc = 0
    total_batches = 0
    
    for src, tgt in tqdm(dataloader, desc="Training"):
        batch_size = src.size(0)
        src, tgt = src.to(device), tgt.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output = model(src, tgt)
        
        # Flatten output and target tensors for loss calculation
        # Ignore the first token in target (<sos>)
        output_flat = output.reshape(-1, output.shape[-1])
        target_flat = tgt[:, 1:].reshape(-1)  # Shift right to predict next token
        
        # Calculate loss
        loss = criterion(output_flat, target_flat)
        
        # Backward pass
        loss.backward()
        
        # Clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Update parameters
        optimizer.step()
        
        # Calculate metrics
        char_acc = compute_char_accuracy(output, tgt[:, 1:])
        
        # Decode for exact match accuracy
        with torch.no_grad():
            predictions, _ = model.decode(src)
            exact_match_acc = compute_exact_match_accuracy(predictions, tgt, dataloader.dataset.tgt_vocab)
        
        # Accumulate metrics
        epoch_loss += loss.item() * batch_size
        epoch_char_acc += char_acc * batch_size
        epoch_exact_match_acc += exact_match_acc * batch_size
        total_batches += batch_size
    
    # Return average metrics
    return {
        'loss': epoch_loss / total_batches,
        'char_acc': epoch_char_acc / total_batches,
        'exact_match_acc': epoch_exact_match_acc / total_batches
    }

def evaluate(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0
    epoch_char_acc = 0
    epoch_exact_match_acc = 0
    total_batches = 0
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        for src, tgt in tqdm(dataloader, desc="Evaluating"):
            batch_size = src.size(0)
            src, tgt = src.to(device), tgt.to(device)
            
            # Forward pass (use teacher forcing for loss calculation)
            output = model(src, tgt)
            
            # Flatten output and target tensors for loss calculation
            output_flat = output.reshape(-1, output.shape[-1])
            target_flat = tgt[:, 1:].reshape(-1)  # Shift right to predict next token
            
            # Calculate loss
            loss = criterion(output_flat, target_flat)
            
            # Calculate metrics
            char_acc = compute_char_accuracy(output, tgt[:, 1:])
            
            # Decode for exact match accuracy (no teacher forcing)
            predictions, _ = model.decode(src)
            exact_match_acc = compute_exact_match_accuracy(predictions, tgt, dataloader.dataset.tgt_vocab)
            
            # Count exact matches for reporting
            correct_batch = int(exact_match_acc * batch_size)
            correct_predictions += correct_batch
            total_predictions += batch_size
            
            # Accumulate metrics
            epoch_loss += loss.item() * batch_size
            epoch_char_acc += char_acc * batch_size
            epoch_exact_match_acc += exact_match_acc * batch_size
            total_batches += batch_size
    
    # Return average metrics
    return {
        'loss': epoch_loss / total_batches,
        'char_acc': epoch_char_acc / total_batches,
        'exact_match_acc': epoch_exact_match_acc / total_batches,
        'correct': correct_predictions,
        'total': total_predictions
    }

# ---- WandB Sweep Configuration ----
sweep_config = {
    "name": "Seq2Seq",
    "method": "bayes",
    'metric': {
        'name': 'validation_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'cell_type': {
            'values': ['lstm', 'gru', 'rnn']
        },
        'dropout': {
            'values': [0, 0.1, 0.2, 0.5]
        },
        'embedding_size': {
            'values': [64, 128, 256, 512]
        },
        'num_layers': {
            'values': [2, 3, 4]
        },
        'batch_size': {
            'values': [32, 64, 128]
        },
        'hidden_size': {
            'values': [128, 256, 512]
        },
        'bidirectional': {
            'values': [True, False]
        },
        'learning_rate': {
            "values": [0.001, 0.002, 0.0001, 0.0002]
        },
        'epochs': {
            'values': [15]
        },
        'optim': {
            "values": ['adam']
        },
        'teacher_forcing': {
            "values": [0.2, 0.5, 0.7]
        }
    }
}

# ---- WandB Sweep Function ----
def sweep_run():
    # Initialize WandB run
    run = wandb.init()
    
    # Get hyperparameters from sweep
    config = wandb.config
    
    # Create run name
    run_name = f"{config.cell_type}-e{config.embedding_size}-h{config.hidden_size}-n{config.num_layers}-d{config.dropout}-b{config.bidirectional}-tf{config.teacher_forcing}-lr{config.learning_rate}-bs{config.batch_size}-{config.optim}"
    wandb.run.name = run_name
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data paths
    train_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.train.tsv'
    dev_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.dev.tsv'
    test_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.test.tsv'
    vocab_dir = '/kaggle/working/vocab'
    model_dir = '/kaggle/working/models'
    
    # Create directories
    os.makedirs(vocab_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    
    # Load or build vocabulary
    vocab_file = os.path.join(vocab_dir, 'src_vocab.json')
    if os.path.exists(vocab_file):
        with open(os.path.join(vocab_dir, 'src_vocab.json'), 'r') as f:
            src_vocab = json.load(f)
        with open(os.path.join(vocab_dir, 'tgt_vocab.json'), 'r') as f:
            tgt_vocab = json.load(f)
        print("Loaded existing vocabulary")
    else:
        print("Building new vocabulary")
        train_dataset = DakshinaTSVDataset(train_tsv, build_vocab=True)
        src_vocab, tgt_vocab = train_dataset.src_vocab, train_dataset.tgt_vocab
        
        # Save vocabulary
        with open(os.path.join(vocab_dir, 'src_vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(src_vocab, f, ensure_ascii=False)
        with open(os.path.join(vocab_dir, 'tgt_vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(tgt_vocab, f, ensure_ascii=False)
        print("Saved vocabulary")
    
    # Create datasets
    train_dataset = DakshinaTSVDataset(train_tsv, src_vocab, tgt_vocab)
    val_dataset = DakshinaTSVDataset(dev_tsv, src_vocab, tgt_vocab)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
    
    # Create model components
    encoder = Encoder(
        input_size=len(src_vocab),
        embedding_size=config.embedding_size,
        hidden_size=config.hidden_size,
        num_layers=config.num_layers,
        dropout=config.dropout,
        bidirectional=config.bidirectional,
        cell_type=config.cell_type
    )
    
    # Calculate encoder output size (doubled if bidirectional)
    enc_hidden_size = config.hidden_size * 2 if config.bidirectional else config.hidden_size
    
    decoder = Decoder(
        output_size=len(tgt_vocab),
        embedding_size=config.embedding_size,
        enc_hidden_size=enc_hidden_size,
        dec_hidden_size=config.hidden_size,
        num_layers=config.num_layers,
        dropout=config.dropout,
        cell_type=config.cell_type
    )
    
    # Create full model
    model = Seq2Seq(encoder, decoder, device, teacher_forcing_ratio=config.teacher_forcing)
    model = model.to(device)
    
    # Print model size
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {total_params:,} parameters ({trainable_params:,} trainable)")
    
    # Loss function (ignore padding token)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    # Optimizer
    if config.optim == 'nadam':
        try:
            optimizer = optim.NAdam(model.parameters(), lr=config.learning_rate)
        except AttributeError:
            print("NAdam optimizer not available, falling back to Adam")
            optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    else:
        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    
    # Training loop
    best_val_acc = 0
    
    for epoch in range(config.epochs):
        print(f"\nEpoch {epoch+1}/{config.epochs}")
        
        # Train
        train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Evaluate
        val_metrics = evaluate(model, val_loader, criterion, device)
        
        # Print metrics
        print(f"Train - Loss: {train_metrics['loss']:.4f}, Char Acc: {train_metrics['char_acc']:.4f}, "
              f"Exact Match: {train_metrics['exact_match_acc']:.4f}")
        print(f"Val - Loss: {val_metrics['loss']:.4f}, Char Acc: {val_metrics['char_acc']:.4f}, "
              f"Exact Match: {val_metrics['exact_match_acc']:.4f} ({val_metrics['correct']}/{val_metrics['total']})")
        
        # Convert exact match to percentage for wandb
        val_accuracy_percent = val_metrics['exact_match_acc'] * 100
        
        # Log to WandB
        wandb.log({
            'epoch': epoch + 1,
            'train_loss': train_metrics['loss'],
            'train_char_accuracy': train_metrics['char_acc'],
            'train_exact_match': train_metrics['exact_match_acc'],
            'val_loss': val_metrics['loss'],
            'val_char_accuracy': val_metrics['char_acc'],
            'val_exact_match': val_metrics['exact_match_acc'],
            'validation_accuracy': val_accuracy_percent  # This matches the metric name in sweep_config
        })
        
        # Save best model
        if val_metrics['exact_match_acc'] > best_val_acc:
            best_val_acc = val_metrics['exact_match_acc']
            
            # Save model
            model_path = os.path.join(model_dir, f"{run_name}_best.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['loss'],
                'val_accuracy': val_metrics['exact_match_acc'],
                'config': {k: v for k, v in config.__dict__.items() if not k.startswith('_')}
            }, model_path)
            
            # Create a new artifact for this model
            artifact_name = f"model-{run.id}-epoch{epoch+1}"
            artifact = wandb.Artifact(artifact_name, type="model")
            artifact.add_file(model_path)
            run.log_artifact(artifact)
            
            print(f"Saved new best model with validation accuracy: {best_val_acc:.4f}")



In [None]:
# ---- Main Function ----
if __name__ == "__main__":
    # Initialize a new sweep
    sweep_id = wandb.sweep(sweep_config, project="dakshina-transliteration")
    
    # Start the sweep agent
    # You can adjust count to determine how many runs to do
    wandb.agent(sweep_id, sweep_run, count=50)  # Run 10 trials
    
    print("Sweep completed!.")

In [None]:
# gru-e512-h256-n2-d0.1-bFalse-tf0.7-lr0.002-bs128-adam

In [1]:
# ---- Function to run with best parameters ----
def run_best_params():
    # Best parameters as provided
    params = {
        'cell_type': 'gru',
        'dropout': 0.1,
        'num_layers': 2,
        'batch_size': 128,
        'hidden_size': 256,
        'embedding_size': 512,
        'bidirectional': False,
        'learning_rate': 0.002,
        'epochs': 15,
        'optim': 'adam',
        'teacher_forcing': 0.7
    }
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data paths
    train_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.train.tsv'
    dev_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.dev.tsv'
    test_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.test.tsv'
    vocab_dir = '/kaggle/working/vocab_best'
    model_dir = '/kaggle/working/models_best'
    
    # Create directories
    os.makedirs(vocab_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    
    # Initialize wandb if needed
    use_wandb = False  # Set to True if you want to use wandb
    if use_wandb:
        wandb.init(project="dakshina-transliteration", config=params)
    
    # Load or build vocabulary
    vocab_file = os.path.join(vocab_dir, 'src_vocab.json')
    if os.path.exists(vocab_file):
        with open(os.path.join(vocab_dir, 'src_vocab.json'), 'r') as f:
            src_vocab = json.load(f)
        with open(os.path.join(vocab_dir, 'tgt_vocab.json'), 'r') as f:
            tgt_vocab = json.load(f)
        print("Loaded existing vocabulary")
    else:
        print("Building new vocabulary")
        train_dataset = DakshinaTSVDataset(train_tsv, build_vocab=True)
        src_vocab, tgt_vocab = train_dataset.src_vocab, train_dataset.tgt_vocab
        
        # Save vocabulary
        with open(os.path.join(vocab_dir, 'src_vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(src_vocab, f, ensure_ascii=False)
        with open(os.path.join(vocab_dir, 'tgt_vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(tgt_vocab, f, ensure_ascii=False)
        print("Saved vocabulary")
    
    # Create datasets
    train_dataset = DakshinaTSVDataset(train_tsv, src_vocab, tgt_vocab)
    val_dataset = DakshinaTSVDataset(dev_tsv, src_vocab, tgt_vocab)
    test_dataset = DakshinaTSVDataset(test_tsv, src_vocab, tgt_vocab)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=params['batch_size'])
    test_loader = DataLoader(test_dataset, batch_size=params['batch_size'])
    
    print(f"Loaded datasets - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    
    # Create model components
    encoder = Encoder(
        input_size=len(src_vocab),
        embedding_size=params['embedding_size'],
        hidden_size=params['hidden_size'],
        num_layers=params['num_layers'],
        dropout=params['dropout'],
        bidirectional=params['bidirectional'],
        cell_type=params['cell_type']
    )
    
    # Calculate encoder output size (doubled if bidirectional)
    enc_hidden_size = params['hidden_size'] * 2 if params['bidirectional'] else params['hidden_size']
    
    decoder = Decoder(
        output_size=len(tgt_vocab),
        embedding_size=params['embedding_size'],
        enc_hidden_size=enc_hidden_size,
        dec_hidden_size=params['hidden_size'],
        num_layers=params['num_layers'],
        dropout=params['dropout'],
        cell_type=params['cell_type']
    )
    
    # Create full model
    model = Seq2Seq(encoder, decoder, device, teacher_forcing_ratio=params['teacher_forcing'])
    model = model.to(device)
    
    # Print model size
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {total_params:,} parameters ({trainable_params:,} trainable)")
    
    # Loss function (ignore padding token)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    # Optimizer
    if params['optim'] == 'nadam':
        try:
            optimizer = optim.NAdam(model.parameters(), lr=params['learning_rate'])
        except AttributeError:
            print("NAdam optimizer not available, falling back to Adam")
            optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])
    else:
        optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])
    
    # Training loop
    best_val_acc = 0
    best_epoch = 0
    model_path = os.path.join(model_dir, "best_model.pt")
    
    for epoch in range(params['epochs']):
        print(f"\nEpoch {epoch+1}/{params['epochs']}")
        
        # Train
        train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Evaluate
        val_metrics = evaluate(model, val_loader, criterion, device)
        
        # Print metrics
        print(f"Train - Loss: {train_metrics['loss']:.4f}, Char Acc: {train_metrics['char_acc']:.4f}, "
              f"Exact Match: {train_metrics['exact_match_acc']:.4f}")
        print(f"Val - Loss: {val_metrics['loss']:.4f}, Char Acc: {val_metrics['char_acc']:.4f}, "
              f"Exact Match: {val_metrics['exact_match_acc']:.4f} ({val_metrics['correct']}/{val_metrics['total']})")
        
        # Log to WandB
        if use_wandb:
            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_metrics['loss'],
                'train_char_accuracy': train_metrics['char_acc'],
                'train_exact_match': train_metrics['exact_match_acc'],
                'val_loss': val_metrics['loss'],
                'val_char_accuracy': val_metrics['char_acc'],
                'val_exact_match': val_metrics['exact_match_acc']
            })
        
        # Save best model
        if val_metrics['exact_match_acc'] > best_val_acc:
            best_val_acc = val_metrics['exact_match_acc']
            best_epoch = epoch + 1
            
            # Save model
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['loss'],
                'val_accuracy': val_metrics['exact_match_acc'],
                'params': params
            }, model_path)
            
            print(f"Saved new best model with validation accuracy: {best_val_acc:.4f}")
    
    print(f"\nTraining complete. Best validation accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
    
    # Load best model for testing
    print("\nLoading best model for testing...")
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate on test set
    test_metrics = evaluate(model, test_loader, criterion, device)
    
    print(f"\nTest Results:")
    print(f"Loss: {test_metrics['loss']:.4f}")
    print(f"Character Accuracy: {test_metrics['char_acc']:.4f}")
    print(f"Exact Match Accuracy: {test_metrics['exact_match_acc']:.4f} "
          f"({test_metrics['correct']}/{test_metrics['total']})")
    
    # Log final results to WandB
    if use_wandb:
        wandb.log({
            'test_loss': test_metrics['loss'],
            'test_char_accuracy': test_metrics['char_acc'],
            'test_exact_match': test_metrics['exact_match_acc'],
            'best_val_accuracy': best_val_acc
        })
    
    # Display examples of correct and incorrect predictions
    print("\nAnalyzing predictions on test set...")
    model.eval()
    all_predictions = []
    all_targets = []
    all_sources = []
    
    with torch.no_grad():
        for src, tgt in test_loader:
            src, tgt = src.to(device), tgt.to(device)
            predictions, _ = model.decode(src)
            
            # Convert to readable strings
            id_to_char = {v: k for k, v in test_dataset.tgt_vocab.items() 
                         if k not in ['<pad>', '<sos>', '<eos>', '<unk>']}
            
            for i in range(src.size(0)):
                # Source (roman)
                src_str = ''.join([test_dataset.src_vocab.get(str(idx.item()), '') 
                                  for idx in src[i] if idx.item() not in [0, 1, 2, 3]])
                
                # Target (native)
                tgt_str = ''.join([id_to_char.get(idx.item(), '') 
                                  for idx in tgt[i, 1:] if idx.item() not in [0, 1, 2, 3]])
                
                # Prediction
                pred_str = ''.join([id_to_char.get(idx.item(), '') 
                                   for idx in predictions[i, 1:] if idx.item() not in [0, 1, 2, 3]])
                
                all_sources.append(src_str)
                all_targets.append(tgt_str)
                all_predictions.append(pred_str)
    
    # Get correct and incorrect examples
    correct_examples = [(s, t, p) for s, t, p in zip(all_sources, all_targets, all_predictions) if t == p]
    incorrect_examples = [(s, t, p) for s, t, p in zip(all_sources, all_targets, all_predictions) if t != p]
    
    # Display some correct examples
    print(f"\nCorrect Examples ({len(correct_examples)} total):")
    for i, (src, tgt, pred) in enumerate(correct_examples[:5]):
        print(f"{i+1}. Roman: '{src}'")
        print(f"   Native: '{tgt}'")
    
    # Display some incorrect examples
    print(f"\nIncorrect Examples ({len(incorrect_examples)} total):")
    for i, (src, tgt, pred) in enumerate(incorrect_examples[:5]):
        print(f"{i+1}. Roman: '{src}'")
        print(f"   Native (correct): '{tgt}'")
        print(f"   Prediction: '{pred}'")
    
    return {
        'val_accuracy': best_val_acc,
        'test_accuracy': test_metrics['exact_match_acc'],
        'correct': test_metrics['correct'],
        'total': test_metrics['total']
    }



In [4]:
run_best_params()

Using device: cuda
Building new vocabulary
Loaded 58550 examples from /kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.train.tsv
Sample examples:
  Roman: 'amkita', Native: 'అంకిత'
  Roman: 'ankita', Native: 'అంకిత'
  Roman: 'ankitha', Native: 'అంకిత'
Vocab sizes -> src: 30, tgt: 67
Saved vocabulary
Loaded 58550 examples from /kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.train.tsv
Sample examples:
  Roman: 'amkita', Native: 'అంకిత'
  Roman: 'ankita', Native: 'అంకిత'
  Roman: 'ankitha', Native: 'అంకిత'
Loaded 5683 examples from /kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.dev.tsv
Sample examples:
  Roman: 'amka', Native: 'అంక'
  Roman: 'anka', Native: 'అంక'
  Roman: 'amkam', Native: 'అంకం'
Loaded 5747 examples from /kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.test.tsv
Sample examples:
  Roman: 'amkamlo', Native: 'అంకంలో'
  Roman: 'ank

Training: 100%|██████████| 458/458 [02:05<00:00,  3.65it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.28it/s]


Train - Loss: 0.7351, Char Acc: 0.8005, Exact Match: 0.3558
Val - Loss: 0.4398, Char Acc: 0.8769, Exact Match: 0.4392 (2496/5683)
Saved new best model with validation accuracy: 0.4392

Epoch 2/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.73it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.26it/s]


Train - Loss: 0.3903, Char Acc: 0.8963, Exact Match: 0.5273
Val - Loss: 0.3757, Char Acc: 0.8952, Exact Match: 0.4788 (2721/5683)
Saved new best model with validation accuracy: 0.4788

Epoch 3/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.74it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.32it/s]


Train - Loss: 0.3335, Char Acc: 0.9119, Exact Match: 0.5953
Val - Loss: 0.4505, Char Acc: 0.8772, Exact Match: 0.4767 (2709/5683)

Epoch 4/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.75it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.03it/s]


Train - Loss: 0.2941, Char Acc: 0.9224, Exact Match: 0.6007
Val - Loss: 0.4068, Char Acc: 0.8922, Exact Match: 0.5147 (2925/5683)
Saved new best model with validation accuracy: 0.5147

Epoch 5/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.74it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.24it/s]


Train - Loss: 0.2846, Char Acc: 0.9250, Exact Match: 0.6311
Val - Loss: 0.3864, Char Acc: 0.8985, Exact Match: 0.4753 (2701/5683)

Epoch 6/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.74it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  6.80it/s]


Train - Loss: 0.2517, Char Acc: 0.9337, Exact Match: 0.6358
Val - Loss: 0.4362, Char Acc: 0.8837, Exact Match: 0.4742 (2695/5683)

Epoch 7/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.74it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.26it/s]


Train - Loss: 0.2649, Char Acc: 0.9304, Exact Match: 0.6496
Val - Loss: 0.4431, Char Acc: 0.8867, Exact Match: 0.5170 (2938/5683)
Saved new best model with validation accuracy: 0.5170

Epoch 8/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.74it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.21it/s]


Train - Loss: 0.2599, Char Acc: 0.9322, Exact Match: 0.6409
Val - Loss: 0.4238, Char Acc: 0.8870, Exact Match: 0.5061 (2876/5683)

Epoch 9/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.74it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.30it/s]


Train - Loss: 0.2530, Char Acc: 0.9334, Exact Match: 0.6291
Val - Loss: 0.4194, Char Acc: 0.8960, Exact Match: 0.5179 (2943/5683)
Saved new best model with validation accuracy: 0.5179

Epoch 10/15


Training: 100%|██████████| 458/458 [02:02<00:00,  3.74it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.21it/s]


Train - Loss: 0.2951, Char Acc: 0.9241, Exact Match: 0.5785
Val - Loss: 0.4505, Char Acc: 0.8881, Exact Match: 0.5098 (2897/5683)

Epoch 11/15


Training: 100%|██████████| 458/458 [02:01<00:00,  3.76it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  6.83it/s]


Train - Loss: 0.2555, Char Acc: 0.9329, Exact Match: 0.6267
Val - Loss: 0.4584, Char Acc: 0.8803, Exact Match: 0.4630 (2631/5683)

Epoch 12/15


Training: 100%|██████████| 458/458 [02:04<00:00,  3.68it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.14it/s]


Train - Loss: 0.2359, Char Acc: 0.9378, Exact Match: 0.6457
Val - Loss: 0.4066, Char Acc: 0.8952, Exact Match: 0.5094 (2895/5683)

Epoch 13/15


Training: 100%|██████████| 458/458 [02:04<00:00,  3.67it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.13it/s]


Train - Loss: 0.2397, Char Acc: 0.9368, Exact Match: 0.6254
Val - Loss: 0.4044, Char Acc: 0.8949, Exact Match: 0.4496 (2555/5683)

Epoch 14/15


Training: 100%|██████████| 458/458 [02:04<00:00,  3.67it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.11it/s]


Train - Loss: 0.2271, Char Acc: 0.9396, Exact Match: 0.6352
Val - Loss: 0.4307, Char Acc: 0.8942, Exact Match: 0.4804 (2730/5683)

Epoch 15/15


Training: 100%|██████████| 458/458 [02:04<00:00,  3.68it/s]
Evaluating: 100%|██████████| 45/45 [00:06<00:00,  6.79it/s]


Train - Loss: 0.2390, Char Acc: 0.9372, Exact Match: 0.6195
Val - Loss: 0.4555, Char Acc: 0.8836, Exact Match: 0.4774 (2713/5683)

Training complete. Best validation accuracy: 0.5179 at epoch 9

Loading best model for testing...


Evaluating: 100%|██████████| 45/45 [00:06<00:00,  7.09it/s]



Test Results:
Loss: 0.4370
Character Accuracy: 0.8912
Exact Match Accuracy: 0.5072 (2915/5747)

Analyzing predictions on test set...

Correct Examples (2915 total):
1. Roman: ''
   Native: 'అంకంలో'
2. Roman: ''
   Native: 'అంకంలో'
3. Roman: ''
   Native: 'అంకంలో'
4. Roman: ''
   Native: 'అంకితమై'
5. Roman: ''
   Native: 'అంకితమై'

Incorrect Examples (2832 total):
1. Roman: ''
   Native (correct): 'అంకెలను'
   Prediction: 'అంకేళను'
2. Roman: ''
   Native (correct): 'అంటారని'
   Prediction: 'అంతారణి'
3. Roman: ''
   Native (correct): 'అంటారని'
   Prediction: 'అంతరణీ'
4. Roman: ''
   Native (correct): 'అంటావా'
   Prediction: 'అంతావా'
5. Roman: ''
   Native (correct): 'అంటావా'
   Prediction: 'అంతవ'


{'val_accuracy': 0.5178602850607074,
 'test_accuracy': 0.507221158865495,
 'correct': 2915,
 'total': 5747}

In [None]:
# rnn-e512-h512-n4-d0.5-bFalse-tf0.7-lr0.0001-bs64-adam

In [5]:
# ---- Function to run with best parameters ----
def run_best_params():
    # Best parameters as provided
    params = {
        'cell_type': 'rnn',
        'dropout': 0.5,
        'num_layers': 4,
        'batch_size': 64,
        'hidden_size': 512,
        'embedding_size': 512,
        'bidirectional': False,
        'learning_rate': 0.0001,
        'epochs': 15,
        'optim': 'adam',
        'teacher_forcing': 0.7
    }
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data paths
    train_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.train.tsv'
    dev_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.dev.tsv'
    test_tsv = '/kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.test.tsv'
    vocab_dir = '/kaggle/working/vocab_best'
    model_dir = '/kaggle/working/models_best'
    
    # Create directories
    os.makedirs(vocab_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    
    # Initialize wandb if needed
    use_wandb = False  # Set to True if you want to use wandb
    if use_wandb:
        wandb.init(project="dakshina-transliteration", config=params)
    
    # Load or build vocabulary
    vocab_file = os.path.join(vocab_dir, 'src_vocab.json')
    if os.path.exists(vocab_file):
        with open(os.path.join(vocab_dir, 'src_vocab.json'), 'r') as f:
            src_vocab = json.load(f)
        with open(os.path.join(vocab_dir, 'tgt_vocab.json'), 'r') as f:
            tgt_vocab = json.load(f)
        print("Loaded existing vocabulary")
    else:
        print("Building new vocabulary")
        train_dataset = DakshinaTSVDataset(train_tsv, build_vocab=True)
        src_vocab, tgt_vocab = train_dataset.src_vocab, train_dataset.tgt_vocab
        
        # Save vocabulary
        with open(os.path.join(vocab_dir, 'src_vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(src_vocab, f, ensure_ascii=False)
        with open(os.path.join(vocab_dir, 'tgt_vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(tgt_vocab, f, ensure_ascii=False)
        print("Saved vocabulary")
    
    # Create datasets
    train_dataset = DakshinaTSVDataset(train_tsv, src_vocab, tgt_vocab)
    val_dataset = DakshinaTSVDataset(dev_tsv, src_vocab, tgt_vocab)
    test_dataset = DakshinaTSVDataset(test_tsv, src_vocab, tgt_vocab)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=params['batch_size'])
    test_loader = DataLoader(test_dataset, batch_size=params['batch_size'])
    
    print(f"Loaded datasets - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    
    # Create model components
    encoder = Encoder(
        input_size=len(src_vocab),
        embedding_size=params['embedding_size'],
        hidden_size=params['hidden_size'],
        num_layers=params['num_layers'],
        dropout=params['dropout'],
        bidirectional=params['bidirectional'],
        cell_type=params['cell_type']
    )
    
    # Calculate encoder output size (doubled if bidirectional)
    enc_hidden_size = params['hidden_size'] * 2 if params['bidirectional'] else params['hidden_size']
    
    decoder = Decoder(
        output_size=len(tgt_vocab),
        embedding_size=params['embedding_size'],
        enc_hidden_size=enc_hidden_size,
        dec_hidden_size=params['hidden_size'],
        num_layers=params['num_layers'],
        dropout=params['dropout'],
        cell_type=params['cell_type']
    )
    
    # Create full model
    model = Seq2Seq(encoder, decoder, device, teacher_forcing_ratio=params['teacher_forcing'])
    model = model.to(device)
    
    # Print model size
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {total_params:,} parameters ({trainable_params:,} trainable)")
    
    # Loss function (ignore padding token)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    # Optimizer
    if params['optim'] == 'nadam':
        try:
            optimizer = optim.NAdam(model.parameters(), lr=params['learning_rate'])
        except AttributeError:
            print("NAdam optimizer not available, falling back to Adam")
            optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])
    else:
        optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])
    
    # Training loop
    best_val_acc = 0
    best_epoch = 0
    model_path = os.path.join(model_dir, "best_model.pt")
    
    for epoch in range(params['epochs']):
        print(f"\nEpoch {epoch+1}/{params['epochs']}")
        
        # Train
        train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Evaluate
        val_metrics = evaluate(model, val_loader, criterion, device)
        
        # Print metrics
        print(f"Train - Loss: {train_metrics['loss']:.4f}, Char Acc: {train_metrics['char_acc']:.4f}, "
              f"Exact Match: {train_metrics['exact_match_acc']:.4f}")
        print(f"Val - Loss: {val_metrics['loss']:.4f}, Char Acc: {val_metrics['char_acc']:.4f}, "
              f"Exact Match: {val_metrics['exact_match_acc']:.4f} ({val_metrics['correct']}/{val_metrics['total']})")
        
        # Log to WandB
        if use_wandb:
            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_metrics['loss'],
                'train_char_accuracy': train_metrics['char_acc'],
                'train_exact_match': train_metrics['exact_match_acc'],
                'val_loss': val_metrics['loss'],
                'val_char_accuracy': val_metrics['char_acc'],
                'val_exact_match': val_metrics['exact_match_acc']
            })
        
        # Save best model
        if val_metrics['exact_match_acc'] > best_val_acc:
            best_val_acc = val_metrics['exact_match_acc']
            best_epoch = epoch + 1
            
            # Save model
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['loss'],
                'val_accuracy': val_metrics['exact_match_acc'],
                'params': params
            }, model_path)
            
            print(f"Saved new best model with validation accuracy: {best_val_acc:.4f}")
    
    print(f"\nTraining complete. Best validation accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
    
    # Load best model for testing
    print("\nLoading best model for testing...")
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate on test set
    test_metrics = evaluate(model, test_loader, criterion, device)
    
    print(f"\nTest Results:")
    print(f"Loss: {test_metrics['loss']:.4f}")
    print(f"Character Accuracy: {test_metrics['char_acc']:.4f}")
    print(f"Exact Match Accuracy: {test_metrics['exact_match_acc']:.4f} "
          f"({test_metrics['correct']}/{test_metrics['total']})")
    
    # Log final results to WandB
    if use_wandb:
        wandb.log({
            'test_loss': test_metrics['loss'],
            'test_char_accuracy': test_metrics['char_acc'],
            'test_exact_match': test_metrics['exact_match_acc'],
            'best_val_accuracy': best_val_acc
        })
    
    # Display examples of correct and incorrect predictions
    print("\nAnalyzing predictions on test set...")
    model.eval()
    all_predictions = []
    all_targets = []
    all_sources = []
    
    with torch.no_grad():
        for src, tgt in test_loader:
            src, tgt = src.to(device), tgt.to(device)
            predictions, _ = model.decode(src)
            
            # Convert to readable strings
            id_to_char = {v: k for k, v in test_dataset.tgt_vocab.items() 
                         if k not in ['<pad>', '<sos>', '<eos>', '<unk>']}
            
            for i in range(src.size(0)):
                # Source (roman)
                src_str = ''.join([test_dataset.src_vocab.get(str(idx.item()), '') 
                                  for idx in src[i] if idx.item() not in [0, 1, 2, 3]])
                
                # Target (native)
                tgt_str = ''.join([id_to_char.get(idx.item(), '') 
                                  for idx in tgt[i, 1:] if idx.item() not in [0, 1, 2, 3]])
                
                # Prediction
                pred_str = ''.join([id_to_char.get(idx.item(), '') 
                                   for idx in predictions[i, 1:] if idx.item() not in [0, 1, 2, 3]])
                
                all_sources.append(src_str)
                all_targets.append(tgt_str)
                all_predictions.append(pred_str)
    
    # Get correct and incorrect examples
    correct_examples = [(s, t, p) for s, t, p in zip(all_sources, all_targets, all_predictions) if t == p]
    incorrect_examples = [(s, t, p) for s, t, p in zip(all_sources, all_targets, all_predictions) if t != p]
    
    # Display some correct examples
    print(f"\nCorrect Examples ({len(correct_examples)} total):")
    for i, (src, tgt, pred) in enumerate(correct_examples[:5]):
        print(f"{i+1}. Roman: '{src}'")
        print(f"   Native: '{tgt}'")
    
    # Display some incorrect examples
    print(f"\nIncorrect Examples ({len(incorrect_examples)} total):")
    for i, (src, tgt, pred) in enumerate(incorrect_examples[:5]):
        print(f"{i+1}. Roman: '{src}'")
        print(f"   Native (correct): '{tgt}'")
        print(f"   Prediction: '{pred}'")
    
    return {
        'val_accuracy': best_val_acc,
        'test_accuracy': test_metrics['exact_match_acc'],
        'correct': test_metrics['correct'],
        'total': test_metrics['total']
    }



In [6]:
run_best_params()

Using device: cuda
Loaded existing vocabulary
Loaded 58550 examples from /kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.train.tsv
Sample examples:
  Roman: 'amkita', Native: 'అంకిత'
  Roman: 'ankita', Native: 'అంకిత'
  Roman: 'ankitha', Native: 'అంకిత'
Loaded 5683 examples from /kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.dev.tsv
Sample examples:
  Roman: 'amka', Native: 'అంక'
  Roman: 'anka', Native: 'అంక'
  Roman: 'amkam', Native: 'అంకం'
Loaded 5747 examples from /kaggle/input/dakshina-dataset/dakshina_dataset_v1.0/te/lexicons/te.translit.sampled.test.tsv
Sample examples:
  Roman: 'amkamlo', Native: 'అంకంలో'
  Roman: 'ankamlo', Native: 'అంకంలో'
  Roman: 'ankamloo', Native: 'అంకంలో'
Loaded datasets - Train: 58550, Val: 5683, Test: 5747
Model has 5,142,595 parameters (5,142,595 trainable)

Epoch 1/15


Training: 100%|██████████| 915/915 [05:32<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.26it/s]


Train - Loss: 1.7257, Char Acc: 0.5244, Exact Match: 0.0500
Val - Loss: 0.7365, Char Acc: 0.7911, Exact Match: 0.2467 (1402/5683)
Saved new best model with validation accuracy: 0.2467

Epoch 2/15


Training: 100%|██████████| 915/915 [05:33<00:00,  2.74it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.34it/s]


Train - Loss: 0.8121, Char Acc: 0.7763, Exact Match: 0.2127
Val - Loss: 0.5601, Char Acc: 0.8433, Exact Match: 0.3648 (2073/5683)
Saved new best model with validation accuracy: 0.3648

Epoch 3/15


Training: 100%|██████████| 915/915 [05:27<00:00,  2.80it/s]
Evaluating: 100%|██████████| 89/89 [00:13<00:00,  6.41it/s]


Train - Loss: 0.6301, Char Acc: 0.8288, Exact Match: 0.2997
Val - Loss: 0.5599, Char Acc: 0.8537, Exact Match: 0.4086 (2322/5683)
Saved new best model with validation accuracy: 0.4086

Epoch 4/15


Training: 100%|██████████| 915/915 [05:25<00:00,  2.81it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.36it/s]


Train - Loss: 0.5675, Char Acc: 0.8474, Exact Match: 0.3513
Val - Loss: 0.4673, Char Acc: 0.8710, Exact Match: 0.4373 (2485/5683)
Saved new best model with validation accuracy: 0.4373

Epoch 5/15


Training: 100%|██████████| 915/915 [05:31<00:00,  2.76it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.30it/s]


Train - Loss: 0.5157, Char Acc: 0.8611, Exact Match: 0.3884
Val - Loss: 0.4846, Char Acc: 0.8731, Exact Match: 0.4765 (2708/5683)
Saved new best model with validation accuracy: 0.4765

Epoch 6/15


Training: 100%|██████████| 915/915 [05:32<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:13<00:00,  6.36it/s]


Train - Loss: 0.4792, Char Acc: 0.8714, Exact Match: 0.4212
Val - Loss: 0.4621, Char Acc: 0.8739, Exact Match: 0.4746 (2697/5683)

Epoch 7/15


Training: 100%|██████████| 915/915 [05:33<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.28it/s]


Train - Loss: 0.4471, Char Acc: 0.8800, Exact Match: 0.4388
Val - Loss: 0.4388, Char Acc: 0.8815, Exact Match: 0.4918 (2795/5683)
Saved new best model with validation accuracy: 0.4918

Epoch 8/15


Training: 100%|██████████| 915/915 [05:32<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.32it/s]


Train - Loss: 0.4241, Char Acc: 0.8860, Exact Match: 0.4623
Val - Loss: 0.4341, Char Acc: 0.8848, Exact Match: 0.4872 (2769/5683)

Epoch 9/15


Training: 100%|██████████| 915/915 [05:32<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.28it/s]


Train - Loss: 0.4017, Char Acc: 0.8919, Exact Match: 0.4806
Val - Loss: 0.4071, Char Acc: 0.8895, Exact Match: 0.5131 (2916/5683)
Saved new best model with validation accuracy: 0.5131

Epoch 10/15


Training: 100%|██████████| 915/915 [05:33<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.35it/s]


Train - Loss: 0.3847, Char Acc: 0.8960, Exact Match: 0.4912
Val - Loss: 0.4126, Char Acc: 0.8921, Exact Match: 0.5106 (2902/5683)

Epoch 11/15


Training: 100%|██████████| 915/915 [05:30<00:00,  2.77it/s]
Evaluating: 100%|██████████| 89/89 [00:13<00:00,  6.44it/s]


Train - Loss: 0.3823, Char Acc: 0.8977, Exact Match: 0.5063
Val - Loss: 0.3798, Char Acc: 0.8978, Exact Match: 0.5286 (3004/5683)
Saved new best model with validation accuracy: 0.5286

Epoch 12/15


Training: 100%|██████████| 915/915 [05:29<00:00,  2.78it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.34it/s]


Train - Loss: 0.3596, Char Acc: 0.9036, Exact Match: 0.5179
Val - Loss: 0.4191, Char Acc: 0.8920, Exact Match: 0.5465 (3106/5683)
Saved new best model with validation accuracy: 0.5465

Epoch 13/15


Training: 100%|██████████| 915/915 [05:33<00:00,  2.74it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.29it/s]


Train - Loss: 0.3474, Char Acc: 0.9061, Exact Match: 0.5309
Val - Loss: 0.4164, Char Acc: 0.8925, Exact Match: 0.5367 (3050/5683)

Epoch 14/15


Training: 100%|██████████| 915/915 [05:33<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:13<00:00,  6.36it/s]


Train - Loss: 0.3343, Char Acc: 0.9099, Exact Match: 0.5408
Val - Loss: 0.3845, Char Acc: 0.8987, Exact Match: 0.5453 (3099/5683)

Epoch 15/15


Training: 100%|██████████| 915/915 [05:32<00:00,  2.75it/s]
Evaluating: 100%|██████████| 89/89 [00:14<00:00,  6.27it/s]


Train - Loss: 0.3284, Char Acc: 0.9113, Exact Match: 0.5531
Val - Loss: 0.3450, Char Acc: 0.9066, Exact Match: 0.5437 (3090/5683)

Training complete. Best validation accuracy: 0.5465 at epoch 12

Loading best model for testing...


Evaluating: 100%|██████████| 90/90 [00:14<00:00,  6.28it/s]



Test Results:
Loss: 0.3768
Character Accuracy: 0.8969
Exact Match Accuracy: 0.5333 (3065/5747)

Analyzing predictions on test set...

Correct Examples (3065 total):
1. Roman: ''
   Native: 'అంకంలో'
2. Roman: ''
   Native: 'అంకంలో'
3. Roman: ''
   Native: 'అంకంలో'
4. Roman: ''
   Native: 'అంకితమై'
5. Roman: ''
   Native: 'అంకితమై'

Incorrect Examples (2682 total):
1. Roman: ''
   Native (correct): 'అంకెల'
   Prediction: 'అంకేల'
2. Roman: ''
   Native (correct): 'అంకెలను'
   Prediction: 'అంకేలను'
3. Roman: ''
   Native (correct): 'అంగీకరించాలి'
   Prediction: 'అంగికరించాలి'
4. Roman: ''
   Native (correct): 'అంగీకరించి'
   Prediction: 'అంగికరించి'
5. Roman: ''
   Native (correct): 'అంగీకరించే'
   Prediction: 'అంగికరించే'


{'val_accuracy': 0.5465423191976069,
 'test_accuracy': 0.5333217330781277,
 'correct': 3065,
 'total': 5747}