In [19]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random

# For reproducibility
def seed_everything(seed=42):
    """Set random seed for all major libraries"""
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hid_size, layers=1, cell='LSTM', dropout=0.0, bidirectional=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.bidirectional = bidirectional
        self.cell_type = cell
        self.layers = layers
        self.hidden_size = hid_size
        
        # Output size will be doubled if bidirectional
        self.output_size = hid_size * 2 if bidirectional else hid_size
        
        rnn_cls = {'LSTM': nn.LSTM, 'GRU': nn.GRU, 'RNN': nn.RNN}[cell]
        self.rnn = rnn_cls(emb_size,
                         hid_size,
                         num_layers=layers,
                         dropout=dropout if layers>1 else 0.0,
                         batch_first=True,
                         bidirectional=bidirectional)

    def forward(self, src, lengths):
        # src: [B, T], lengths: [B]
        embedded = self.embedding(src)  # [B, T, E]
        packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, hidden = self.rnn(packed)
        outputs, _ = pad_packed_sequence(packed_out, batch_first=True)  # [B, T, H*dirs]
        
        # If bidirectional, we need to process hidden state properly
        if self.bidirectional:
            if self.cell_type == 'LSTM':
                # For LSTM we have both hidden and cell states
                h_n, c_n = hidden
                # Combine forward and backward states by averaging
                h_n = torch.add(h_n[0:self.layers], h_n[self.layers:]) / 2
                c_n = torch.add(c_n[0:self.layers], c_n[self.layers:]) / 2
                hidden = (h_n, c_n)
            else:
                # For GRU/RNN we only have hidden state
                hidden = torch.add(hidden[0:self.layers], hidden[self.layers:]) / 2
                
        return outputs, hidden


class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid, dec_hid):
        super().__init__()
        self.attn = nn.Linear(enc_hid + dec_hid, dec_hid)
        self.v = nn.Linear(dec_hid, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask):
        # hidden: [B, H], encoder_outputs: [B, T, H], mask: [B, T]
        B, T, H = encoder_outputs.size()
        hidden = hidden.unsqueeze(1).repeat(1, T, 1)               # [B, T, H]
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [B, T, H]
        scores = self.v(energy).squeeze(2)                        # [B, T]
        scores = scores.masked_fill(~mask, -1e9)
        return torch.softmax(scores, dim=1)                       # [B, T]


