In [14]:
!pip install scikit-learn tqdm pandas sanskrit-text aksharamukha tensorboard nltk ipywidgets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [15]:
import os
import math
import torch
import logging
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import sanskrit_text as skt
from typing import List, Dict
from tqdm.notebook import tqdm
from aksharamukha import transliterate
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter

# Constants and Hyperparameters
MAX_LENGTH = 100
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 1e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BEAM_WIDTH = 5
LENGTH_PENALTY = 0.6
GRAD_CLIP = 1.0
LABEL_SMOOTHING = 0.1
LOG_DIR = './logs'

# Early Stopping Class with Enhanced Criteria
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, min_epochs=10):
        self.patience = patience
        self.min_delta = min_delta
        self.min_epochs = min_epochs
        self.counter = 0
        self.best_loss = float('inf')
        self.best_acc = -float('inf')
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.stop_reason = None

    def __call__(self, epoch, val_loss, val_acc, train_loss, train_acc):
        if epoch < self.min_epochs:
            return False
        loss_improved = val_loss < (self.best_loss - self.min_delta)
        acc_improved = val_acc > (self.best_acc + self.min_delta)
        if loss_improved or acc_improved:
            self.counter = 0
            if loss_improved:
                self.best_loss = val_loss
            if acc_improved:
                self.best_acc = val_acc
        else:
            self.counter += 1
        if self.counter >= self.patience:
            self.early_stop = True
            self.stop_reason = f"No improvement for {self.patience} epochs"
            return True
        if val_loss > 2 * train_loss:
            self.early_stop = True
            self.stop_reason = "Validation loss much higher than training loss"
            return True
        if val_acc > 95 and val_loss < 0.1:
            self.early_stop = True
            self.stop_reason = "Reached excellent performance"
            return True
        return False

    def get_status(self):
        return {
            'counter': self.counter,
            'best_loss': self.best_loss,
            'best_acc': self.best_acc,
            'stop_reason': self.stop_reason
        }

# Positional Encoding Module
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=MAX_LENGTH):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

# Sandhi Transformer Model with Enhanced Features
class SandhiTransformer(nn.Module):
    def __init__(self, char_vocab_size: int, phon_feature_dim: int, 
                 hidden_dim: int, num_layers: int, num_heads: int, 
                 dropout: float = 0.1, max_length: int = MAX_LENGTH):
        super(SandhiTransformer, self).__init__()
        self.hidden_dim = hidden_dim

        # Character Embedding with Dropout
        self.char_embedding = nn.Embedding(char_vocab_size, hidden_dim, padding_idx=0)
        self.dropout = nn.Dropout(dropout)

        # Phonetic Feature Transformation
        self.phon_feature_linear = nn.Linear(phon_feature_dim, hidden_dim)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(hidden_dim, max_len=max_length)

        # Transformer Encoders for Characters and Phonetic Features
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='relu'
        )
        self.char_transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.phon_transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention-based Memory Fusion
        self.memory_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)

        # Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='relu'
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Layer Normalization
        self.layer_norm = nn.LayerNorm(hidden_dim)

        # Classifier with Label Smoothing
        self.classifier = nn.Linear(hidden_dim, char_vocab_size)

    def attention_fusion(self, memory1, memory2):
        att_fusion = self.memory_attention(memory1, memory2, memory2)[0]
        return att_fusion

    def forward(self, char_inputs, phon_feature_inputs, split_output, 
                src_key_padding_mask=None, tgt_mask=None, tgt_key_padding_mask=None):
        # Character Embedding for Word
        word_char_embedded = self.char_embedding(char_inputs)  # (batch_size, seq_len, hidden_dim)
        word_char_embedded = self.dropout(word_char_embedded)
        word_char_embedded = self.pos_encoder(word_char_embedded)
        word_char_embedded = word_char_embedded.transpose(0, 1)  # (seq_len, batch_size, hidden_dim)

        # Phonetic Feature Transformation
        feature_embedded = self.phon_feature_linear(phon_feature_inputs)  # (batch_size, seq_len, hidden_dim)
        feature_embedded = self.dropout(feature_embedded)
        feature_embedded = self.pos_encoder(feature_embedded)
        feature_embedded = feature_embedded.transpose(0, 1)  # (seq_len, batch_size, hidden_dim)

        # Transformer Encoders
        char_transformed = self.char_transformer_encoder(word_char_embedded, src_key_padding_mask=src_key_padding_mask)
        phon_transformed = self.phon_transformer_encoder(feature_embedded, src_key_padding_mask=src_key_padding_mask)

        # Attention-based Memory Fusion
        att_fusion = self.attention_fusion(char_transformed, phon_transformed)  # (seq_len, batch_size, hidden_dim)

        # Split Token Embedding
        split_char_embedded = self.char_embedding(split_output)  # (batch_size, tgt_seq_len, hidden_dim)
        split_char_embedded = self.dropout(split_char_embedded)
        split_char_embedded = self.pos_encoder(split_char_embedded)
        split_char_embedded = split_char_embedded.transpose(0, 1)  # (tgt_seq_len, batch_size, hidden_dim)

        # Transformer Decoder
        transformed = self.decoder(split_char_embedded, att_fusion, tgt_mask=tgt_mask, 
                                   tgt_key_padding_mask=tgt_key_padding_mask)
        transformed = transformed.transpose(0, 1)  # (batch_size, tgt_seq_len, hidden_dim)

        # Layer Normalization
        transformed = self.layer_norm(transformed)

        # Classifier
        logits = self.classifier(transformed)  # (batch_size, tgt_seq_len, vocab_size)
        return logits

