In [1]:
# Downloading dakshina dataset
!yes | wget "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"

--2025-05-20 19:09:40--  https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.216.207, 173.194.217.207, 173.194.215.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.216.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2008340480 (1.9G) [application/x-tar]
Saving to: ‘dakshina_dataset_v1.0.tar’


2025-05-20 19:09:46 (296 MB/s) - ‘dakshina_dataset_v1.0.tar’ saved [2008340480/2008340480]

yes: standard output: Broken pipe


In [2]:
# Unzipping dataset
!yes | tar xopf dakshina_dataset_v1.0.tar

yes: standard output: Broken pipe


In [4]:
import wandb

In [5]:
wandb.login(key='49f8f505158ee3693f0cacf0a82118bd4e636e8c')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msurendarmohan283[0m ([33msurendarmohan283-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random

# For reproducibility - keep the same for consistency
def initialize_random_seeds(seed=42):
    """Configure all random seeds for reproducibility"""
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class EncoderNetwork(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=1, rnn_type='LSTM', 
                 dropout_rate=0.0, use_bidirectional=False):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.use_bidirectional = use_bidirectional
        self.rnn_type = rnn_type
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        
        # Adjust output dimension for bidirectional
        self.output_dim = hidden_dim * 2 if use_bidirectional else hidden_dim
        
        # Select RNN type
        rnn_classes = {'LSTM': nn.LSTM, 'GRU': nn.GRU, 'RNN': nn.RNN}
        selected_rnn = rnn_classes[rnn_type]
        
        self.rnn_layer = selected_rnn(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout_rate if num_layers > 1 else 0.0,
            batch_first=True,
            bidirectional=use_bidirectional
        )

    def forward(self, input_sequence, sequence_lengths):
        # input_sequence: [Batch, SeqLen], sequence_lengths: [Batch]
        embedded_sequence = self.token_embedding(input_sequence)  # [Batch, SeqLen, EmbeddingDim]
        
        # Pack to handle variable lengths efficiently
        packed_embeddings = pack_padded_sequence(
            embedded_sequence, 
            sequence_lengths.cpu(), 
            batch_first=True, 
            enforce_sorted=False
        )
        
        # Process through RNN
        packed_outputs, hidden_states = self.rnn_layer(packed_embeddings)
        
        # Unpack to get full sequence
        outputs, _ = pad_packed_sequence(packed_outputs, batch_first=True)  # [Batch, SeqLen, HiddenDim*Directions]
        
        # Process bidirectional states
        if self.use_bidirectional:
            if self.rnn_type == 'LSTM':
                # Handle LSTM's (hidden, cell) pair
                h_states, c_states = hidden_states
                # Combine forward and backward directions
                h_combined = torch.add(h_states[0:self.num_layers], h_states[self.num_layers:]) / 2
                c_combined = torch.add(c_states[0:self.num_layers], c_states[self.num_layers:]) / 2
                hidden_states = (h_combined, c_combined)
            else:
                # For GRU/RNN with only hidden state
                hidden_combined = torch.add(hidden_states[0:self.num_layers], hidden_states[self.num_layers:]) / 2
                hidden_states = hidden_combined
                
        return outputs, hidden_states


class AttentionMechanism(nn.Module):
    def __init__(self, encoder_dim, decoder_dim):
        super().__init__()
        self.attention_projection = nn.Linear(encoder_dim + decoder_dim, decoder_dim)
        self.energy_vector = nn.Linear(decoder_dim, 1, bias=False)

    def forward(self, decoder_hidden, encoder_outputs, attention_mask):
        # decoder_hidden: [Batch, HiddenDim], encoder_outputs: [Batch, SrcLen, HiddenDim], 
        # attention_mask: [Batch, SrcLen]
        batch_size, src_len, hidden_dim = encoder_outputs.size()
        
        # Expand decoder hidden state to match encoder outputs
        expanded_hidden = decoder_hidden.unsqueeze(1).repeat(1, src_len, 1)  # [Batch, SrcLen, HiddenDim]
        
        # Calculate attention scores
        energy = torch.tanh(self.attention_projection(
            torch.cat((expanded_hidden, encoder_outputs), dim=2)
        ))  # [Batch, SrcLen, HiddenDim]
        
        attention_scores = self.energy_vector(energy).squeeze(2)  # [Batch, SrcLen]
        
        # Apply mask to prevent attention to padding tokens
        attention_scores = attention_scores.masked_fill(~attention_mask, -1e9)
        
        # Apply softmax to get attention weights
        return torch.softmax(attention_scores, dim=1)  # [Batch, SrcLen]


class DecoderNetwork(nn.Module):
    """
    Decoder with two operating modes:
        • with_attention=True  – Uses Bahdanau attention mechanism (default)
        • with_attention=False – Uses plain RNN decoder without attention

    Always returns (logits, hidden_state, attention_weights_or_None)
    """
    def __init__(self, vocab_size, embedding_dim, encoder_dim, decoder_dim,
                 num_layers=1, rnn_type="LSTM", dropout_rate=0.0, with_attention=True):
        super().__init__()
        self.with_attention = with_attention
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn_type = rnn_type

        # Configure input dimensions based on attention usage
        if with_attention:
            self.attention = AttentionMechanism(encoder_dim, decoder_dim)
            rnn_input_size = embedding_dim + encoder_dim  # Concatenate embedding with context
            output_projection_size = decoder_dim + encoder_dim + embedding_dim  # Full projection
        else:
            rnn_input_size = embedding_dim
            output_projection_size = decoder_dim + embedding_dim

        # Select RNN type
        rnn_classes = {"LSTM": nn.LSTM, "GRU": nn.GRU, "RNN": nn.RNN}
        selected_rnn = rnn_classes[rnn_type]
        
        self.rnn_layer = selected_rnn(
            rnn_input_size, 
            decoder_dim,
            num_layers=num_layers,
            dropout=dropout_rate if num_layers > 1 else 0.0,
            batch_first=True
        )
        
        self.output_projection = nn.Linear(output_projection_size, vocab_size)

    def forward(self, current_token, hidden_state, encoder_outputs, attention_mask):
        """
        Process a single decoding step
        
        Args:
            current_token: [Batch] - Current input tokens
            hidden_state: tuple|tensor - Previous decoder state
            encoder_outputs: [Batch, SrcLen, EncoderDim] - All encoder outputs
            attention_mask: [Batch, SrcLen] - Mask for valid source positions
        """
        # Embed the current token
        token_embedding = self.token_embedding(current_token).unsqueeze(1)  # [Batch, 1, EmbeddingDim]

        if self.with_attention:
            # Extract the relevant hidden state for attention
            if self.rnn_type == 'LSTM':
                decoder_hidden = hidden_state[0][-1]
            else:
                decoder_hidden = hidden_state[-1]
                
            # Calculate attention weights and context vector
            attention_weights = self.attention(decoder_hidden, encoder_outputs, attention_mask)  # [Batch, SrcLen]
            context_vector = torch.bmm(
                attention_weights.unsqueeze(1), 
                encoder_outputs
            )  # [Batch, 1, EncoderDim]
            
            # Concatenate embedding with context for RNN input
            rnn_input = torch.cat((token_embedding, context_vector), dim=2)  # [Batch, 1, EmbeddingDim+EncoderDim]
        else:
            context_vector = None
            attention_weights = None
            rnn_input = token_embedding  # [Batch, 1, EmbeddingDim]

        # Process through RNN
        output, new_hidden = self.rnn_layer(rnn_input, hidden_state)  # [Batch, 1, DecoderDim]
        
        # Remove sequence dimension
        output = output.squeeze(1)  # [Batch, DecoderDim]
        embedding = token_embedding.squeeze(1)  # [Batch, EmbeddingDim]

        # Create final projection input
        if self.with_attention:
            context = context_vector.squeeze(1)  # [Batch, EncoderDim]
            logits = self.output_projection(torch.cat((output, context, embedding), dim=1))
        else:
            logits = self.output_projection(torch.cat((output, embedding), dim=1))

        return logits, new_hidden, attention_weights


class Seq2SeqModel(nn.Module):
    def __init__(self, encoder, decoder, padding_idx, device='cpu'):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.padding_idx = padding_idx
        self.device = device

    def forward(self, src_sequence, src_lengths, tgt_sequence, teacher_forcing_ratio=0.5):
        """
        Training forward pass with teacher forcing
        
        Args:
            src_sequence: Input sequence [Batch, SrcLen]
            src_lengths: Lengths of source sequences [Batch]
            tgt_sequence: Target sequence [Batch, TgtLen]
            teacher_forcing_ratio: Probability of using teacher forcing
        """
        # Encode source sequence
        encoder_outputs, hidden_state = self.encoder(src_sequence, src_lengths)
        
        # Create mask for attention (1 for real tokens, 0 for padding)
        attention_mask = (src_sequence != self.padding_idx)
        
        # Initialize for decoding
        batch_size, target_length = tgt_sequence.size()
        output_vocab_size = self.decoder.output_projection.out_features
        outputs = torch.zeros(batch_size, target_length-1, output_vocab_size, device=self.device)
        
        # Initialize with start token
        current_token = tgt_sequence[:, 0]  # <sos>
        
        # Decode one step at a time
        for t in range(1, target_length):
            # Compute output for current step
            step_output, hidden_state, _ = self.decoder(
                current_token, hidden_state, encoder_outputs, attention_mask
            )
            outputs[:, t-1] = step_output
            
            # Decide whether to use teacher forcing
            use_teacher_forcing = random.random() < teacher_forcing_ratio
            
            if use_teacher_forcing:
                # Use ground truth as next input
                current_token = tgt_sequence[:, t]
            else:
                # Use model prediction as next input
                current_token = step_output.argmax(1)
                
        return outputs

    def generate_greedy(self, src_sequence, src_lengths, target_vocab, max_length=50):
        """
        Generate output sequence using greedy decoding
        """
        # Encode source sequence
        encoder_outputs, hidden_state = self.encoder(src_sequence, src_lengths)
        attention_mask = (src_sequence != self.padding_idx)
        
        # Initialize decoding with start token
        batch_size = src_sequence.size(0)
        current_token = torch.full(
            (batch_size,), 
            target_vocab.sos_idx, 
            device=self.device, 
            dtype=torch.long
        )
        
        generated_tokens = []
        
        # Generate tokens step by step
        for _ in range(max_length):
            step_output, hidden_state, _ = self.decoder(
                current_token, hidden_state, encoder_outputs, attention_mask
            )
            
            # Select most likely token
            current_token = step_output.argmax(1)
            generated_tokens.append(current_token.unsqueeze(1))
            
            # Stop if all sequences have generated end token
            if (current_token == target_vocab.eos_idx).all():
                break
                
        return torch.cat(generated_tokens, dim=1)


# Download Dakshina dataset - use for setup but no changes here
# !yes | wget "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"
# !yes | tar xopf dakshina_dataset_v1.0.tar

# Vocabulary handling class
class CharacterVocabulary:
    """Character-level vocabulary for transliteration tasks"""
    def __init__(self, token_list=None, special_tokens=['<pad>','<sos>','<eos>','<unk>']):
        self.special_tokens = special_tokens
        self.idx2token = list(special_tokens) + (token_list or [])
        self.token2idx = {token:idx for idx, token in enumerate(self.idx2token)}

    @classmethod
    def create_from_text_collection(cls, text_collection):
        """Build vocabulary from a collection of texts"""
        unique_chars = sorted({char for text in text_collection for char in text})
        return cls(token_list=unique_chars)
    
    @classmethod
    def create_from_data_file(cls, file_path, source_col='src', target_col='trg', is_csv_format=True):
        """
        Build vocabulary from a data file
        """
        if is_csv_format:
            import pandas as pd
            df = pd.read_csv(file_path, header=None, names=[source_col, target_col])
            all_texts = df[source_col].dropna().tolist() + df[target_col].dropna().tolist()
        else:
            all_texts = []
            with open(file_path, encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        all_texts.extend([parts[0], parts[1]])
        
        return cls.create_from_text_collection(all_texts)

    def save_to_json(self, path):
        """Save vocabulary to a JSON file"""
        import json
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(self.idx2token, f, ensure_ascii=False)

    @classmethod
    def load_from_json(cls, path):
        """Load vocabulary from a JSON file"""
        import json
        with open(path, encoding='utf-8') as f:
            idx2token = json.load(f)
        
        vocab = cls(token_list=[])
        vocab.idx2token = idx2token
        vocab.token2idx = {token:idx for idx, token in enumerate(idx2token)}
        return vocab

    def encode_text(self, text, add_start=False, add_end=False):
        """Convert text to token indices"""
        indices = []
        if add_start: 
            indices.append(self.token2idx['<sos>'])
            
        for char in text:
            indices.append(self.token2idx.get(char, self.token2idx['<unk>']))
            
        if add_end: 
            indices.append(self.token2idx['<eos>'])
            
        return indices

    def decode_indices(self, indices, remove_special=True, join_chars=True):
        """Convert token indices back to text"""
        # Handle tensor input
        if hasattr(indices, 'tolist'):
            indices = indices.tolist()
            
        # Convert indices to characters
        chars = [self.idx2token[idx] for idx in indices if idx < len(self.idx2token)]
        
        # Remove special tokens if requested
        if remove_special:
            chars = [c for c in chars if c not in self.special_tokens]
            
        # Return as string or character list
        return ''.join(chars) if join_chars else chars
    
    def batch_decode_indices(self, batch_indices, remove_special=True):
        """Decode a batch of index sequences"""
        return [self.decode_indices(seq, remove_special=remove_special) for seq in batch_indices]
    
    def get_statistics(self):
        """Get vocabulary statistics"""
        return {
            'total_size': len(self.idx2token),
            'special_tokens': len(self.special_tokens),
            'character_count': len(self.idx2token) - len(self.special_tokens)
        }
    
    def __len__(self):
        return len(self.idx2token)

    @property
    def pad_idx(self): return self.token2idx['<pad>']
    
    @property
    def sos_idx(self): return self.token2idx['<sos>']
    
    @property
    def eos_idx(self): return self.token2idx['<eos>']
    
    @property
    def unk_idx(self): return self.token2idx['<unk>']
    
    @property
    def size(self): return len(self.idx2token)


# Data loading utilities
import os
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd

class TransliterationDataset(Dataset):
    """Dataset for transliteration tasks"""
    
    def __init__(self, data_path, source_vocab, target_vocab, dataset_format='dakshina'):
        """
        Initialize dataset
        
        Args:
            data_path: Path to data file
            source_vocab: Source language vocabulary
            target_vocab: Target language vocabulary
            dataset_format: Format specification ('dakshina')
        """
        self.examples = []
        self.dataset_format = dataset_format
        
        if dataset_format == 'dakshina':
            # Process tab-separated format
            for src_text, tgt_text in read_transliteration_pairs(data_path):
                # Encode sequences with start/end tokens
                src_indices = source_vocab.encode_text(src_text, add_start=True, add_end=True)
                tgt_indices = target_vocab.encode_text(tgt_text, add_start=True, add_end=True)
                
                # Convert to tensors
                self.examples.append((
                    torch.tensor(src_indices, dtype=torch.long),
                    torch.tensor(tgt_indices, dtype=torch.long)
                ))
        else:
            raise ValueError(f"Unsupported dataset format: {dataset_format}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


def read_transliteration_pairs(path):
    """Read transliteration pairs from tsv file"""
    with open(path, encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 2:
                yield parts[1], parts[0]  # Dakshina has target, source order


def read_transliteration_csv(path, src_col='src', tgt_col='trg'):
    """Read transliteration pairs from CSV file"""
    df = pd.read_csv(path)
    for _, row in df.iterrows():
        yield row[src_col], row[tgt_col]


def batch_collate_function(batch, src_vocab, tgt_vocab):
    """Collate function for variable-length sequences"""
    src_sequences, tgt_sequences = zip(*batch)
    
    # Pad sequences to equal length in batch
    src_padded = pad_sequence(src_sequences, batch_first=True, padding_value=src_vocab.pad_idx)
    tgt_padded = pad_sequence(tgt_sequences, batch_first=True, padding_value=tgt_vocab.pad_idx)
    
    # Create length tensor for packed sequences
    src_lengths = torch.tensor([len(seq) for seq in src_sequences], dtype=torch.long)
    
    return src_padded, src_lengths, tgt_padded


def prepare_data_loaders(
        language='ta',  # Changed from 'te' to 'ta' for Tamil
        dataset_format='dakshina',
        base_path=None,
        batch_size=64,
        device='cpu',
        num_workers=2,
        prefetch_factor=4,
        persistent_workers=True,
        cache_dir='./cache',
        use_cached_vocab=True
    ):
    """
    Load datasets and create data loaders
    
    Args:
        language: Language code ('ta' for Tamil)
        dataset_format: Format specification ('dakshina')
        base_path: Override default dataset path
        batch_size: Batch size for loaders
        device: Device to use ('cuda' or 'cpu')
        num_workers: Data loading worker count
        prefetch_factor: Number of batches to prefetch
        persistent_workers: Keep workers alive between epochs
        cache_dir: Directory for cached vocabularies
        use_cached_vocab: Whether to use cached vocab files
    """
    # Set up data paths
    if base_path is None:
        base_path = os.path.join(
            '/kaggle/working/dakshina_dataset_v1.0',
            language, 'lexicons'
        )

    # Set up vocabulary caching
    if use_cached_vocab:
        os.makedirs(cache_dir, exist_ok=True)
        vocab_cache_path = os.path.join(cache_dir, f"{language}_{dataset_format}_vocab.pkl")
    
    # Try to load cached vocabularies
    if use_cached_vocab and os.path.exists(vocab_cache_path):
        print(f"Loading cached vocabularies from {vocab_cache_path}")
        with open(vocab_cache_path, 'rb') as f:
            src_vocab, tgt_vocab = pickle.load(f)
    else:
        # Build vocabularies from data
        src_texts, tgt_texts = [], []
        
        for split in ['train', 'dev']:
            data_path = os.path.join(base_path, f"{language}.translit.sampled.{split}.tsv")
            for src, tgt in read_transliteration_pairs(data_path):
                src_texts.append(src)
                tgt_texts.append(tgt)
        
        # Create vocabularies
        src_vocab = CharacterVocabulary.create_from_text_collection(src_texts)
        tgt_vocab = CharacterVocabulary.create_from_text_collection(tgt_texts)
        
        # Cache for future use
        if use_cached_vocab:
            with open(vocab_cache_path, 'wb') as f:
                pickle.dump((src_vocab, tgt_vocab), f)
    
    # Configure data loaders
    loader_kwargs = dict(
        batch_size=batch_size,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor,
        persistent_workers=persistent_workers and num_workers > 0,
        pin_memory=(device == 'cuda')
    )
    
    # Create data loaders for each split
    data_loaders = {}
    
    splits = {'train': 'train', 'dev': 'dev', 'test': 'test'}
    for split_name, file_suffix in splits.items():
        data_path = os.path.join(base_path, f"{language}.translit.sampled.{file_suffix}.tsv")
        dataset = TransliterationDataset(data_path, src_vocab, tgt_vocab, format='dakshina')
        
        data_loaders[split_name] = DataLoader(
            dataset,
            shuffle=(split_name == 'train'),
            collate_fn=lambda b: batch_collate_function(b, src_vocab, tgt_vocab),
            **loader_kwargs
        )
    
    return data_loaders, src_vocab, tgt_vocab


# Training and evaluation utilities
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from tqdm.auto import tqdm
import csv
import pandas as pd

def evaluate_model_accuracy(model, data_loader, target_vocab, source_vocab, device):
    """
    Evaluate model prediction accuracy with detailed error analysis
    
    Returns:
        - Overall accuracy
        - Correct predictions details
        - Incorrect predictions details
    """
    model.eval()
    correct_count = total_count = 0
    
    # Lists for detailed analysis
    correct_sources = []
    correct_targets = []
    correct_predictions = []
    
    incorrect_sources = []
    incorrect_targets = []
    incorrect_predictions = []
    
    with torch.no_grad():
        for src_batch, src_lengths, tgt_batch in data_loader:
            # Move tensors to device
            src_batch = src_batch.to(device)
            src_lengths = src_lengths.to(device)
            tgt_batch = tgt_batch.to(device)
            
            # Generate predictions
            predictions = model.generate_greedy(
                src_batch, src_lengths, target_vocab, max_length=tgt_batch.size(1)
            )

            # Evaluate each example in batch
            for idx in range(src_batch.size(0)):
                # Convert tensors to text
                pred_text = target_vocab.decode_indices(predictions[idx].cpu())
                gold_text = target_vocab.decode_indices(tgt_batch[idx, 1:].cpu())  # Skip <sos>
                src_text = source_vocab.decode_indices(src_batch[idx].cpu())
                
                # Check accuracy
                is_correct = (pred_text == gold_text)
                correct_count += int(is_correct)
                
                # Store details for analysis
                if is_correct:
                    correct_sources.append(src_text)
                    correct_targets.append(gold_text)
                    correct_predictions.append(pred_text)
                else:
                    incorrect_sources.append(src_text)
                    incorrect_targets.append(gold_text)
                    incorrect_predictions.append(pred_text)
                    
            total_count += src_batch.size(0)

    # Calculate overall accuracy
    accuracy = correct_count / total_count if total_count > 0 else 0.0
    
    return (
        accuracy, 
        (correct_sources, correct_targets, correct_predictions),
        (incorrect_sources, incorrect_targets, incorrect_predictions)
    )

def save_prediction_results(src_texts, tgt_texts, pred_texts, filename):
    """Save prediction details to CSV file"""
    with open(filename, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(['Source', 'Target', 'Predicted'])
        for row in zip(src_texts, tgt_texts, pred_texts):
            writer.writerow(row)
    
    return filename

def train_sequence_model(
    model, 
    data_loaders, 
    src_vocab, 
    tgt_vocab, 
    device,
    config,
    model_save_path=None,
    enable_wandb=True
):
    """
    Train and evaluate the sequence model
    
    Args:
        model: The Seq2Seq model
        data_loaders: Dictionary of data loaders
        src_vocab, tgt_vocab: Source and target vocabularies
        device: Computation device
        config: Training configuration
        model_save_path: Where to save best model
        enable_wandb: Whether to log to Weights & Biases
    """
    # Setup loss function - ignore padding in loss calculation
    loss_function = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)
    
    # Configure optimizer
    if config.optimizer.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer.lower() == 'nadam':
        optimizer = optim.NAdam(model.parameters(), lr=config.lr)
    else:
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    
    # Track best model
    best_validation_accuracy = 0.0
    
    # Training loop
    for epoch in tqdm(range(1, config.epochs + 1), desc="Training Progress", position=0):
        # Training phase
        model.train()
        epoch_loss = 0.0

        # Process training batches
        train_iter = tqdm(data_loaders['train'], desc=f"Epoch {epoch}", leave=False, position=1)
        for src_batch, src_lengths, tgt_batch in train_iter:
            # Move data to device
            src_batch = src_batch.to(device)
            src_lengths = src_lengths.to(device)
            tgt_batch = tgt_batch.to(device)

            # Forward and backward pass
            optimizer.zero_grad()
            batch_output = model(
                src_batch, src_lengths, tgt_batch, 
                teacher_forcing_ratio=config.teacher_forcing
            )
            
            # Calculate loss (flatten predictions and targets)
            loss = loss_function(
                batch_output.reshape(-1, batch_output.size(-1)), 
                tgt_batch[:,1:].reshape(-1)  # Skip <sos> token
            )
            
            # Backpropagation
            loss.backward()
            
            # Gradient clipping to prevent explosion
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Update weights
            optimizer.step()
            
            # Track loss
            epoch_loss += loss.item()
            
        # Calculate average training loss
        train_loss = epoch_loss / len(data_loaders['train'])

        # Validation phase
        val_loss = 0.0
        val_iter = tqdm(data_loaders['dev'], desc=f"Validation {epoch}", leave=False, position=1)
        
        model.eval()
        with torch.no_grad():
            for src_batch, src_lengths, tgt_batch in val_iter:
                # Move data to device
                src_batch = src_batch.to(device)
                src_lengths = src_lengths.to(device)
                tgt_batch = tgt_batch.to(device)
                
                # Forward pass (no teacher forcing during validation)
                batch_output = model(src_batch, src_lengths, tgt_batch, teacher_forcing_ratio=0.0)
                
                # Calculate validation loss
                val_loss += loss_function(
                    batch_output.reshape(-1, batch_output.size(-1)),
                    tgt_batch[:,1:].reshape(-1)
                ).item()
        
        # Calculate average validation loss
        val_loss /= len(data_loaders['dev'])

        # Compute accuracy metrics
        train_results = evaluate_model_accuracy(
            model, data_loaders['train'], tgt_vocab, src_vocab, device
        )
        train_accuracy = train_results[0]
        
        val_results = evaluate_model_accuracy(
            model, data_loaders['dev'], tgt_vocab, src_vocab, device
        )
        val_accuracy = val_results[0]
        
        # Save best model
        if val_accuracy > best_validation_accuracy and model_save_path:
            best_validation_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"✓ New best model saved with validation accuracy: {val_accuracy:.4f}")
            
            # Save prediction analysis for best model
            if epoch == config.epochs or epoch % 5 == 0:
                # Save correct and incorrect predictions
                correct_data = val_results[1]
                incorrect_data = val_results[2]
                
                save_prediction_results(
                    correct_data[0], correct_data[1], correct_data[2],
                    f"correct_preds_epoch_{epoch}.csv"
                )
                
                save_prediction_results(
                    incorrect_data[0], incorrect_data[1], incorrect_data[2],
                    f"incorrect_preds_epoch_{epoch}.csv"
                )

        # Print progress
        print(f"Epoch {epoch}/{config.epochs}:")
        print(f"  Training:   Loss={train_loss:.4f}, Accuracy={train_accuracy:.4f}")
        print(f"  Validation: Loss={val_loss:.4f}, Accuracy={val_accuracy:.4f}")
        
        # Log metrics to WandB
        if enable_wandb:
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_accuracy': train_accuracy,
                'val_accuracy': val_accuracy
            })
    
    # Placeholder for test accuracy (not evaluating test set in this version)
    test_accuracy = 0
    return model, test_accuracy


# Hyperparameter sweeping functionality 
import wandb
import torch
from tqdm.auto import tqdm
import os
import random
import numpy as np

def objective():
    # Initialize WandB run
    run = wandb.init()
    cfg = run.config
    
    # Set seeds for reproducibility
    initialize_random_seeds(cfg.seed if hasattr(cfg, 'seed') else 42)
    
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Create a unique run name based on config
    run_name = f"{cfg.cell}_{cfg.enc_layers}l_{cfg.emb_size}e_{cfg.hidden_size}h_" \
               f"{'bid' if cfg.bidirectional else 'uni'}_{cfg.dropout}d_" \
               f"{cfg.teacher_forcing}tf_{cfg.optimizer}"
    wandb.run.name = run_name
    
    # Load data with Tamil language code
    data_loaders, src_vocab, tgt_vocab = prepare_data_loaders(
        'ta',  # Tamil language code
        batch_size=cfg.batch_size,
        device=device
    )
    
    # Create encoder
    encoder = EncoderNetwork(
        src_vocab.size, cfg.emb_size, cfg.hidden_size,
        cfg.enc_layers, cfg.cell, cfg.dropout, 
        use_bidirectional=cfg.bidirectional
    ).to(device)
    
    # Calculate encoder output dimension
    enc_output_dim = cfg.hidden_size * 2 if cfg.bidirectional else cfg.hidden_size
    
    # Create decoder
    decoder = DecoderNetwork(
        tgt_vocab.size, cfg.emb_size, enc_output_dim, cfg.hidden_size,
        cfg.enc_layers, cfg.cell, cfg.dropout
    ).to(device)
    
    # Combine into full model
    model = Seq2SeqModel(encoder, decoder, pad_idx=src_vocab.pad_idx, device=device).to(device)
    
    # Train the model
    best_model_path = f"model_{run_name}.pt"
    _, test_accuracy = train_sequence_model(
        model=model,
        data_loaders=data_loaders,
        src_vocab=src_vocab,
        tgt_vocab=tgt_vocab,
        device=device,
        config=cfg,
        model_save_path=best_model_path,
        enable_wandb=True
    )
    
    # Finish the run
    wandb.finish()

if __name__ == "__main__":
    # Define sweep configuration - keeping hyperparameters the same as original
    sweep_cfg = {
        'method': 'bayes',  # Use Bayesian optimization
        'name':'Tamil_Transliteration_with_Attention',
        'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
        'parameters': {
            
            # Model architecture
            'emb_size': {'values': [32,64,128, 256, 512]},
            'hidden_size': {'values': [32,64,128, 256, 512, 1024]},
            'enc_layers': {'values': [1, 2, 3, 4]},
            'cell': {'values': ['RNN', 'GRU', 'LSTM']},  
            'bidirectional': {'values': [True, False]},  # Bidirectional encode
            
            # Training parameters
            'dropout': {'values': [0.0, 0.1, 0.2, 0.3, 0.5]},
            'lr': {'values': [1e-4, 2e-4, 5e-4, 8e-4, 1e-3]},
            'batch_size': {'values': [32, 64, 128]},
            'epochs': {'values': [10, 15, 20]},
            'teacher_forcing': {'values': [0.3, 0.5, 0.7, 1.0]},  # Explicit teacher forcing
            'optimizer': {'values': ['Adam', 'NAdam']},  # Added optimizer options
            # Reproducibility
            'seed': {'values': [42, 43, 44, 45, 46]},  # Different seeds for robustness
        }
    }

In [None]:
sweep_id = wandb.sweep(
    sweep_cfg,
    project='DA6401_A3'
)

# Run sweep agent
wandb.agent(sweep_id, function=objective, count=30)