class Decoder(nn.Module):
    """
    One class, two modes:
        • use_attn=True  – Bahdanau attention (default)
        • use_attn=False – Plain RNN decoder (no attention)

    Forward always returns (logits, hidden, attn_weights_or_None),
    so Seq2Seq code stays unchanged.
    """
    def __init__(self, vocab_size, emb_size, enc_hid, dec_hid,
                 layers=1, cell="LSTM", dropout=0.0, use_attn=False):
        super().__init__()
        self.use_attn = use_attn
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.cell_type = cell

        # ----- dimensions depend on whether we concatenate context -----
        if use_attn:
            self.attention = BahdanauAttention(enc_hid, dec_hid)
            rnn_input_dim = emb_size + enc_hid            # [E ⊕ Henc]
            fc_input_dim  = dec_hid + enc_hid + emb_size  # [Hdec ⊕ Henc ⊕ E]
        else:
            rnn_input_dim = emb_size
            fc_input_dim  = dec_hid + emb_size

        rnn_cls = {"LSTM": nn.LSTM, "GRU": nn.GRU, "RNN": nn.RNN}[cell]
        self.rnn = rnn_cls(rnn_input_dim, dec_hid,
                           num_layers=layers,
                           dropout=dropout if layers > 1 else 0.0,
                           batch_first=True)
        self.fc = nn.Linear(fc_input_dim, vocab_size)

    def forward(self, input_token, hidden, encoder_outputs, mask):
        """
        input_token : [B]
        hidden      : tuple|tensor  initial state for this step
        encoder_outputs : [B, Tenc, Henc]
        mask        : [B, Tenc]  (ignored when use_attn=False)
        """
        emb = self.embedding(input_token).unsqueeze(1)     # [B,1,E]

        if self.use_attn:
            # ---- additive attention ----
            if self.cell_type == 'LSTM':
                dec_h = hidden[0][-1]
            else:
                dec_h = hidden[-1]
                
            attn_w = self.attention(dec_h, encoder_outputs, mask)          # [B,Tenc]
            ctx    = torch.bmm(attn_w.unsqueeze(1), encoder_outputs)        # [B,1,Henc]
            rnn_in = torch.cat((emb, ctx), dim=2)                           # [B,1,E+Henc]
        else:
            ctx = None
            attn_w = None
            rnn_in = emb                                                    # [B,1,E]

        out, hidden = self.rnn(rnn_in, hidden)       # [B,1,Hdec]
        out = out.squeeze(1)                         # [B,Hdec]
        emb = emb.squeeze(1)                         # [B,E]

        if self.use_attn:
            ctx = ctx.squeeze(1)                     # [B,Henc]
            logits = self.fc(torch.cat((out, ctx, emb), dim=1))
        else:
            logits = self.fc(torch.cat((out, emb), dim=1))

        return logits, hidden, attn_w


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx, device='cpu'):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device

    def forward(self, src, src_lens, tgt, teacher_forcing_ratio=0.5):
        """
        Enhanced forward with explicit teacher forcing ratio control
        """
        enc_out, hidden = self.encoder(src, src_lens)
        mask = (src != self.pad_idx)
        B, T = tgt.size()
        outputs = torch.zeros(B, T-1, self.decoder.fc.out_features, device=self.device)
        input_tok = tgt[:, 0]  # <sos>
        
        for t in range(1, T):
            out, hidden, _ = self.decoder(input_tok, hidden, enc_out, mask)
            outputs[:, t-1] = out
            
            # Teacher forcing: with probability, use ground truth as next input
            # Otherwise use predicted token
            teacher_force = random.random() < teacher_forcing_ratio
            if teacher_force:
                input_tok = tgt[:, t]
            else:
                input_tok = out.argmax(1)
                
        return outputs

    def infer_greedy(self, src, src_lens, tgt_vocab, max_len=50):
        enc_out, hidden = self.encoder(src, src_lens)
        mask = (src != self.pad_idx)
        B = src.size(0)
        input_tok = torch.full((B,), tgt_vocab.sos_idx, device=self.device, dtype=torch.long)
        generated = []
        
        for _ in range(max_len):
            out, hidden, _ = self.decoder(input_tok, hidden, enc_out, mask)
            input_tok = out.argmax(1)
            generated.append(input_tok.unsqueeze(1))
            if (input_tok == tgt_vocab.eos_idx).all():
                break
                
        return torch.cat(generated, dim=1)

In [20]:
# Downloading dakshina dataset
!yes | wget "https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar"

--2025-05-19 12:13:12--  https://storage.googleapis.com/gresearch/dakshina/dakshina_dataset_v1.0.tar
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.31.207, 74.125.143.207, 142.250.145.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.31.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2008340480 (1.9G) [application/x-tar]
Saving to: ‘dakshina_dataset_v1.0.tar.1’


2025-05-19 12:14:01 (39.5 MB/s) - ‘dakshina_dataset_v1.0.tar.1’ saved [2008340480/2008340480]

yes: standard output: Broken pipe


In [21]:
# Unzipping dataset
!yes | tar xopf dakshina_dataset_v1.0.tar

yes: standard output: Broken pipe


In [22]:
# vocab.py - Enhanced vocabulary handling

import json
import os
import pickle