# Sandhi Dataset Class
class SandhiDataset(Dataset):
    def __init__(self, words: List[str], splits: List[str],
                 char2idx: Dict[str, int], max_length: int = MAX_LENGTH):
        self.words = words
        self.splits = splits
        self.char2idx = char2idx
        self.max_length = max_length

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        word = self.words[idx]
        split = self.splits[idx]
        try:
            # Tokenize the word
            word_token = self._char_tokenizer(word)
            # Tokenize the split
            split_token = self._char_tokenizer(split, is_target=True)
            # Get phonetic feature vectors for the word
            phonetic_features = self._phonetic_features(word)
            # Pad the tokenized strings and phonetic feature vectors
            word_token = self._pad_sequence(word_token, self.max_length, self.char2idx['<pad>'])
            split_token = self._pad_sequence(split_token, self.max_length, self.char2idx['<pad>'])
            phonetic_features = self._pad_sequence(phonetic_features, self.max_length, [0] * 26)
            # Return the tokens and feature vectors
            return {
                'word_token': torch.tensor(word_token, dtype=torch.long),
                'split_token': torch.tensor(split_token, dtype=torch.long),
                'phon_features': torch.tensor(phonetic_features, dtype=torch.float),
                'word': word,
                'split_text': split
            }
        except Exception as e:
            print(f"Error processing index {idx}: {str(e)}")
            raise

    def _pad_sequence(self, sequence, max_length, pad_value):
        # Handle sequences with nested lists (e.g., phonetic features)
        if isinstance(pad_value, list):
            inner_dim = len(pad_value)
            if len(sequence) > max_length:
                return sequence[:max_length]
            else:
                padding = [pad_value for _ in range(max_length - len(sequence))]
                return sequence + padding
        else:
            if len(sequence) > max_length:
                return sequence[:max_length]
            else:
                return sequence + [pad_value] * (max_length - len(sequence))

    def _phonetic_features(self, string):
        phon_vec = [list(tup[1].values()) for tup in skt.get_ucchaarana_vectors(string)]
        return phon_vec

    def _char_tokenizer(self, string, is_target=False):
        try:
            segments = string.split('+') if '+' in string else [string]
            char_encoding = []
            # Introduce a start token for targets
            if is_target:
                char_encoding.append(self.char2idx.get('<s>', 2))  # Assuming <s> is index 2
            for i, segment in enumerate(segments):
                chars = [self.char2idx.get(tup[0], self.char2idx['<unk>']) for tup in skt.get_ucchaarana_vectors(segment)]
                char_encoding += chars
                # Add '+' between segments, except after the last segment
                if i + 1 < len(segments):
                    char_encoding.append(self.char2idx.get('+', 3))  # Assuming '+' is index 3
            if is_target:
                char_encoding.append(self.char2idx.get('</s>', 4))  # Assuming '</s>' is index 4
            return char_encoding
        except Exception as e:
            print(f"Error tokenizing string '{string}': {str(e)}")
            raise

# Character to Index Mapper with Start Token
def char2idx_mapper(data):
    bag_of_chars = set()
    char2idx = {}
    # Define special tokens with explicit indices
    special_tokens = ['<pad>', '<unk>', '<s>', '</s>', '+']
    for i, token in enumerate(special_tokens):
        char2idx[token] = i
    # Use the word column to construct character mapping
    for sans_word in tqdm(data["word"], desc="Constructing character mapping"):
        varna_decomp = [tup[0] for tup in skt.get_ucchaarana_vectors(sans_word)]
        bag_of_chars.update(varna_decomp)
    # Sort the regular characters
    sorted_chars = sorted(bag_of_chars)
    # Add regular characters to the mapping
    for char in sorted_chars:
        if char not in char2idx:
            char2idx[char] = len(char2idx)
    return char2idx

# Data Preparation Function
def prepare_data(csv_path: str):
    print("\nLoading dataset...")
    data = pd.read_csv(csv_path)
    total_samples = len(data)
    print(f"Total samples: {total_samples:,}")

    # Construct character to id mapping
    char2idx = char2idx_mapper(data)
    print(f"Vocabulary size: {len(char2idx):,}")

    # Split data into train, validation, and test
    train_words, temp_words, train_splits, temp_splits = train_test_split(
        data['word'].values, data['split'].values, test_size=0.2, random_state=42
    )
    val_words, test_words, val_splits, test_splits = train_test_split(
        temp_words, temp_splits, test_size=0.5, random_state=42
    )

    # Prepare datasets
    train_dataset = SandhiDataset(train_words, train_splits, char2idx)
    val_dataset = SandhiDataset(val_words, val_splits, char2idx)
    test_dataset = SandhiDataset(test_words, test_splits, char2idx)

    # Display dataset sizes
    print(f"\nTrain: {len(train_words):,} ({len(train_words)/total_samples*100:.1f}%)")
    print(f"Val: {len(val_words):,} ({len(val_words)/total_samples*100:.1f}%)")
    print(f"Test: {len(test_words):,} ({len(test_words)/total_samples*100:.1f}%)")

    # Save char2idx mapping
    torch.save(char2idx, os.path.join(os.path.dirname(csv_path), "char2idx.pt"))
    print(f"Character to index mapping saved at {os.path.join(os.path.dirname(csv_path), 'char2idx.pt')}")
    return train_dataset, val_dataset, test_dataset, char2idx

