In [3]:
# Downloading dakshina dataset
!yes | wget "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"

--2025-05-20 05:19:56--  https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.134.207, 172.217.203.207, 172.217.204.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.134.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2008340480 (1.9G) [application/x-tar]
Saving to: ‘dakshina_dataset_v1.0.tar’


2025-05-20 05:20:02 (313 MB/s) - ‘dakshina_dataset_v1.0.tar’ saved [2008340480/2008340480]

yes: standard output: Broken pipe


In [4]:
# Unzipping dataset
!yes | tar xopf dakshina_dataset_v1.0.tar

yes: standard output: Broken pipe


In [8]:
wandb.login(key='49f8f505158ee3693f0cacf0a82118bd4e636e8c')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msurendarmohan283[0m ([33msurendarmohan283-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random

# For reproducibility
def set_random_seeds(seed_value=42):
    """
    Set deterministic behavior for reproducibility across runs.
    Configures random number generators for all libraries.
    """
    # Set Python's built-in random module seed
    random.seed(seed_value)
    
    # Set PyTorch CPU operations seed
    torch.manual_seed(seed_value)
    
    # Set PyTorch GPU operations seeds if available
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    
    # Disable CUDA backend auto-tuning
    torch.backends.cudnn.benchmark = False

class SourceEncoder(nn.Module):
    """Encodes the input sequence into a context representation."""
    def __init__(self, vocabulary_size, embedding_dim, hidden_dim, num_layers=1, 
                 rnn_type='LSTM', dropout_rate=0.0, use_bidirectional=False):
        super().__init__()
        # Create embedding layer
        self.token_embeddings = nn.Embedding(vocabulary_size, embedding_dim)
        
        # Store configuration
        self.is_bidirectional = use_bidirectional
        self.rnn_type = rnn_type
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        
        # Calculate output dimension
        self.context_dim = hidden_dim * 2 if use_bidirectional else hidden_dim
        
        # Select RNN implementation
        rnn_options = {'LSTM': nn.LSTM, 'GRU': nn.GRU, 'RNN': nn.RNN}
        rnn_class = rnn_options[rnn_type]
        
        # Create recurrent layer
        self.recurrent = rnn_class(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout_rate if num_layers > 1 else 0.0,
            batch_first=True,
            bidirectional=use_bidirectional
        )

    def forward(self, input_sequence, sequence_lengths):
        """
        Process source sequence and return context representation.
        
        Args:
            input_sequence: Tensor of token IDs [batch_size, seq_len]
            sequence_lengths: Tensor of sequence lengths [batch_size]
            
        Returns:
            context_states: Sequence of context states [batch_size, seq_len, hidden_dim*dirs]
            final_state: Final hidden state(s) for decoder initialization
        """
        # Get embeddings for input tokens
        embedded_seq = self.token_embeddings(input_sequence)
        
        # Pack sequence to handle variable lengths
        packed_input = pack_padded_sequence(
            embedded_seq, sequence_lengths.cpu(), 
            batch_first=True, enforce_sorted=False
        )
        
        # Pass through recurrent network
        packed_outputs, final_states = self.recurrent(packed_input)
        
        # Unpack outputs
        context_states, _ = pad_packed_sequence(packed_outputs, batch_first=True)
        
        # Process final states for bidirectional case
        if self.is_bidirectional:
            if self.rnn_type == 'LSTM':
                # Process both hidden and cell states
                h_final, c_final = final_states
                # Average forward and backward states for each layer
                h_final = torch.mean(
                    torch.stack([h_final[:self.num_layers], h_final[self.num_layers:]]), dim=0
                )
                c_final = torch.mean(
                    torch.stack([c_final[:self.num_layers], c_final[self.num_layers:]]), dim=0
                )
                final_states = (h_final, c_final)
            else:
                # For GRU/RNN, process only hidden state
                final_states = torch.mean(
                    torch.stack([final_states[:self.num_layers], final_states[self.num_layers:]]), dim=0
                )
                
        return context_states, final_states


class AdditiveAttention(nn.Module):
    """
    Implements Bahdanau-style additive attention mechanism.
    Calculates attention weights based on encoder states and decoder state.
    """
    def __init__(self, encoder_dim, decoder_dim):
        super().__init__()
        # Scoring network components
        self.score_projection = nn.Linear(encoder_dim + decoder_dim, decoder_dim)
        self.weight_vector = nn.Linear(decoder_dim, 1, bias=False)

    def forward(self, decoder_state, encoder_states, attention_mask):
        """
        Calculate attention weights for each encoder state.
        
        Args:
            decoder_state: Current decoder state [batch_size, decoder_dim]
            encoder_states: All encoder states [batch_size, seq_len, encoder_dim]
            attention_mask: Boolean mask for padding [batch_size, seq_len]
            
        Returns:
            attention_weights: Softmax distribution over encoder states [batch_size, seq_len]
        """
        # Get sequence dimensions
        batch_size, seq_len, _ = encoder_states.size()
        
        # Repeat decoder state for each position in sequence
        expanded_state = decoder_state.unsqueeze(1).expand(-1, seq_len, -1)
        
        # Concatenate encoder and decoder states
        combined = torch.cat((expanded_state, encoder_states), dim=2)
        
        # Calculate attention scores
        attention_features = torch.tanh(self.score_projection(combined))
        attention_scores = self.weight_vector(attention_features).squeeze(2)
        
        # Apply mask to prevent attention to padding
        attention_scores = attention_scores.masked_fill(~attention_mask, float('-inf'))
        
        # Get probability distribution
        attention_weights = torch.softmax(attention_scores, dim=1)
        
        return attention_weights


class TargetDecoder(nn.Module):
    """
    Decoder for generating target sequence.
    
    Supports two operating modes:
    1. With attention (using context from encoder)
    2. Without attention (simple RNN decoder)
    """
    def __init__(self, target_vocab_size, embedding_dim, encoder_dim, decoder_dim,
                 num_layers=1, rnn_type="LSTM", dropout_rate=0.0, with_attention=True):
        super().__init__()
        
        # Configuration
        self.with_attention = with_attention
        self.rnn_type = rnn_type
        
        # Token embedding layer
        self.token_embeddings = nn.Embedding(target_vocab_size, embedding_dim)
        
        # Determine input dimensions based on attention usage
        if with_attention:
            # Create attention mechanism
            self.attention_module = AdditiveAttention(encoder_dim, decoder_dim)
            
            # When using attention, we concatenate context vector with token embedding
            recurrent_input_dim = embedding_dim + encoder_dim
            
            # Output combines decoder state, context vector, and current embedding
            output_input_dim = decoder_dim + encoder_dim + embedding_dim
        else:
            # Without attention, just use embeddings as input
            recurrent_input_dim = embedding_dim
            
            # Output combines decoder state and current embedding
            output_input_dim = decoder_dim + embedding_dim

        # Select appropriate RNN implementation
        rnn_options = {"LSTM": nn.LSTM, "GRU": nn.GRU, "RNN": nn.RNN}
        selected_rnn = rnn_options[rnn_type]
        
        # Create recurrent network
        self.recurrent = selected_rnn(
            input_size=recurrent_input_dim,
            hidden_size=decoder_dim,
            num_layers=num_layers,
            dropout=dropout_rate if num_layers > 1 else 0.0,
            batch_first=True
        )
        
        # Output projection
        self.output_projection = nn.Linear(output_input_dim, target_vocab_size)

    def forward(self, current_token, hidden_state, encoder_states, padding_mask):
        """
        Process one decoding step.
        
        Args:
            current_token: Current input token ids [batch_size]
            hidden_state: Previous decoder hidden state
            encoder_states: All encoder hidden states [batch_size, src_len, encoder_dim]
            padding_mask: Mask for padding in source [batch_size, src_len]
            
        Returns:
            token_logits: Probability distribution over target vocabulary
            new_hidden: Updated hidden state
            attention_weights: Attention distribution over source (or None)
        """
        # Embed current token and add time dimension
        token_embedding = self.token_embeddings(current_token).unsqueeze(1)  # [B,1,E]

        # Process differently based on attention mode
        if self.with_attention:
            # Extract current hidden state for attention
            if self.rnn_type == 'LSTM':
                # For LSTM, get the hidden state (not cell state)
                current_decoder_state = hidden_state[0][-1]
            else:
                # For GRU/RNN, directly use the hidden state
                current_decoder_state = hidden_state[-1]
            
            # Calculate attention weights
            attention_weights = self.attention_module(
                current_decoder_state, encoder_states, padding_mask
            )
            
            # Get context vector by weighted sum of encoder states
            context_vector = torch.bmm(
                attention_weights.unsqueeze(1), encoder_states
            )
            
            # Combine token embedding with context for recurrent input
            recurrent_input = torch.cat((token_embedding, context_vector), dim=2)
        else:
            # Without attention, just use token embedding
            context_vector = None
            attention_weights = None
            recurrent_input = token_embedding

        # Process through RNN
        recurrent_output, new_hidden = self.recurrent(recurrent_input, hidden_state)
        
        # Remove time dimension
        recurrent_output = recurrent_output.squeeze(1)  # [B,H]
        token_embedding = token_embedding.squeeze(1)    # [B,E]

        # Generate output logits
        if self.with_attention:
            # Combine decoder output, context vector, and token embedding
            context_vector = context_vector.squeeze(1)  # [B,C]
            token_logits = self.output_projection(
                torch.cat((recurrent_output, context_vector, token_embedding), dim=1)
            )
        else:
            # Combine decoder output and token embedding
            token_logits = self.output_projection(
                torch.cat((recurrent_output, token_embedding), dim=1)
            )

        return token_logits, new_hidden, attention_weights


class TransliterationModel(nn.Module):
    """
    Sequence-to-sequence model for transliteration tasks.
    Combines encoder and decoder with optional attention.
    """
    def __init__(self, source_encoder, target_decoder, padding_idx, device_name='cpu'):
        super().__init__()
        self.source_encoder = source_encoder
        self.target_decoder = target_decoder
        self.padding_idx = padding_idx
        self.device = device_name

    def forward(self, source_tokens, source_lengths, target_tokens, teacher_forcing_prob=0.5):
        """
        Training forward pass with teacher forcing.
        
        Args:
            source_tokens: Input token IDs [batch_size, src_len]
            source_lengths: Length of each input sequence [batch_size]
            target_tokens: Target token IDs [batch_size, tgt_len]
            teacher_forcing_prob: Probability of using teacher forcing
            
        Returns:
            prediction_logits: Logits for each target position [batch_size, tgt_len-1, vocab_size]
        """
        # Encode source sequence
        encoder_states, encoder_final = self.source_encoder(source_tokens, source_lengths)
        
        # Create mask for attention
        attention_mask = (source_tokens != self.padding_idx)
        
        # Prepare storage for decoder outputs
        batch_size, target_length = target_tokens.size()
        vocab_size = self.target_decoder.output_projection.out_features
        prediction_logits = torch.zeros(batch_size, target_length-1, vocab_size, device=self.device)
        
        # Initial input is the start token
        current_input = target_tokens[:, 0]  # <sos> token
        
        # Generate each target token
        for timestep in range(1, target_length):
            # Get prediction for current position
            step_output, encoder_final, _ = self.target_decoder(
                current_input, encoder_final, encoder_states, attention_mask
            )
            
            # Store prediction
            prediction_logits[:, timestep-1] = step_output
            
            # Determine next input (teacher forcing vs. predicted)
            use_teacher_forcing = random.random() < teacher_forcing_prob
            if use_teacher_forcing:
                # Use ground truth as next input
                current_input = target_tokens[:, timestep]
            else:
                # Use model's prediction as next input
                current_input = step_output.argmax(1)
                
        return prediction_logits

    def generate_sequence(self, source_tokens, source_lengths, target_vocab, max_length=50):
        """
        Generate target sequence using greedy decoding.
        
        Args:
            source_tokens: Input token IDs [batch_size, src_len]
            source_lengths: Length of each input sequence [batch_size]
            target_vocab: Target vocabulary object
            max_length: Maximum generation length
            
        Returns:
            generated_tokens: Sequence of generated token IDs [batch_size, gen_len]
        """
        # Encode source sequence
        encoder_states, encoder_final = self.source_encoder(source_tokens, source_lengths)
        
        # Create mask for attention
        attention_mask = (source_tokens != self.padding_idx)
        
        # Initialize generation with start token
        batch_size = source_tokens.size(0)
        current_input = torch.full(
            (batch_size,), target_vocab.sos_idx, 
            device=self.device, dtype=torch.long
        )
        
        # Store for generated tokens
        generated_tokens = []
        
        # Generate sequence
        for _ in range(max_length):
            # Get prediction for current position
            step_output, encoder_final, _ = self.target_decoder(
                current_input, encoder_final, encoder_states, attention_mask
            )
            
            # Select most likely token
            current_input = step_output.argmax(1)
            
            # Store generated token
            generated_tokens.append(current_input.unsqueeze(1))
            
            # Check if all sequences have generated end token
            if (current_input == target_vocab.eos_idx).all():
                break
                
        # Combine all generated tokens
        return torch.cat(generated_tokens, dim=1)


# Character-level vocabulary handling
class CharacterVocabulary:
    """
    Character-level vocabulary for transliteration tasks.
    Maps characters to integer indices and vice versa.
    """
    def __init__(self, character_list=None, special_tokens=['<pad>', '<sos>', '<eos>', '<unk>']):
        self.special_tokens = special_tokens
        
        # Initialize mapping dictionaries
        self.index_to_char = list(special_tokens) + (character_list or [])
        self.char_to_index = {char: idx for idx, char in enumerate(self.index_to_char)}
    
    @classmethod
    def build_from_text_corpus(cls, text_corpus):
        """
        Build vocabulary from a corpus of texts.
        
        Args:
            text_corpus: List of text strings
            
        Returns:
            CharacterVocabulary: New vocabulary instance
        """
        # Collect unique characters from all texts
        unique_chars = sorted(set(char for text in text_corpus for char in text))
        return cls(character_list=unique_chars)
    
    @classmethod
    def build_from_data_file(cls, file_path, source_col='src', target_col='trg', file_format='tsv'):
        """
        Build vocabulary from a data file.
        
        Args:
            file_path: Path to the data file
            source_col: Name/index of source column
            target_col: Name/index of target column
            file_format: File format ('tsv' or 'csv')
            
        Returns:
            CharacterVocabulary: New vocabulary instance
        """
        texts = []
        
        if file_format == 'csv':
            import pandas as pd
            df = pd.read_csv(file_path, header=None, names=[source_col, target_col])
            texts = df[source_col].dropna().tolist() + df[target_col].dropna().tolist()
        else:  # tsv
            with open(file_path, encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        texts.extend([parts[0], parts[1]])
        
        return cls.build_from_text_corpus(texts)
    
    def save(self, file_path):
        """Save vocabulary to JSON file."""
        import json
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(self.index_to_char, f, ensure_ascii=False)
    
    @classmethod
    def load(cls, file_path):
        """Load vocabulary from JSON file."""
        import json
        with open(file_path, encoding='utf-8') as f:
            index_to_char = json.load(f)
        
        # Create instance with empty character list
        instance = cls(character_list=[])
        
        # Replace mapping dictionaries
        instance.index_to_char = index_to_char
        instance.char_to_index = {char: idx for idx, char in enumerate(index_to_char)}
        
        return instance
    
    def convert_text_to_indices(self, text, add_start=False, add_end=False):
        """
        Convert text string to sequence of token indices.
        
        Args:
            text: Input text string
            add_start: Whether to add start-of-sequence token
            add_end: Whether to add end-of-sequence token
            
        Returns:
            list: Sequence of token indices
        """
        indices = []
        
        # Add start token if requested
        if add_start:
            indices.append(self.char_to_index['<sos>'])
        
        # Convert each character to its index
        for char in text:
            # Use unknown token index for out-of-vocabulary characters
            indices.append(self.char_to_index.get(char, self.char_to_index['<unk>']))
        
        # Add end token if requested
        if add_end:
            indices.append(self.char_to_index['<eos>'])
            
        return indices
    
    def convert_indices_to_text(self, indices, remove_special=True, join_chars=True):
        """
        Convert sequence of indices back to text.
        
        Args:
            indices: Sequence of token indices
            remove_special: Whether to remove special tokens
            join_chars: Whether to join characters into a string
            
        Returns:
            str or list: Decoded text as string or character list
        """
        # Convert tensor to list if needed
        if hasattr(indices, 'tolist'):
            indices = indices.tolist()
        
        # Convert indices to characters, filtering out-of-range indices
        characters = [self.index_to_char[idx] for idx in indices if idx < len(self.index_to_char)]
        
        # Remove special tokens if requested
        if remove_special:
            characters = [char for char in characters if char not in self.special_tokens]
        
        # Return as string or list
        return ''.join(characters) if join_chars else characters
    
    def batch_decode(self, batch_indices, remove_special=True):
        """
        Decode a batch of index sequences.
        
        Args:
            batch_indices: Batch of index sequences
            remove_special: Whether to remove special tokens
            
        Returns:
            list: List of decoded strings
        """
        return [
            self.convert_indices_to_text(sequence, remove_special=remove_special) 
            for sequence in batch_indices
        ]
    
    def get_statistics(self):
        """Get vocabulary statistics."""
        return {
            'total_size': len(self.index_to_char),
            'special_token_count': len(self.special_tokens),
            'character_count': len(self.index_to_char) - len(self.special_tokens)
        }
    
    def __len__(self):
        return len(self.index_to_char)
    
    # Convenience properties for special token indices
    @property
    def pad_idx(self): 
        return self.char_to_index['<pad>']
    
    @property
    def sos_idx(self): 
        return self.char_to_index['<sos>']
    
    @property
    def eos_idx(self): 
        return self.char_to_index['<eos>']
    
    @property
    def unk_idx(self): 
        return self.char_to_index['<unk>']
    
    @property
    def size(self): 
        return len(self.index_to_char)


# Data loading functionality
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import pickle


class TransliterationDataset(Dataset):
    """Dataset for transliteration tasks."""
    
    def __init__(self, data_path, source_vocab, target_vocab, dataset_type='dakshina'):
        """
        Initialize dataset.
        
        Args:
            data_path: Path to the data file
            source_vocab: Source vocabulary
            target_vocab: Target vocabulary
            dataset_type: Dataset format
        """
        self.examples = []
        self.dataset_type = dataset_type
        
        if dataset_type == 'dakshina':
            # Load examples from Dakshina format (tab-separated)
            for source_text, target_text in self._read_dakshina_format(data_path):
                # Convert text to indices and prepare tensors
                source_indices = source_vocab.convert_text_to_indices(
                    source_text, add_start=True, add_end=True
                )
                target_indices = target_vocab.convert_text_to_indices(
                    target_text, add_start=True, add_end=True
                )
                
                # Store as tensors
                self.examples.append((
                    torch.tensor(source_indices, dtype=torch.long),
                    torch.tensor(target_indices, dtype=torch.long)
                ))
        else:
            raise ValueError(f"Unsupported dataset type: {dataset_type}")
    
    def _read_dakshina_format(self, file_path):
        """Read Dakshina format file (tab-separated)."""
        with open(file_path, encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    # In Dakshina, format is: native_script, latin_script
                    yield parts[1], parts[0]
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]


def collate_with_padding(batch, source_vocab, target_vocab):
    """
    Collate function to handle variable-length sequences.
    Pads sequences to the maximum length in the batch.
    """
    # Unzip the batch of examples
    source_sequences, target_sequences = zip(*batch)
    
    # Pad sequences to the same length
    padded_sources = pad_sequence(
        source_sequences, batch_first=True, 
        padding_value=source_vocab.pad_idx
    )
    padded_targets = pad_sequence(
        target_sequences, batch_first=True, 
        padding_value=target_vocab.pad_idx
    )
    
    # Store original sequence lengths for packed sequence
    source_lengths = torch.tensor([len(seq) for seq in source_sequences], dtype=torch.long)
    
    return padded_sources, source_lengths, padded_targets


def load_transliteration_datasets(
    language_code='ta',  # Tamil
    dataset_format='dakshina',
    data_root=None,
    batch_size=64,
    device_name='cpu',
    num_workers=2,
    prefetch_factor=4,
    persistent_workers=True,
    cache_directory='./cache',
    use_cached_vocabulary=True
):
    """
    Load transliteration datasets for a language.
    
    Args:
        language_code: Language code (e.g., 'ta' for Tamil)
        dataset_format: Dataset format
        data_root: Override default dataset path
        batch_size: Batch size for dataloaders
        device_name: Device for data loading
        num_workers: Number of data loading workers
        prefetch_factor: Batches to prefetch per worker
        persistent_workers: Keep workers alive between epochs
        cache_directory: Directory for vocabulary cache
        use_cached_vocabulary: Whether to use cached vocabulary
        
    Returns:
        tuple: (dataloaders, source_vocab, target_vocab)
    """
    # Set up data paths
    if data_root is None:
        data_root = os.path.join(
            '/kaggle/working/dakshina_dataset_v1.0',
            language_code, 'lexicons'
        )
    
    # Set up vocabulary cache
    if use_cached_vocabulary:
        os.makedirs(cache_directory, exist_ok=True)
        vocab_cache_path = os.path.join(
            cache_directory, f"{language_code}_{dataset_format}_vocab.pkl"
        )
    
    # Try to load cached vocabularies
    if use_cached_vocabulary and os.path.exists(vocab_cache_path):
        print(f"Loading cached vocabularies from {vocab_cache_path}")
        with open(vocab_cache_path, 'rb') as f:
            source_vocab, target_vocab = pickle.load(f)
    else:
        # Build vocabularies from training and validation data
        source_texts, target_texts = [], []
        
        # Process training and validation data for vocabulary creation
        for split in ['train', 'dev']:
            file_path = os.path.join(
                data_root, f"{language_code}.translit.sampled.{split}.tsv"
            )
            
            # Read file in Dakshina format
            with open(file_path, encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        # In Dakshina: native_script, latin_script
                        target_texts.append(parts[0])
                        source_texts.append(parts[1])
        
        # Create vocabularies
        source_vocab = CharacterVocabulary.build_from_text_corpus(source_texts)
        target_vocab = CharacterVocabulary.build_from_text_corpus(target_texts)
        
        # Cache vocabularies
        if use_cached_vocabulary:
            with open(vocab_cache_path, 'wb') as f:
                pickle.dump((source_vocab, target_vocab), f)
    
    # Common DataLoader configurations
    loader_kwargs = {
        'batch_size': batch_size,
        'num_workers': num_workers,
        'prefetch_factor': prefetch_factor,
        'persistent_workers': persistent_workers and num_workers > 0,
        'pin_memory': (device_name == 'cuda')
    }
    
    # Create data loaders for each split
    dataloaders = {}
    
    # Map split names to file suffixes
    split_mapping = {'train': 'train', 'dev': 'dev', 'test': 'test'}
    
    for split_name, file_suffix in split_mapping.items():
        file_path = os.path.join(
            data_root, f"{language_code}.translit.sampled.{file_suffix}.tsv"
        )
        
        # Create dataset
        dataset = TransliterationDataset(
            file_path, source_vocab, target_vocab, dataset_type='dakshina'
        )
        
        # Create dataloader
        dataloaders[split_name] = DataLoader(
            dataset,
            shuffle=(split_name == 'train'),  # Only shuffle training data
            collate_fn=lambda batch: collate_with_padding(batch, source_vocab, target_vocab),
            **loader_kwargs
        )
    
    return dataloaders, source_vocab, target_vocab


# Training and evaluation functions
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
import csv
import wandb


def calculate_accuracy_metrics(model, dataloader, target_vocab, source_vocab, device):
    """
    Calculate accuracy and store prediction details.
    
    Returns:
        tuple: (accuracy, correct_predictions, incorrect_predictions)
    """
    model.eval()
    total_correct = 0
    total_examples = 0
    
    # Store prediction details for analysis
    correct_sources = []
    correct_targets = []
    correct_predictions = []
    
    incorrect_sources = []
    incorrect_targets = []
    incorrect_predictions = []
    
    with torch.no_grad():
        for source_batch, source_lengths, target_batch in dataloader:
            # Move data to device
            source_batch = source_batch.to(device)
            source_lengths = source_lengths.to(device)
            target_batch = target_batch.to(device)
            
            # Generate predictions
            predicted_batch = model.generate_sequence(
                source_batch, source_lengths, target_vocab, 
                max_length=target_batch.size(1)
            )

            # Evaluate each example in batch
            for idx in range(source_batch.size(0)):
                # Convert to text for comparison
                predicted_text = target_vocab.convert_indices_to_text(predicted_batch[idx])
                target_text = target_vocab.convert_indices_to_text(
                    target_batch[idx, 1:]  # Skip <sos> token
                )
                source_text = source_vocab.convert_indices_to_text(source_batch[idx])
                
                # Check if prediction matches target
                is_correct = (predicted_text == target_text)
                total_correct += int(is_correct)
                
                # Store details for analysis
                if is_correct:
                    correct_sources.append(source_text)
                    correct_targets.append(target_text)
                    correct_predictions.append(predicted_text)
                else:
                    incorrect_sources.append(source_text)
                    incorrect_targets.append(target_text)
                    incorrect_predictions.append(predicted_text)
                
            total_examples += source_batch.size(0)

    # Calculate overall accuracy
    accuracy = total_correct / total_examples if total_examples > 0 else 0.0
    
    return (
        accuracy, 
        (correct_sources, correct_targets, correct_predictions),
        (incorrect_sources, incorrect_targets, incorrect_predictions)
    )


def save_predictions_to_file(sources, targets, predictions, output_path):
    """Save prediction details to CSV file for analysis."""
    with open(output_path, mode='w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['Source', 'Target', 'Prediction'])
        writer.writerows(zip(sources, targets, predictions))
    
    return output_path


def train_transliteration_model(
    model, 
    dataloaders, 
    source_vocab, 
    target_vocab, 
    device,
    config,
    model_save_path=None,
    enable_wandb_logging=True
):
    """
    Train transliteration model.
    
    Args:
        model: Transliteration model
        dataloaders: Dictionary of data loaders
        source_vocab: Source vocabulary
        target_vocab: Target vocabulary
        device: Device to use
        config: Training configuration
        model_save_path: Path to save best model
        enable_wandb_logging: Whether to log to W&B
        
    Returns:
        tuple: (trained_model, test_accuracy)
    """
    # Define loss function
    loss_function = nn.CrossEntropyLoss(ignore_index=target_vocab.pad_idx)
    
    # Configure optimizer
    if config.optimizer.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer.lower() == 'nadam':
        optimizer = optim.NAdam(model.parameters(), lr=config.lr)
    else:
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    
    # Track best validation performance
    best_validation_accuracy = 0.0
    
    # Training loop
    for epoch in tqdm(range(1, config.epochs + 1), desc="Epochs", position=0):
        # Training phase
        model.train()
        epoch_loss = 0.0

        # Process training batches
        train_loader = tqdm(
            dataloaders['train'], desc=f"Training epoch {epoch}", 
            leave=False, position=1
        )
        
        for source_batch, source_lengths, target_batch in train_loader:
            # Move data to device
            source_batch = source_batch.to(device)
            source_lengths = source_lengths.to(device)
            target_batch = target_batch.to(device)

            # Forward pass with teacher forcing
            optimizer.zero_grad()
            output_logits = model(
                source_batch, source_lengths, target_batch, 
                teacher_forcing_prob=config.teacher_forcing
            )
            
            # Calculate loss
            # Reshape outputs and targets to match cross-entropy requirements
            loss = loss_function(
                output_logits.reshape(-1, output_logits.size(-1)),
                target_batch[:, 1:].reshape(-1)  # Skip <sos> token
            )
            
            # Backward pass and optimize
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
            optimizer.step()
            
            epoch_loss += loss.item()
        
        train_loader.close()
        average_train_loss = epoch_loss / len(dataloaders['train'])

        # Validation phase
        validation_loss = 0.0
        validation_loader = tqdm(
            dataloaders['dev'], desc=f"Validation epoch {epoch}", 
            leave=False, position=1
        )
        
        model.eval()
        with torch.no_grad():
            for source_batch, source_lengths, target_batch in validation_loader:
                # Move data to device
                source_batch = source_batch.to(device)
                source_lengths = source_lengths.to(device)
                target_batch = target_batch.to(device)
                
                # Forward pass without teacher forcing
                output_logits = model(
                    source_batch, source_lengths, target_batch, 
                    teacher_forcing_prob=0.0
                )
                
                # Calculate loss
                loss = loss_function(
                    output_logits.reshape(-1, output_logits.size(-1)),
                    target_batch[:, 1:].reshape(-1)
                )
                
                validation_loss += loss.item()
        
        validation_loader.close()
        average_validation_loss = validation_loss / len(dataloaders['dev'])

        # Calculate accuracy metrics
        train_results = calculate_accuracy_metrics(
            model, dataloaders['train'], target_vocab, source_vocab, device
        )
        train_accuracy = train_results[0]
        
        validation_results = calculate_accuracy_metrics(
            model, dataloaders['dev'], target_vocab, source_vocab, device
        )
        validation_accuracy = validation_results[0]
        
        # Save best model
        if validation_accuracy > best_validation_accuracy and model_save_path:
            best_validation_accuracy = validation_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved new best model with validation accuracy: {validation_accuracy:.4f}")
            
            # Save prediction analysis periodically
            if epoch == config.epochs or epoch % 5 == 0:
                # Save correct predictions
                correct_data = validation_results[1]
                save_predictions_to_file(
                    correct_data[0], correct_data[1], correct_data[2],
                    f"correct_predictions_epoch_{epoch}.csv"
                )
                
                # Save incorrect predictions
                incorrect_data = validation_results[2]
                save_predictions_to_file(
                    incorrect_data[0], incorrect_data[1], incorrect_data[2],
                    f"incorrect_predictions_epoch_{epoch}.csv"
                )

        # Log metrics
        print(f"Epoch {epoch}/{config.epochs}:")
        print(f"  Train Loss: {average_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"  Validation Loss: {average_validation_loss:.4f}, Validation Accuracy: {validation_accuracy:.4f}")
        
        # Log to W&B if enabled
        if enable_wandb_logging:
            wandb.log({
                'epoch': epoch,
                'train_loss': average_train_loss,
                'validation_loss': average_validation_loss,
                'train_accuracy': train_accuracy,
                'validation_accuracy': validation_accuracy
            })

    # For simplicity, we skip test set evaluation here
    test_accuracy = 0
    
    return model, test_accuracy


# Hyperparameter sweep configuration
import wandb
import torch
from tqdm.auto import tqdm
import os
import random
import numpy as np


def run_model_with_config():
    """Run a model training with the current W&B configuration."""
    # Initialize W&B run
    run = wandb.init()
    config = run.config
    
    # Set seeds for reproducibility
    set_random_seeds(config.seed if hasattr(config, 'seed') else 42)
    
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Create a unique run name based on configuration
    run_name = (
        f"{config.cell_type}_{config.encoder_layers}l_{config.embedding_dim}e_"
        f"{config.hidden_dim}h_{'bid' if config.bidirectional else 'uni'}_"
        f"{config.dropout}d_{config.teacher_forcing}tf_{config.optimizer}"
    )
    wandb.run.name = run_name
    
    # Load datasets
    dataloaders, source_vocab, target_vocab = load_transliteration_datasets(
        language_code='ta',  # Tamil
        batch_size=config.batch_size,
        device_name=device
    )
    
    # Create encoder
    encoder = SourceEncoder(
        vocabulary_size=source_vocab.size,
        embedding_dim=config.embedding_dim,
        hidden_dim=config.hidden_dim,
        num_layers=config.encoder_layers,
        rnn_type=config.cell_type,
        dropout_rate=config.dropout,
        use_bidirectional=config.bidirectional
    ).to(device)
    
    # Calculate encoder output dimension
    encoder_output_dim = config.hidden_dim * 2 if config.bidirectional else config.hidden_dim
    
    # Create decoder
    decoder = TargetDecoder(
        target_vocab_size=target_vocab.size,
        embedding_dim=config.embedding_dim,
        encoder_dim=encoder_output_dim,
        decoder_dim=config.hidden_dim,
        num_layers=config.encoder_layers,
        rnn_type=config.cell_type,
        dropout_rate=config.dropout
    ).to(device)
    
    # Create full model
    model = TransliterationModel(
        encoder, decoder, 
        padding_idx=source_vocab.pad_idx, 
        device_name=device
    ).to(device)
    
    # Train model
    best_model_path = f"model_{run_name}.pt"
    _, test_accuracy = train_transliteration_model(
        model=model,
        dataloaders=dataloaders,
        source_vocab=source_vocab,
        target_vocab=target_vocab,
        device=device,
        config=config,
        model_save_path=best_model_path,
        enable_wandb_logging=True
    )
    
    # Finish the run
    wandb.finish()


# Main script to run sweep
if __name__ == "__main__":
    # Define sweep configuration
    sweep_cfg = {
        
        'method': 'bayes',  # Use Bayesian optimization
        'name':'Transliteration_without_Attention',
        'metric': {'name': 'val_acc', 'goal': 'maximize'},
        'parameters': {
            
            # Model architecture
            'emb_size': {'values': [32,64,128, 256, 512]},
            'hidden_size': {'values': [32,64,128, 256, 512, 1024]},
            'enc_layers': {'values': [1, 2, 3, 4]},
            'cell': {'values': ['RNN', 'GRU', 'LSTM']},  
            'bidirectional': {'values': [True, False]},  # Bidirectional encode
            
            # Training parameters
            'dropout': {'values': [0.0, 0.1, 0.2, 0.3, 0.5]},
            'lr': {'values': [1e-4, 2e-4, 5e-4, 8e-4, 1e-3]},
            'batch_size': {'values': [32, 64, 128]},
            'epochs': {'values': [10, 15, 20]},
            'teacher_forcing': {'values': [0.3, 0.5, 0.7, 1.0]},  # Explicit teacher forcing
            'optimizer': {'values': ['Adam', 'NAdam']},  # Added optimizer options
            # Reproducibility
            'seed': {'values': [42, 43, 44, 45, 46]},  # Different seeds for robustness
        }
    }

    # Initialize and run sweep
    sweep_id = wandb.sweep(
        sweep_config,
        project='DA6401_A3'
    )
    
    # Run sweep agent
    wandb.agent(sweep_id, function=run_model_with_config, count=30)

Using device: cuda
Loading cached vocabularies from ./cache/te_dakshina_vocab.pkl


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Train 1:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 1:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.3924
Epoch 1/10:
  Train Loss: 1.1262, Train Acc: 0.4864
  Val Loss: 0.7815, Val Acc: 0.3924


Train 2:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 2:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4000
Epoch 2/10:
  Train Loss: 0.5373, Train Acc: 0.4988
  Val Loss: 0.7040, Val Acc: 0.4000


Train 3:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 3:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4695
Epoch 3/10:
  Train Loss: 0.4352, Train Acc: 0.6503
  Val Loss: 0.6903, Val Acc: 0.4695


Train 4:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 4:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 4/10:
  Train Loss: 0.3838, Train Acc: 0.6760
  Val Loss: 0.6739, Val Acc: 0.4695


Train 5:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 5:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 5/10:
  Train Loss: 0.3589, Train Acc: 0.6470
  Val Loss: 0.6958, Val Acc: 0.4491


Train 6:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 6:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4901
Epoch 6/10:
  Train Loss: 0.3403, Train Acc: 0.7171
  Val Loss: 0.6802, Val Acc: 0.4901


Train 7:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 7:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 7/10:
  Train Loss: 0.3299, Train Acc: 0.6936
  Val Loss: 0.6989, Val Acc: 0.4705


Train 8:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 8:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4918
Epoch 8/10:
  Train Loss: 0.3206, Train Acc: 0.7254
  Val Loss: 0.6857, Val Acc: 0.4918


Train 9:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 9:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 9/10:
  Train Loss: 0.3195, Train Acc: 0.6401
  Val Loss: 0.6949, Val Acc: 0.4552


Train 10:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 10:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 10/10:
  Train Loss: 0.3170, Train Acc: 0.7376
  Val Loss: 0.7352, Val Acc: 0.4760


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_acc,▁▁▆▆▅▇▇█▅█
train_loss,█▃▂▂▁▁▁▁▁▁
val_acc,▁▂▆▆▅█▇█▅▇
val_loss,█▃▂▁▂▁▃▂▂▅

0,1
epoch,10.0
train_acc,0.73763
train_loss,0.31698
val_acc,0.47598
val_loss,0.73522


Using device: cuda
Loading cached vocabularies from ./cache/te_dakshina_vocab.pkl


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Train 1:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 1:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.3924
Epoch 1/10:
  Train Loss: 0.8821, Train Acc: 0.4717
  Val Loss: 0.7612, Val Acc: 0.3924


Train 2:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 2:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 2/10:
  Train Loss: 0.3274, Train Acc: 0.4733
  Val Loss: 0.6899, Val Acc: 0.3621


Train 3:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 3:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.3970
Epoch 3/10:
  Train Loss: 0.2234, Train Acc: 0.5558
  Val Loss: 0.7137, Val Acc: 0.3970


Train 4:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 4:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4198
Epoch 4/10:
  Train Loss: 0.1635, Train Acc: 0.5828
  Val Loss: 0.7027, Val Acc: 0.4198


Train 5:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 5:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 5/10:
  Train Loss: 0.1273, Train Acc: 0.3911
  Val Loss: 0.7471, Val Acc: 0.3097


Train 6:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 6:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 6/10:
  Train Loss: 0.0999, Train Acc: 0.5882
  Val Loss: 0.7745, Val Acc: 0.3817


Train 7:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 7:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 7/10:
  Train Loss: 0.0814, Train Acc: 0.5473
  Val Loss: 0.7889, Val Acc: 0.3613


Train 8:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 8:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 8/10:
  Train Loss: 0.0717, Train Acc: 0.5645
  Val Loss: 0.8248, Val Acc: 0.3616


Train 9:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 9:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4394
Epoch 9/10:
  Train Loss: 0.0616, Train Acc: 0.7116
  Val Loss: 0.8680, Val Acc: 0.4394


Train 10:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 10:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4531
Epoch 10/10:
  Train Loss: 0.0554, Train Acc: 0.7633
  Val Loss: 0.8919, Val Acc: 0.4531


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_acc,▃▃▄▅▁▅▄▄▇█
train_loss,█▃▂▂▂▁▁▁▁▁
val_acc,▅▄▅▆▁▅▄▄▇█
val_loss,▃▁▂▁▃▄▄▆▇█

0,1
epoch,10.0
train_acc,0.76331
train_loss,0.05543
val_acc,0.45311
val_loss,0.89188


Using device: cuda
Loading cached vocabularies from ./cache/te_dakshina_vocab.pkl


Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Train 1:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 1:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.2958
Epoch 1/20:
  Train Loss: 0.7627, Train Acc: 0.3480
  Val Loss: 0.8287, Val Acc: 0.2958


Train 2:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 2:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 2/20:
  Train Loss: 0.2182, Train Acc: 0.2371
  Val Loss: 0.7490, Val Acc: 0.1943


Train 3:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 3:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 3/20:
  Train Loss: 0.1369, Train Acc: 0.3998
  Val Loss: 0.7798, Val Acc: 0.2736


Train 4:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 4:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 4/20:
  Train Loss: 0.0935, Train Acc: 0.3428
  Val Loss: 0.8113, Val Acc: 0.2395


Train 5:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 5:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 5/20:
  Train Loss: 0.0697, Train Acc: 0.3405
  Val Loss: 0.8425, Val Acc: 0.2515


Train 6:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 6:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 6/20:
  Train Loss: 0.0575, Train Acc: 0.3145
  Val Loss: 0.8661, Val Acc: 0.2277


Train 7:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 7:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 7/20:
  Train Loss: 0.0490, Train Acc: 0.3029
  Val Loss: 0.9856, Val Acc: 0.2259


Train 8:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 8:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.3188
Epoch 8/20:
  Train Loss: 0.0430, Train Acc: 0.4541
  Val Loss: 0.9369, Val Acc: 0.3188


Train 9:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 9:   0%|          | 0/178 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.4425
Epoch 9/20:
  Train Loss: 0.0390, Train Acc: 0.6670
  Val Loss: 0.9226, Val Acc: 0.4425


Train 10:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 10:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 10/20:
  Train Loss: 0.0343, Train Acc: 0.4532
  Val Loss: 0.9878, Val Acc: 0.2898


Train 11:   0%|          | 0/1830 [00:00<?, ?it/s]

Val 11:   0%|          | 0/178 [00:00<?, ?it/s]

Epoch 11/20:
  Train Loss: 0.0336, Train Acc: 0.4677
  Val Loss: 1.0406, Val Acc: 0.2967


Train 12:   0%|          | 0/1830 [00:00<?, ?it/s]