class CharVocab:
    """
    Character-level vocabulary class with additional features from satyabhagwan.ipynb
    """
    def __init__(self, tokens=None, specials=['<pad>','<sos>','<eos>','<unk>']):
        self.specials = specials
        self.idx2char = list(specials) + (tokens or [])
        self.char2idx = {ch:i for i,ch in enumerate(self.idx2char)}

    @classmethod
    def build_from_texts(cls, texts):
        """Build vocabulary from a list of texts"""
        chars = sorted({c for line in texts for c in line})
        return cls(tokens=chars)
    
    @classmethod
    def build_from_file(cls, file_path, src_col='src', tgt_col='trg', is_csv=True):
        """
        Build vocabulary from a data file (CSV or TSV)
        
        Args:
            file_path (str): Path to the data file
            src_col (str): Name of the source column (for CSV)
            tgt_col (str): Name of the target column (for CSV)
            is_csv (bool): Whether the file is CSV (True) or TSV (False)
        """
        if is_csv:
            import pandas as pd
            df = pd.read_csv(file_path, header=None, names=[src_col, tgt_col])
            texts = df[src_col].dropna().tolist() + df[tgt_col].dropna().tolist()
        else:
            texts = []
            with open(file_path, encoding='utf-8') as f:
                for ln in f:
                    parts = ln.strip().split('\t')
                    if len(parts) >= 2:
                        texts.extend([parts[0], parts[1]])
        
        return cls.build_from_texts(texts)

    def save(self, path):
        """Save vocabulary to JSON file"""
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(self.idx2char, f, ensure_ascii=False)

    @classmethod
    def load(cls, path):
        """Load vocabulary from JSON file"""
        with open(path, encoding='utf-8') as f:
            idx2char = json.load(f)
        
        inst = cls(tokens=[])
        inst.idx2char = idx2char
        inst.char2idx = {c:i for i,c in enumerate(idx2char)}
        return inst

    def encode(self, text, add_sos=False, add_eos=False):
        """
        Convert text to a sequence of indices
        
        Args:
            text (str): Input text
            add_sos (bool): Whether to add start-of-sequence token
            add_eos (bool): Whether to add end-of-sequence token
        
        Returns:
            list: Sequence of token indices
        """
        seq = []
        if add_sos: seq.append(self.char2idx['<sos>'])
        for c in text:
            seq.append(self.char2idx.get(c, self.char2idx['<unk>']))
        if add_eos: seq.append(self.char2idx['<eos>'])
        return seq

    def decode(self, idxs, strip_specials=True, join=True):
        """
        Convert a sequence of indices back to text
        
        Args:
            idxs (list or tensor): Sequence of indices
            strip_specials (bool): Whether to remove special tokens
            join (bool): Whether to join characters into a string
            
        Returns:
            str or list: Decoded text as string (if join=True) or list of characters
        """
        # Convert tensor to list if needed
        if hasattr(idxs, 'tolist'):
            idxs = idxs.tolist()
            
        # Convert indices to characters
        chars = [self.idx2char[i] for i in idxs if i < len(self.idx2char)]
        
        # Remove special tokens if requested
        if strip_specials:
            chars = [c for c in chars if c not in self.specials]
            
        # Return as string or list
        return ''.join(chars) if join else chars
    
    def batch_decode(self, batch_idxs, strip_specials=True):
        """
        Decode a batch of index sequences
        
        Args:
            batch_idxs (list of lists or tensor): Batch of index sequences
            strip_specials (bool): Whether to remove special tokens
            
        Returns:
            list: List of decoded strings
        """
        return [self.decode(seq, strip_specials=strip_specials) for seq in batch_idxs]
    
    def get_stats(self):
        """Get vocabulary statistics"""
        return {
            'size': len(self.idx2char),
            'num_specials': len(self.specials),
            'num_chars': len(self.idx2char) - len(self.specials)
        }
    
    def __len__(self):
        return len(self.idx2char)

    @property
    def pad_idx(self): return self.char2idx['<pad>']
    
    @property
    def sos_idx(self): return self.char2idx['<sos>']
    
    @property
    def eos_idx(self): return self.char2idx['<eos>']
    
    @property
    def unk_idx(self): return self.char2idx['<unk>']
    
    @property
    def size(self): return len(self.idx2char)

In [23]:
# data_loader.py - Enhanced data loading with support for multiple datasets

import os
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
# from vocab import CharVocab

class TransliterationDataset(Dataset):
    """A flexible dataset class that can handle  Dakshina  dataset"""
    
    def __init__(self, path, src_vocab, tgt_vocab, format='dakshina'):
        """
        Initialize the dataset
        
        Args:
            path (str): Path to the data file
            src_vocab (CharVocab): Source vocabulary
            tgt_vocab (CharVocab): Target vocabulary
            format (str): Dataset format - 'dakshina'
        """
        self.examples = []
        self.format = format
        
        if format == 'dakshina':
            # Dakshina format: tab-separated without header
            for src, tgt in read_tsv(path):
                src_ids = src_vocab.encode(src, add_sos=True, add_eos=True)
                tgt_ids = tgt_vocab.encode(tgt, add_sos=True, add_eos=True)
                self.examples.append((
                    torch.tensor(src_ids, dtype=torch.long),
                    torch.tensor(tgt_ids, dtype=torch.long)
                ))
        
        else:
            raise ValueError(f"Unknown format: {format}. Use 'dakshina'")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


def read_tsv(path):
    """Read a tab-separated file with source and target text"""
    with open(path, encoding='utf-8') as f:
        for ln in f:
            parts = ln.strip().split('\t')
            if len(parts) >= 2:
                yield parts[1], parts[0]  # Dakshina format has target, source