# Collate Function for DataLoader
def collate_fn(batch):
    word_token = torch.stack([item['word_token'] for item in batch])
    split_token = torch.stack([item['split_token'] for item in batch])
    phon_features = torch.stack([item['phon_features'] for item in batch])
    return {
        'word_token': word_token,
        'split_token': split_token,
        'phon_features': phon_features,
        'word': [item['word'] for item in batch],
        'split_text': [item['split_text'] for item in batch]
    }

# Initialize TensorBoard Writer
def initialize_tensorboard():
    if not os.path.exists(LOG_DIR):
        os.makedirs(LOG_DIR)
    writer = SummaryWriter(log_dir=LOG_DIR)
    return writer

# Initialize the Epoch Summary Table
def initialize_epoch_table():
    header = f"{'Epoch':<6} {'Train Loss':<12} {'Train Acc (%)':<15} {'Val Loss':<12} {'Val Acc (%)':<15} {'Status':<20}"
    separator = "-" * len(header)
    print("\n" + header)
    print(separator)

# Update the Epoch Summary Table
def update_epoch_table(epoch, train_loss, train_acc, val_loss, val_acc, status):
    print(f"{epoch:<6} {train_loss:<12.6f} {train_acc:<15.2f} {val_loss:<12.6f} {val_acc:<15.2f} {status:<20}")

# Beam Search Decoding Function
def beam_search(model, char2idx, idx2char, input_chars, input_features, src_key_padding_mask, beam_width=5, length_penalty=0.6):
    model.eval()
    sequences = [[char2idx['<s>']]]
    scores = torch.zeros(len(sequences), device=DEVICE)
    
    with torch.no_grad():
        for _ in range(MAX_LENGTH):
            all_candidates = []
            for i, seq in enumerate(sequences):
                if seq[-1] == char2idx['</s>']:
                    all_candidates.append((seq, scores[i]))
                    continue
                # Prepare decoder input
                decoder_input = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE)
                seq_len = decoder_input.size(1)
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(DEVICE)
                # Forward pass
                output = model(
                    char_inputs=input_chars.unsqueeze(0),  # (1, seq_len)
                    phon_feature_inputs=input_features.unsqueeze(0),  # (1, seq_len, 26)
                    split_output=decoder_input,  # (1, seq_len)
                    src_key_padding_mask=src_key_padding_mask.unsqueeze(0),
                    tgt_mask=tgt_mask
                )
                logits = output[0, -1, :]  # (vocab_size)
                log_probs = torch.log_softmax(logits, dim=-1)
                topk_probs, topk_indices = torch.topk(log_probs, beam_width)
                for k in range(beam_width):
                    candidate = seq + [topk_indices[k].item()]
                    score = scores[i] + topk_probs[k]
                    all_candidates.append((candidate, score))
            # Select top beam_width sequences
            ordered = sorted(all_candidates, key=lambda tup: tup[1]/(len(tup[0])**length_penalty), reverse=True)
            sequences = []
            scores = []
            for seq, score in ordered[:beam_width]:
                sequences.append(seq)
                scores.append(score)
            # Check if all sequences have ended
            if all(seq[-1] == char2idx['</s>'] for seq in sequences):
                break
    # Select the best sequence
    best_seq = sequences[0]
    # Remove start and end tokens
    if best_seq[0] == char2idx['<s>']:
        best_seq = best_seq[1:]
    if best_seq[-1] == char2idx['</s>']:
        best_seq = best_seq[:-1]
    return best_seq

# Modified Sample Predictions with Beam Search
def sample_predictions(model, val_loader, char2idx, idx2char, num_samples=3, beam_width=BEAM_WIDTH, length_penalty=LENGTH_PENALTY):
    """Generate sample predictions from the model using beam search."""
    model.eval()
    samples = []
    with torch.no_grad():
        for batch in val_loader:
            input_chars = batch['word_token'].to(DEVICE)
            input_features = batch['phon_features'].to(DEVICE)
            original_words = batch['word']
            true_splits = batch['split_text']
            src_key_padding_mask = (input_chars == char2idx['<pad>'])
            for i in range(len(original_words)):
                generated_seq = beam_search(
                    model=model,
                    char2idx=char2idx,
                    idx2char=idx2char,
                    input_chars=input_chars[i],
                    input_features=input_features[i],
                    src_key_padding_mask=src_key_padding_mask[i],
                    beam_width=beam_width,
                    length_penalty=length_penalty
                )
                # Decode the generated sequence
                pred_split = ''.join([idx2char.get(idx, '') for idx in generated_seq])
                # Replace multiple '+' with single '+' and remove leading/trailing '+'
                pred_split = '+'.join(filter(None, pred_split.split('+')))
                # Transliterate to Devanagari for consistency
                pred_split = transliterate.process(
                    'IAST', 'Devanagari',
                    transliterate.process('Devanagari', 'IAST', pred_split)
                )
                samples.append((original_words[i], true_splits[i], pred_split))
                if len(samples) >= num_samples:
                    return samples
    return samples