def read_csv(path, src_col='src', tgt_col='trg'):
    """Read a CSV file with source and target columns"""
    df = pd.read_csv(path)
    for _, row in df.iterrows():
        yield row[src_col], row[tgt_col]


def collate_fn(batch, src_vocab, tgt_vocab):
    """Collate function to handle variable-length sequences"""
    srcs, tgts = zip(*batch)
    srcs_p = pad_sequence(srcs, batch_first=True, padding_value=src_vocab.pad_idx)
    tgts_p = pad_sequence(tgts, batch_first=True, padding_value=tgt_vocab.pad_idx)
    src_lens = torch.tensor([len(s) for s in srcs], dtype=torch.long)
    return srcs_p, src_lens, tgts_p


def get_dataloaders(
        language='te', 
        dataset_format='dakshina',
        base_path=None,
        batch_size=64,
        device='cpu',
        num_workers=2,
        prefetch_factor=4,
        persistent_workers=True,
        cache_dir='./cache',
        use_cached_vocab=True
    ):
    """
    Enhanced function to load transliteration datasets with support for multiple formats
    
    Args:
        language (str): Language code (e.g., 'te' for Telugu)
        dataset_format (str): 'dakshina' 
        base_path (str): Override the default dataset path
        batch_size (int): Batch size
        device (str): Device to use ('cuda' or 'cpu')
        num_workers (int): Number of data loading workers
        prefetch_factor (int): Number of batches to prefetch
        persistent_workers (bool): Keep workers alive between epochs
        cache_dir (str): Directory to cache vocabularies
        use_cached_vocab (bool): Whether to use cached vocabularies if available
        
    Returns:
        tuple: (loaders dict, src_vocab, tgt_vocab)
    """
    # Set up paths based on dataset format
    if base_path is None:
        base_path = os.path.join(
            '/kaggle/working/dakshina_dataset_v1.0',
            language, 'lexicons'
        )

    # Create cache directory if it doesn't exist
    if use_cached_vocab:
        os.makedirs(cache_dir, exist_ok=True)
        vocab_cache_path = os.path.join(cache_dir, f"{language}_{dataset_format}_vocab.pkl")
    
    # Try to load cached vocabularies
    if use_cached_vocab and os.path.exists(vocab_cache_path):
        print(f"Loading cached vocabularies from {vocab_cache_path}")
        with open(vocab_cache_path, 'rb') as f:
            src_vocab, tgt_vocab = pickle.load(f)
    else:
        # Build vocabularies from data
        all_src, all_tgt = [], []
        
        for split in ['train', 'dev']:
            path = os.path.join(base_path, f"{language}.translit.sampled.{split}.tsv")
            for s, t in read_tsv(path):
                all_src.append(s)
                all_tgt.append(t)
        
        
        # Build vocabularies
        src_vocab = CharVocab.build_from_texts(all_src)
        tgt_vocab = CharVocab.build_from_texts(all_tgt)
        
        # Cache vocabularies
        if use_cached_vocab:
            with open(vocab_cache_path, 'wb') as f:
                pickle.dump((src_vocab, tgt_vocab), f)
    
    # Common DataLoader arguments
    loader_kwargs = dict(
        batch_size=batch_size,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor,
        persistent_workers=persistent_workers and num_workers > 0,
        pin_memory=(device == 'cuda')
    )
    
    # Create data loaders for each split
    loaders = {}
    
    
    splits = {'train': 'train', 'dev': 'dev', 'test': 'test'}
    for split_name, file_split in splits.items():
        path = os.path.join(base_path, f"{language}.translit.sampled.{file_split}.tsv")
        ds = TransliterationDataset(path, src_vocab, tgt_vocab, format='dakshina')
        loaders[split_name] = DataLoader(
            ds,
            shuffle=(split_name == 'train'),
            collate_fn=lambda b: collate_fn(b, src_vocab, tgt_vocab),
            **loader_kwargs
        )
    
    
    return loaders, src_vocab, tgt_vocab

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from tqdm.auto import tqdm
import csv
import pandas as pd

def compute_detailed_accuracy(model, loader, tgt_vocab, src_vocab, device):
    """
    Enhanced accuracy function that returns:
    - Overall accuracy
    - Lists of correct and incorrect predictions for analysis
    """
    model.eval()
    correct = total = 0
    
    # Lists to store detailed results
    correct_srcs = []
    correct_tgts = []
    correct_preds = []
    
    incorrect_srcs = []
    incorrect_tgts = []
    incorrect_preds = []
    
    with torch.no_grad():
        for src, src_lens, tgt in loader:
            src, src_lens, tgt = (x.to(device) for x in (src, src_lens, tgt))
            pred = model.infer_greedy(src, src_lens, tgt_vocab, max_len=tgt.size(1))

            # iterate over the batch
            for b in range(src.size(0)):
                # Convert indices to strings
                pred_str = tgt_vocab.decode(pred[b].cpu().tolist())
                gold_str = tgt_vocab.decode(tgt[b, 1:].cpu().tolist())  # skip <sos>
                src_str = src_vocab.decode(src[b].cpu().tolist())
                
                # Check if prediction is correct
                is_correct = (pred_str == gold_str)
                correct += is_correct
                
                # Store detailed results
                if is_correct:
                    correct_srcs.append(src_str)
                    correct_tgts.append(gold_str)
                    correct_preds.append(pred_str)
                else:
                    incorrect_srcs.append(src_str)
                    incorrect_tgts.append(gold_str)
                    incorrect_preds.append(pred_str)
                    
            total += src.size(0)

    accuracy = correct / total if total else 0.0
    return (
        accuracy, 
        (correct_srcs, correct_tgts, correct_preds),
        (incorrect_srcs, incorrect_tgts, incorrect_preds)
    )