# Training Function with Advanced Features
def train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs, model_save_path, char2idx, writer, idx2char):
    print("\nStarting training...")
    early_stopping = EarlyStopping(patience=5, min_delta=0.001, min_epochs=10)
    scaler = GradScaler()
    epoch_metrics = {}

    # Initialize the progress bars
    epoch_pbar = tqdm(total=num_epochs, desc="Epoch Progress", position=0, leave=True)
    for epoch in range(1, num_epochs + 1):
        epoch_pbar.update(1)
        # Training Phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch} - Training", position=1, leave=False)
        for batch in train_pbar:
            optimizer.zero_grad()
            char_inputs = batch['word_token'].to(DEVICE)
            split_output = batch['split_token'].to(DEVICE)
            phon_feature_inputs = batch['phon_features'].to(DEVICE)
            src_key_padding_mask = (char_inputs == char2idx['<pad>'])

            with autocast():
                output = model(
                    char_inputs=char_inputs,
                    phon_feature_inputs=phon_feature_inputs,
                    split_output=split_output[:, :-1],  # Input to decoder
                    src_key_padding_mask=src_key_padding_mask,
                    tgt_mask=nn.Transformer.generate_square_subsequent_mask(split_output[:, :-1].size(1)).to(DEVICE),
                    tgt_key_padding_mask=(split_output[:, :-1] == char2idx['<pad>'])
                )
                # Reshape for loss computation
                loss = criterion(output.view(-1, output.size(-1)), split_output[:, 1:].reshape(-1))
            
            scaler.scale(loss).backward()
            # Gradient Clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
            # Calculate accuracy
            pred = output.argmax(dim=-1)  # (batch_size, seq_len)
            mask = (split_output[:, 1:] != char2idx['<pad>'])
            train_correct += ((pred == split_output[:, 1:]) & mask).sum().item()
            train_total += mask.sum().item()
            # Update progress bar
            running_avg_loss = train_loss / (train_pbar.n + 1)
            running_avg_acc = (train_correct / train_total) * 100
            train_pbar.set_postfix(loss=running_avg_loss, acc=f"{running_avg_acc:.2f}%")
            # Log training metrics
            current_step = (epoch - 1) * len(train_loader) + train_pbar.n
            writer.add_scalar('Train/Loss', loss.item(), current_step)
            writer.add_scalar('Train/Accuracy', (train_correct / train_total) * 100, current_step)
        
        # Scheduler Step
        scheduler.step()

        # Calculate average training metrics
        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = (train_correct / train_total) * 100
        train_pbar.close()

        # Validation Phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch} - Validation", position=2, leave=False)
        with torch.no_grad():
            for batch in val_pbar:
                char_inputs = batch['word_token'].to(DEVICE)
                split_output = batch['split_token'].to(DEVICE)
                phon_feature_inputs = batch['phon_features'].to(DEVICE)
                src_key_padding_mask = (char_inputs == char2idx['<pad>'])
                # Forward pass
                output = model(
                    char_inputs=char_inputs,
                    phon_feature_inputs=phon_feature_inputs,
                    split_output=split_output[:, :-1],
                    src_key_padding_mask=src_key_padding_mask,
                    tgt_mask=nn.Transformer.generate_square_subsequent_mask(split_output[:, :-1].size(1)).to(DEVICE),
                    tgt_key_padding_mask=(split_output[:, :-1] == char2idx['<pad>'])
                )
                loss = criterion(output.view(-1, output.size(-1)), split_output[:, 1:].reshape(-1))
                val_loss += loss.item()
                # Calculate accuracy
                pred = output.argmax(dim=-1)
                mask = (split_output[:, 1:] != char2idx['<pad>'])
                val_correct += ((pred == split_output[:, 1:]) & mask).sum().item()
                val_total += mask.sum().item()
                # Update progress bar
                running_avg_val_loss = val_loss / (val_pbar.n + 1)
                running_avg_val_acc = (val_correct / val_total) * 100
                val_pbar.set_postfix(loss=running_avg_val_loss, acc=f"{running_avg_val_acc:.2f}%")
                # Log validation metrics
                current_step = (epoch - 1) * len(train_loader) + train_pbar.n
                writer.add_scalar('Validation/Loss', loss.item(), current_step)
                writer.add_scalar('Validation/Accuracy', (val_correct / val_total) * 100, current_step)
        
        # Calculate average validation metrics
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = (val_correct / val_total) * 100
        val_pbar.close()

        # Save metrics
        status = ""
        if avg_val_loss < early_stopping.best_loss:
            early_stopping.best_loss = avg_val_loss
            torch.save(model.state_dict(), model_save_path)
            status = "✓ Saved"
        epoch_metrics[epoch] = {
            'train_loss': avg_train_loss,
            'train_acc': train_accuracy,
            'val_loss': avg_val_loss,
            'val_acc': val_accuracy,
            'status': status
        }

        # Update TensorBoard
        writer.add_scalar('Epoch/Train_Loss', avg_train_loss, epoch)
        writer.add_scalar('Epoch/Train_Accuracy', train_accuracy, epoch)
        writer.add_scalar('Epoch/Val_Loss', avg_val_loss, epoch)
        writer.add_scalar('Epoch/Val_Accuracy', val_accuracy, epoch)

        # Initialize and Update Epoch Summary Table
        if epoch == 1:
            initialize_epoch_table()
        update_epoch_table(epoch, avg_train_loss, train_accuracy, avg_val_loss, val_accuracy, status)

        # Sample Predictions
        print("\nSample predictions for this epoch:")
        print("-" * 50)
        samples = sample_predictions(model, val_loader, char2idx, idx2char, num_samples=3)
        for original, true_split, prediction in samples:
            print(f"Input:     {original}")
            print(f"True:      {true_split}")
            print(f"Predicted: {prediction}")
            print("-" * 50)
        
        # Early Stopping Check
        if early_stopping(epoch, avg_val_loss, val_accuracy, avg_train_loss, train_accuracy):
            print(f"\nEarly stopping: {early_stopping.stop_reason}")
            break

    epoch_pbar.close()
    writer.close()
    return epoch_metrics