def save_predictions_to_csv(src_list, tgt_list, pred_list, file_name):
    """Save prediction details to CSV file for further analysis"""
    rows = zip(src_list, tgt_list, pred_list)
    
    with open(file_name, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(['Source', 'Target', 'Predicted'])
        writer.writerows(rows)
    
    return file_name

def train_model(
    model, 
    loaders, 
    src_vocab, 
    tgt_vocab, 
    device,
    config,
    save_path=None,
    log_to_wandb=True
):
    """
    Enhanced training function with:
    - Teacher forcing control
    - Detailed accuracy tracking
    - Progress bars
    - Optional WandB logging
    """
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)
    
    # Select optimizer based on config
    if config.optimizer.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer.lower() == 'nadam':
        optimizer = optim.NAdam(model.parameters(), lr=config.lr)
    else:
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    
    # Track best validation accuracy
    best_val_acc = 0.0
    
    # Main training loop
    for epoch in tqdm(range(1, config.epochs + 1), desc="Epochs", position=0):
        model.train()
        total_loss = 0.0

        # Training batches with progress bar
        train_loader = tqdm(loaders['train'], desc=f"Train {epoch}", leave=False, position=1)
        for src, src_lens, tgt in train_loader:
            src, src_lens, tgt = src.to(device), src_lens.to(device), tgt.to(device)

            optimizer.zero_grad()
            # Use teacher forcing ratio from config
            output = model(src, src_lens, tgt, teacher_forcing_ratio=config.teacher_forcing)
            loss = criterion(output.reshape(-1, output.size(-1)), tgt[:,1:].reshape(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
            
        train_loader.close()
        train_loss = total_loss / len(loaders['train'])

        # Validation loss
        val_loss = 0.0
        val_loader = tqdm(loaders['dev'], desc=f"Val {epoch}", leave=False, position=1)
        model.eval()
        with torch.no_grad():
            for src, src_lens, tgt in val_loader:
                src, src_lens, tgt = src.to(device), src_lens.to(device), tgt.to(device)
                output = model(src, src_lens, tgt, teacher_forcing_ratio=0.0)  # No teacher forcing during validation
                val_loss += criterion(output.reshape(-1, output.size(-1)),
                                    tgt[:,1:].reshape(-1)).item()
        val_loader.close()
        val_loss /= len(loaders['dev'])

        # Compute detailed accuracy metrics
        train_results = compute_detailed_accuracy(model, loaders['train'], tgt_vocab, src_vocab, device)
        train_acc = train_results[0]
        
        val_results = compute_detailed_accuracy(model, loaders['dev'], tgt_vocab, src_vocab, device)
        val_acc = val_results[0]
        
        # Save model if it's the best so far
        if val_acc > best_val_acc and save_path:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best model with validation accuracy: {val_acc:.4f}")
            
            # Save prediction analysis CSVs for best model
            if epoch == config.epochs or epoch % 5 == 0:  # Save at last epoch or every 5 epochs
                correct_data = val_results[1]
                incorrect_data = val_results[2]
                
                save_predictions_to_csv(
                    correct_data[0], correct_data[1], correct_data[2],
                    f"correct_predictions_epoch_{epoch}.csv"
                )
                
                save_predictions_to_csv(
                    incorrect_data[0], incorrect_data[1], incorrect_data[2],
                    f"incorrect_predictions_epoch_{epoch}.csv"
                )

        # Log metrics
        print(f"Epoch {epoch}/{config.epochs}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        if log_to_wandb:
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'val_acc': val_acc
            })
    
    # Final evaluation on test set
    # test_results = compute_detailed_accuracy(model, loaders['test'], tgt_vocab, src_vocab, device)
    # test_acc = test_results[0]
    # print(f"Final test accuracy: {test_acc:.4f}")
    
    # if log_to_wandb:
    #     wandb.log({'test_acc': test_acc})
    
    # # Save final prediction analysis
    # correct_data = test_results[1]
    # incorrect_data = test_results[2]
    
    # save_predictions_to_csv(
    #     correct_data[0], correct_data[1], correct_data[2],
    #     "correct_predictions_final.csv"
    # )
    
    # save_predictions_to_csv(
    #     incorrect_data[0], incorrect_data[1], incorrect_data[2],
    #     "incorrect_predictions_final.csv"
    # )
    test_acc=0
    
    return model, test_acc

In [25]:
wandb.login(key='4e3ff854ed5182f4bc02df0482250c121f645ae5')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
# sweep_config.py - Enhanced with improved hyperparameter sweep

import wandb
import torch
from tqdm.auto import tqdm
import os
import random
import numpy as np


# Import enhanced modules
# from models import Encoder, Decoder, Seq2Seq, seed_everything
# from training import train_model
# from data import get_dataloaders

def objective():
    # Initialize WandB run
    run = wandb.init()
    cfg = run.config
    
    # Set seeds for reproducibility
    seed_everything(cfg.seed if hasattr(cfg, 'seed') else 42)
    
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Create a unique run name based on config
    run_name = f"{cfg.cell}_{cfg.enc_layers}l_{cfg.emb_size}e_{cfg.hidden_size}h_" \
               f"{'bid' if cfg.bidirectional else 'uni'}_{cfg.dropout}d_" \
               f"{cfg.teacher_forcing}tf_{cfg.optimizer}"
    wandb.run.name = run_name
    
    # Load data
    loaders, src_vocab, tgt_vocab = get_dataloaders(
        'te',
        batch_size=cfg.batch_size,
        device=device
    )
    
    # Create model components
    enc = Encoder(
        src_vocab.size, cfg.emb_size, cfg.hidden_size,
        cfg.enc_layers, cfg.cell, cfg.dropout, 
        bidirectional=cfg.bidirectional
    ).to(device)
    
    # Calculate encoder output dimension (doubled if bidirectional)
    enc_out_dim = cfg.hidden_size * 2 if cfg.bidirectional else cfg.hidden_size
    
    dec = Decoder(
        tgt_vocab.size, cfg.emb_size, enc_out_dim, cfg.hidden_size,
        cfg.enc_layers, cfg.cell, cfg.dropout
    ).to(device)
    
    model = Seq2Seq(enc, dec, pad_idx=src_vocab.pad_idx, device=device).to(device)
    
    # Train the model
    best_model_path = f"model_{run_name}.pt"
    _, test_acc = train_model(
        model=model,
        loaders=loaders,
        src_vocab=src_vocab,
        tgt_vocab=tgt_vocab,
        device=device,
        config=cfg,
        save_path=best_model_path,
        log_to_wandb=True
    )
    
    # Log final test accuracy as summary metric
    # wandb.run.summary['test_accuracy'] = test_acc
    
    # Finish the run
    wandb.finish()

if __name__ == "__main__":
    # Define an enhanced sweep configuration
    sweep_cfg = {
        
        'method': 'bayes',  # Use Bayesian optimization
        'name':'Transliteration_without_Attention',
        'metric': {'name': 'val_acc', 'goal': 'maximize'},
        'parameters': {
            
            # Model architecture
            'emb_size': {'values': [128, 256, 512]},
            'hidden_size': {'values': [128, 256, 512, 1024]},
            'enc_layers': {'values': [1, 2, 3, 4]},
            'cell': {'values': ['RNN', 'GRU', 'LSTM']},  
            'bidirectional': {'values': [True, False]},  # Bidirectional encode
            
            # Training parameters
            'dropout': {'values': [0.0, 0.1, 0.2, 0.3, 0.5]},
            'lr': {'values': [1e-4, 2e-4, 5e-4, 8e-4, 1e-3]},
            'batch_size': {'values': [32, 64, 128]},
            'epochs': {'values': [10, 15, 20]},
            'teacher_forcing': {'values': [0.3, 0.5, 0.7, 1.0]},  # Explicit teacher forcing
            'optimizer': {'values': ['Adam', 'NAdam']},  # Added optimizer options
            # Reproducibility
            'seed': {'values': [42, 43, 44, 45, 46]},  # Different seeds for robustness
        }
    }

    # Start the sweep
    sweep_id = wandb.sweep(
        sweep_cfg,
        entity='cs24m042-iit-madras-foundation',  # Replace with your username
        project='DA6401-Assignment-3'
    )
    
    # Run the sweep agent
    wandb.agent(sweep_id, function=objective, count=20)

Create sweep with ID: kx7x532z
Sweep URL: https://wandb.ai/cs24m042-iit-madras-foundation/DA6401-Assignment-3/sweeps/kx7x532z


[34m[1mwandb[0m: Agent Starting Run: eesdjwbp with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	bidirectional: True
[34m[1mwandb[0m: 	cell: RNN
[34m[1mwandb[0m: 	dropout: 0
[34m[1mwandb[0m: 	emb_size: 512
[34m[1mwandb[0m: 	enc_layers: 2
[34m[1mwandb[0m: 	epochs: 20
[34m[1mwandb[0m: 	hidden_size: 512
[34m[1mwandb[0m: 	lr: 0.0008
[34m[1mwandb[0m: 	optimizer: Adam
[34m[1mwandb[0m: 	seed: 45
[34m[1mwandb[0m: 	teacher_forcing: 1


Using device: cuda
Loading cached vocabularies from ./cache/te_dakshina_vocab.pkl


Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Train 1:   0%|          | 0/915 [00:00<?, ?it/s]

Val 1:   0%|          | 0/89 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.0554
Epoch 1/20:
  Train Loss: 1.2515, Train Acc: 0.0806
  Val Loss: 2.5904, Val Acc: 0.0554


Train 2:   0%|          | 0/915 [00:00<?, ?it/s]

Val 2:   0%|          | 0/89 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.1138
Epoch 2/20:
  Train Loss: 0.6541, Train Acc: 0.1946
  Val Loss: 2.5525, Val Acc: 0.1138


Train 3:   0%|          | 0/915 [00:00<?, ?it/s]

Val 3:   0%|          | 0/89 [00:00<?, ?it/s]

Saved new best model with validation accuracy: 0.1285
Epoch 3/20:
  Train Loss: 0.5007, Train Acc: 0.2401
  Val Loss: 2.6587, Val Acc: 0.1285


Train 4:   0%|          | 0/915 [00:00<?, ?it/s]

Val 4:   0%|          | 0/89 [00:00<?, ?it/s]

Epoch 4/20:
  Train Loss: 0.4225, Train Acc: 0.2339
  Val Loss: 2.7117, Val Acc: 0.1221


Train 5:   0%|          | 0/915 [00:00<?, ?it/s]

Val 5:   0%|          | 0/89 [00:00<?, ?it/s]

Epoch 5/20:
  Train Loss: 0.3690, Train Acc: 0.2422
  Val Loss: 2.8883, Val Acc: 0.1147


Train 6:   0%|          | 0/915 [00:00<?, ?it/s]

In [None]:
# import os
# import torch
# import wandb
# import argparse
# from models import Encoder, Decoder, Seq2Seq, seed_everything
# from data import get_dataloaders
# from training import train_model, save_predictions_to_csv, compute_detailed_accuracy

# def main(args):
#     # Set seeds for reproducibility
#     seed_everything(args.seed)
    
#     # Set device
#     device = torch.device(args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu')
#     print(f"Using device: {device}")
    
#     # Initialize wandb if requested
#     if args.use_wandb:
#         wandb.init(
#             project=args.wandb_project,
#             name=args.run_name,
#             config=vars(args)
#         )
    
#     # Load datasets
#     print(f"Loading {args.language} data...")
#     loaders, src_vocab, tgt_vocab = get_dataloaders(
#         args.language,
#         batch_size=args.batch_size,
#         device=args.device
#     )
    
#     # Create model components
#     print("Building model...")
#     enc = Encoder(
#         src_vocab.size, args.emb_size, args.hidden_size,
#         args.enc_layers, args.cell_type, args.dropout, 
#         bidirectional=args.bidirectional
#     ).to(device)
    
#     # Calculate encoder output dimension (doubled if bidirectional)
#     enc_out_dim = args.hidden_size * 2 if args.bidirectional else args.hidden_size
    
#     dec = Decoder(
#         tgt_vocab.size, args.emb_size, enc_out_dim, args.hidden_size,
#         args.enc_layers, args.cell_type, args.dropout, 
#         use_attn=args.use_attention
#     ).to(device)
    
#     model = Seq2Seq(enc, dec, pad_idx=src_vocab.pad_idx, device=device).to(device)
    
#     # Create output directory if it doesn't exist
#     os.makedirs(args.output_dir, exist_ok=True)
    
#     # Train the model
#     print("Training model...")
#     best_model_path = os.path.join(args.output_dir, f"{args.run_name}_best.pt")
#     model, test_acc = train_model(
#         model=model,
#         loaders=loaders,
#         src_vocab=src_vocab,
#         tgt_vocab=tgt_vocab,
#         device=device,
#         config=args,
#         save_path=best_model_path,
#         log_to_wandb=args.use_wandb
#     )
    
#     # Save final model
#     final_model_path = os.path.join(args.output_dir, f"{args.run_name}_final.pt")
#     torch.save(model.state_dict(), final_model_path)
#     print(f"Saved final model to {final_model_path}")
    
#     # Generate and save detailed predictions on test set
#     print("Generating final predictions on test set...")
#     test_results = compute_detailed_accuracy(model, loaders['test'], tgt_vocab, src_vocab, device)
#     test_acc = test_results[0]
    
#     correct_data = test_results[1]
#     incorrect_data = test_results[2]
    
#     correct_csv = os.path.join(args.output_dir, f"{args.run_name}_correct.csv")
#     incorrect_csv = os.path.join(args.output_dir, f"{args.run_name}_incorrect.csv")
    
#     save_predictions_to_csv(
#         correct_data[0], correct_data[1], correct_data[2],
#         correct_csv
#     )
    
#     save_predictions_to_csv(
#         incorrect_data[0], incorrect_data[1], incorrect_data[2],
#         incorrect_csv
#     )
    
#     print(f"Final test accuracy: {test_acc:.4f}")
#     print(f"Saved correct predictions to {correct_csv}")
#     print(f"Saved incorrect predictions to {incorrect_csv}")
    
#     if args.use_wandb:
#         wandb.finish()

# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description="Train a seq2seq transliteration model")
    
#     # Data parameters
#     parser.add_argument("--language", type=str, default="te", help="Language code (e.g., 'te' for Telugu)")
#     parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save models and outputs")
    
#     # Model architecture
#     parser.add_argument("--emb_size", type=int, default=256, help="Embedding size")
#     parser.add_argument("--hidden_size", type=int, default=512, help="Hidden state size")
#     parser.add_argument("--enc_layers", type=int, default=2, help="Number of encoder layers")
#     parser.add_argument("--cell_type", type=str, default="LSTM", choices=["RNN", "GRU", "LSTM"], help="RNN cell type")
#     parser.add_argument("--bidirectional", action="store_true", help="Use bidirectional encoder")
#     parser.add_argument("--use_attention", action="store_true", default=True, help="Use attention mechanism")
    
#     # Training parameters
#     parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
#     parser.add_argument("--epochs", type=int, default=15, help="Number of epochs")
#     parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
#     parser.add_argument("--dropout", type=float, default=0.3, help="Dropout rate")
#     parser.add_argument("--teacher_forcing", type=float, default=0.5, help="Teacher forcing ratio")
#     parser.add_argument("--optimizer", type=str, default="Adam", choices=["Adam", "NAdam"], help="Optimizer")
#     parser.add_argument("--seed", type=int, default=42, help="Random seed")
#     parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use")
    
#     # Wandb parameters
#     parser.add_argument("--use_wandb", action="store_true", help="Log to Weights & Biases")
#     parser.add_argument("--wandb_project", type=str, default="transliteration", help="WandB project name")
#     parser.add_argument("--run_name", type=str, default="seq2seq_model", help="Run name")
    
#     args = parser.parse_args()
#     main(args)