# Evaluation Function without BLEU Score
def evaluate_model(model, test_loader, char2idx, idx2char):
    model.eval()
    test_loss = 0
    test_correct = 0
    test_total = 0
    criterion = nn.CrossEntropyLoss(ignore_index=char2idx['<pad>'], label_smoothing=LABEL_SMOOTHING)
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating on Test Set"):
            char_inputs = batch['word_token'].to(DEVICE)
            split_output = batch['split_token'].to(DEVICE)
            phon_feature_inputs = batch['phon_features'].to(DEVICE)
            src_key_padding_mask = (char_inputs == char2idx['<pad>'])
            # Forward pass
            output = model(
                char_inputs=char_inputs,
                phon_feature_inputs=phon_feature_inputs,
                split_output=split_output[:, :-1],
                src_key_padding_mask=src_key_padding_mask,
                tgt_mask=nn.Transformer.generate_square_subsequent_mask(split_output[:, :-1].size(1)).to(DEVICE),
                tgt_key_padding_mask=(split_output[:, :-1] == char2idx['<pad>'])
            )
            loss = criterion(output.view(-1, output.size(-1)), split_output[:, 1:].reshape(-1))
            test_loss += loss.item()
            # Calculate accuracy
            pred = output.argmax(dim=-1)
            mask = (split_output[:, 1:] != char2idx['<pad>'])
            test_correct += ((pred == split_output[:, 1:]) & mask).sum().item()
            test_total += mask.sum().item()
    avg_test_loss = test_loss / len(test_loader)
    test_accuracy = (test_correct / test_total) * 100
    print("\nTest Set Evaluation:")
    print(f"Test Loss: {avg_test_loss:.6f}")
    print(f"Test Accuracy: {test_accuracy:.2f}%")

# Main Function
def main():
    print("\nLoading dataset...")

    # Define Paths
    base_path = './'
    dataset_path = os.path.join(base_path, "sandhi_data.csv")
    model_save_path = os.path.join(base_path, "model.pt")
    tensorboard_log_dir = os.path.join(base_path, LOG_DIR)

    # Prepare Data
    train_dataset, val_dataset, test_dataset, char2idx = prepare_data(csv_path=dataset_path)

    # Create idx2char mapping
    idx2char = {idx: char for char, idx in char2idx.items()}

    # Create Data Loaders with optimized parameters
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                              collate_fn=collate_fn, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                            collate_fn=collate_fn, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                             collate_fn=collate_fn, num_workers=4, pin_memory=True)

    # Initialize TensorBoard
    writer = initialize_tensorboard()

    # Initialize Model
    print("\nInitializing model...")
    vocab_size = len(char2idx)
    model = SandhiTransformer(
        char_vocab_size=vocab_size,
        phon_feature_dim=26,
        hidden_dim=256,
        num_layers=4,
        num_heads=8,
        dropout=0.1,
        max_length=MAX_LENGTH
    ).to(DEVICE)
    print(f"Model initialized and moved to {DEVICE}.")

    # Optimizer and Scheduler
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
    scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

    # Loss Function with Label Smoothing
    criterion = nn.CrossEntropyLoss(ignore_index=char2idx['<pad>'], label_smoothing=LABEL_SMOOTHING)

    # Display Configuration
    print("\nTraining Configuration:")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Learning Rate: {LEARNING_RATE}")
    print(f"Number of Epochs: {EPOCHS}")
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Vocabulary Size: {vocab_size}")
    print(f"Feature Dimension: 26")
    print(f"TensorBoard Log Directory: {tensorboard_log_dir}")

    # Train the Model
    epoch_metrics = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        num_epochs=EPOCHS,
        model_save_path=model_save_path,
        char2idx=char2idx,
        writer=writer,
        idx2char=idx2char  # Pass idx2char here
    )

    # Save Final Model
    torch.save(model.state_dict(), model_save_path)
    print(f"\nFinal model saved at {model_save_path}")

    # Evaluate on the Test Set
    evaluate_model(model, test_loader, char2idx, idx2char)

# Run the Main Function
if __name__ == '__main__':
    main()


Loading dataset...

Loading dataset...
Total samples: 1,009,439


Constructing character mapping:   0%|          | 0/1009439 [00:00<?, ?it/s]

Vocabulary size: 53

Train: 807,551 (80.0%)
Val: 100,944 (10.0%)
Test: 100,944 (10.0%)
Character to index mapping saved at ./char2idx.pt

Initializing model...




Model initialized and moved to cuda.

Training Configuration:
Batch Size: 64
Learning Rate: 0.0001
Number of Epochs: 30
Model Parameters: 10,829,621
Vocabulary Size: 53
Feature Dimension: 26
TensorBoard Log Directory: ././logs

Starting training...


Epoch Progress:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]



Epoch 1 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]


Epoch  Train Loss   Train Acc (%)   Val Loss     Val Acc (%)     Status              
-------------------------------------------------------------------------------------
1      0.993359     90.22           0.816144     96.18           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 2 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 2 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

2      0.818025     96.01           0.789001     97.14           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्रौ+अशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 3 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 3 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

3      0.793214     96.89           0.776903     97.57           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 4 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 4 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

4      0.780495     97.36           0.767975     97.92           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 5 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 5 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

5      0.772179     97.66           0.762831     98.10           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 6 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 6 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

6      0.766496     97.87           0.759393     98.24           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 7 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 7 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

7      0.762120     98.04           0.757211     98.34           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 8 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 8 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

8      0.758727     98.17           0.754187     98.44           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 9 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 9 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

9      0.755715     98.28           0.753022     98.49           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 10 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 10 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

10     0.753322     98.37           0.750545     98.59           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 11 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 11 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

11     0.751164     98.44           0.749165     98.63           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 12 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 12 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

12     0.749196     98.52           0.748069     98.68           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 13 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 13 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

13     0.747330     98.59           0.747331     98.70           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 14 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 14 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

14     0.745737     98.65           0.746895     98.73           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 15 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 15 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

15     0.744194     98.71           0.745493     98.78           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 16 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 16 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

16     0.742795     98.76           0.745230     98.79           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 17 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 17 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

17     0.741385     98.81           0.744490     98.82           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 18 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 18 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

18     0.740178     98.86           0.743883     98.84           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 19 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 19 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

19     0.739025     98.90           0.743724     98.86           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 20 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 20 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

20     0.737932     98.94           0.743238     98.88           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 21 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 21 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

21     0.736855     98.98           0.742982     98.89           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 22 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 22 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

22     0.735999     99.02           0.743063     98.90                               

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 23 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 23 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

23     0.735149     99.05           0.742739     98.91           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 24 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 24 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

24     0.734476     99.07           0.742465     98.92           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 25 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 25 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

25     0.733789     99.10           0.742515     98.92                               

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 26 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 26 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

26     0.733278     99.12           0.742364     98.93           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 27 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 27 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

27     0.732853     99.14           0.742389     98.94                               

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 28 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 28 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

28     0.732530     99.15           0.742361     98.94           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 29 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 29 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

29     0.732321     99.16           0.742356     98.94           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------


Epoch 30 - Training:   0%|          | 0/12618 [00:00<?, ?it/s]

Epoch 30 - Validation:   0%|          | 0/1578 [00:00<?, ?it/s]

30     0.732166     99.16           0.742328     98.94           ✓ Saved             

Sample predictions for this epoch:
--------------------------------------------------
Input:     कन्यामात्रावशेषिताः
True:      कन्या+मात्र+अवशेषिताः
Predicted: कन्या+मात्र+अवशेषिताः
--------------------------------------------------
Input:     तापसोऽभवम्
True:      तापसः+अभवम्
Predicted: तापसः+अभवम्
--------------------------------------------------
Input:     सुकृता
True:      सु+कृता
Predicted: सु+कृता
--------------------------------------------------

Final model saved at ./model.pt


Evaluating on Test Set:   0%|          | 0/1578 [00:00<?, ?it/s]


Test Set Evaluation:
Test Loss: 0.741269
Test Accuracy: 98.97%


In [16]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Constants and Hyperparameters
MAX_LENGTH = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BEAM_WIDTH = 5
LENGTH_PENALTY = 0.6

# Positional Encoding Module
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=MAX_LENGTH):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

# Sandhi Transformer Model
class SandhiTransformer(nn.Module):
    def __init__(self, char_vocab_size: int, phon_feature_dim: int, 
                 hidden_dim: int, num_layers: int, num_heads: int, 
                 dropout: float = 0.1, max_length: int = MAX_LENGTH):
        super(SandhiTransformer, self).__init__()
        self.hidden_dim = hidden_dim

        # Character Embedding with Dropout
        self.char_embedding = nn.Embedding(char_vocab_size, hidden_dim, padding_idx=0)
        self.dropout = nn.Dropout(dropout)

        # Phonetic Feature Transformation
        self.phon_feature_linear = nn.Linear(phon_feature_dim, hidden_dim)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(hidden_dim, max_len=max_length)

        # Transformer Encoders for Characters and Phonetic Features
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='relu'
        )
        self.char_transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.phon_transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention-based Memory Fusion
        self.memory_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)

        # Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='relu'
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Layer Normalization
        self.layer_norm = nn.LayerNorm(hidden_dim)

        # Classifier
        self.classifier = nn.Linear(hidden_dim, char_vocab_size)

    def attention_fusion(self, memory1, memory2):
        att_fusion = self.memory_attention(memory1, memory2, memory2)[0]
        return att_fusion

    def forward(self, char_inputs, phon_feature_inputs, split_output, 
                src_key_padding_mask=None, tgt_mask=None, tgt_key_padding_mask=None):
        # Character Embedding for Word
        word_char_embedded = self.char_embedding(char_inputs)  # (batch_size, seq_len, hidden_dim)
        word_char_embedded = self.dropout(word_char_embedded)
        word_char_embedded = self.pos_encoder(word_char_embedded)
        word_char_embedded = word_char_embedded.transpose(0, 1)  # (seq_len, batch_size, hidden_dim)

        # Phonetic Feature Transformation
        feature_embedded = self.phon_feature_linear(phon_feature_inputs)  # (batch_size, seq_len, hidden_dim)
        feature_embedded = self.dropout(feature_embedded)
        feature_embedded = self.pos_encoder(feature_embedded)
        feature_embedded = feature_embedded.transpose(0, 1)  # (seq_len, batch_size, hidden_dim)

        # Transformer Encoders
        char_transformed = self.char_transformer_encoder(word_char_embedded, src_key_padding_mask=src_key_padding_mask)
        phon_transformed = self.phon_transformer_encoder(feature_embedded, src_key_padding_mask=src_key_padding_mask)

        # Attention-based Memory Fusion
        att_fusion = self.attention_fusion(char_transformed, phon_transformed)  # (seq_len, batch_size, hidden_dim)

        # Split Token Embedding
        split_char_embedded = self.char_embedding(split_output)  # (batch_size, tgt_seq_len, hidden_dim)
        split_char_embedded = self.dropout(split_char_embedded)
        split_char_embedded = self.pos_encoder(split_char_embedded)
        split_char_embedded = split_char_embedded.transpose(0, 1)  # (tgt_seq_len, batch_size, hidden_dim)

        # Transformer Decoder
        transformed = self.decoder(split_char_embedded, att_fusion, tgt_mask=tgt_mask, 
                                   tgt_key_padding_mask=tgt_key_padding_mask)
        transformed = transformed.transpose(0, 1)  # (batch_size, tgt_seq_len, hidden_dim)

        # Layer Normalization
        transformed = self.layer_norm(transformed)

        # Classifier
        logits = self.classifier(transformed)  # (batch_size, tgt_seq_len, vocab_size)
        return logits

# Beam Search Decoding Function
def beam_search(model, char2idx, idx2char, input_chars, input_features, src_key_padding_mask, beam_width=BEAM_WIDTH, length_penalty=LENGTH_PENALTY):
    model.eval()
    sequences = [[char2idx['<s>']]]
    scores = torch.zeros(len(sequences), device=DEVICE)
    
    with torch.no_grad():
        for _ in range(MAX_LENGTH):
            all_candidates = []
            for i, seq in enumerate(sequences):
                if seq[-1] == char2idx['</s>']:
                    all_candidates.append((seq, scores[i]))
                    continue
                # Prepare decoder input
                decoder_input = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE)
                seq_len = decoder_input.size(1)
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(DEVICE)
                # Forward pass
                output = model(
                    char_inputs=input_chars.unsqueeze(0),  # (1, seq_len)
                    phon_feature_inputs=input_features.unsqueeze(0),  # (1, seq_len, 26)
                    split_output=decoder_input,  # (1, seq_len)
                    src_key_padding_mask=src_key_padding_mask.unsqueeze(0),
                    tgt_mask=tgt_mask
                )
                logits = output[0, -1, :]  # (vocab_size)
                log_probs = torch.log_softmax(logits, dim=-1)
                topk_probs, topk_indices = torch.topk(log_probs, beam_width)
                for k in range(beam_width):
                    candidate = seq + [topk_indices[k].item()]
                    score = scores[i] + topk_probs[k]
                    all_candidates.append((candidate, score))
            # Select top beam_width sequences
            ordered = sorted(all_candidates, key=lambda tup: tup[1]/(len(tup[0])**length_penalty), reverse=True)
            sequences = []
            scores = []
            for seq, score in ordered[:beam_width]:
                sequences.append(seq)
                scores.append(score)
            # Check if all sequences have ended
            if all(seq[-1] == char2idx['</s>'] for seq in sequences):
                break
    # Select the best sequence
    best_seq = sequences[0]
    # Remove start and end tokens
    if best_seq[0] == char2idx['<s>']:
        best_seq = best_seq[1:]
    if best_seq[-1] == char2idx['</s>']:
        best_seq = best_seq[:-1]
    return best_seq

# Function to Load Model and Mappings
def load_model(model_path: str, char2idx_path: str, hidden_dim=256, num_layers=4, num_heads=8, phon_feature_dim=26):
    # Load char2idx
    if not os.path.exists(char2idx_path):
        raise FileNotFoundError(f"char2idx mapping not found at {char2idx_path}")
    char2idx = torch.load(char2idx_path)
    idx2char = {idx: char for char, idx in char2idx.items()}
    
    # Initialize Model
    vocab_size = len(char2idx)
    model = SandhiTransformer(
        char_vocab_size=vocab_size,
        phon_feature_dim=phon_feature_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        num_heads=num_heads,
        dropout=0.1,
        max_length=MAX_LENGTH
    ).to(DEVICE)
    
    # Load Model Weights
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    
    return model, char2idx, idx2char

# Function to Process Input Word
def process_input(word: str, char2idx: Dict[str, int], phon_feature_dim=26):
    # Tokenize the word
    segments = word.split('+') if '+' in word else [word]
    char_encoding = [char2idx.get('<s>', 2)]  # Start token
    for i, segment in enumerate(segments):
        try:
            chars = [char2idx.get(tup[0], char2idx['<unk>']) for tup in skt.get_ucchaarana_vectors(segment)]
            char_encoding += chars
            if i + 1 < len(segments):
                char_encoding.append(char2idx.get('+', 3))  # '+' token
        except Exception as e:
            print(f"Error tokenizing segment '{segment}': {str(e)}")
            raise
    char_encoding.append(char2idx.get('</s>', 4))  # End token
    
    # Pad or truncate
    if len(char_encoding) > MAX_LENGTH:
        char_encoding = char_encoding[:MAX_LENGTH]
    else:
        char_encoding += [char2idx['<pad>']] * (MAX_LENGTH - len(char_encoding))
    
    # Convert to tensor
    char_tensor = torch.tensor(char_encoding, dtype=torch.long).to(DEVICE)
    
    # Extract phonetic features
    phonetic_features = []
    try:
        for tup in skt.get_ucchaarana_vectors(word):
            phon_vec = list(tup[1].values())
            phonetic_features.append(phon_vec)
    except Exception as e:
        print(f"Error extracting phonetic features for '{word}': {str(e)}")
        raise
    
    # Pad phonetic features
    if len(phonetic_features) > MAX_LENGTH:
        phonetic_features = phonetic_features[:MAX_LENGTH]
    else:
        phonetic_features += [[0] * phon_feature_dim for _ in range(MAX_LENGTH - len(phonetic_features))]
    
    phon_tensor = torch.tensor(phonetic_features, dtype=torch.float).to(DEVICE)
    
    return char_tensor, phon_tensor

# Function to Perform Sandhi Splitting
def sandhi_split(model: nn.Module, word: str, char2idx: Dict[str, int], idx2char: Dict[int, str]):
    # Process input
    try:
        input_chars, input_features = process_input(word, char2idx)
    except Exception as e:
        print(f"Failed to process input word '{word}': {str(e)}")
        return None
    
    # Create src_key_padding_mask
    src_key_padding_mask = (input_chars == char2idx['<pad>'])
    
    # Perform beam search
    try:
        generated_seq = beam_search(
            model=model,
            char2idx=char2idx,
            idx2char=idx2char,
            input_chars=input_chars,
            input_features=input_features,
            src_key_padding_mask=src_key_padding_mask,
            beam_width=BEAM_WIDTH,
            length_penalty=LENGTH_PENALTY
        )
    except Exception as e:
        print(f"Error during beam search for '{word}': {str(e)}")
        return None
    
    # Decode the generated sequence
    pred_split = ''.join([idx2char.get(idx, '') for idx in generated_seq])
    # Replace multiple '+' with single '+' and remove leading/trailing '+'
    pred_split = '+'.join(filter(None, pred_split.split('+')))
    # Transliterate to Devanagari for consistency (if needed)
    pred_split = transliterate.process(
        'IAST', 'Devanagari',
        transliterate.process('Devanagari', 'IAST', pred_split)
    )
    
    return pred_split

# Load Model and Mappings
# Replace 'model.pt' and 'char2idx.pt' with actual file paths
model_path = 'model.pt'       # Replace with your model's path
char2idx_path = 'char2idx.pt' # Replace with your char2idx mapping path

try:
    model, char2idx, idx2char = load_model(model_path, char2idx_path)
    print("Model and character mappings loaded successfully.")
except Exception as e:
    print(f"Error loading model or mappings: {str(e)}")

# Function to Inference Sandhi Splitting
def infer_sandhi_split(word: str):
    split = sandhi_split(model, word, char2idx, idx2char)
    if split:
        print(f"\nInput Word (Devanagari): {word}")
        print(f"Sandhi Split (Devanagari): {split}\n")
    else:
        print("\nFailed to generate sandhi split.\n")

# Create interactive widgets
word_input = widgets.Text(
    value='',
    placeholder='Enter Sanskrit word in Devanagari',
    description='Word:',
    disabled=False
)

button = widgets.Button(
    description='Split Sandhi',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to split sandhi',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

output_area = widgets.Output()

# Define button click event
def on_button_click(b):
    with output_area:
        clear_output()
        word = word_input.value.strip()
        if word.lower() == 'exit':
            print("Exiting the inference interface.")
            # Optionally, you can disable the widgets after exit
            word_input.disabled = True
            button.disabled = True
        elif word:
            infer_sandhi_split(word)
        else:
            print("Please enter a valid Sanskrit word.")

button.on_click(on_button_click)

# Display widgets
display(word_input, button, output_area)

Model and character mappings loaded successfully.


Text(value='', description='Word:', placeholder='Enter Sanskrit word in Devanagari')

Button(button_style='success', description='Split Sandhi', icon='check', style=ButtonStyle(), tooltip='Click t…

Output()