# Importing all the important libraries

In [1]:
import numpy as np
import torch
from utils.dataframe import (
    save_tmp_df, load_tmp_df, load_models_df,
    save_model_variants_gen_df, load_model_variants_gen_df,
    convert_to_hf,
)
from utils.gpu import get_device
from utils.common import (
    compute_metrics_hf_batch,
    convert_to_mean_scores_df,
)
from IPython.display import display
from tqdm.notebook import tqdm
from transformers import (
    logging,
    AutoTokenizer, MT5ForConditionalGeneration
)
from peft import PeftModel

# Set settings

In [2]:
tqdm.pandas()

In [3]:
# Suppress specific warnings from the transformers library
logging.set_verbosity_error()

# Common

In [4]:
# gpu device 
device = get_device()


Devices:  [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU details:  {'device_name': 'METAL'}
Using device: mps


In [5]:
# mT5 model path
model_names = {
    "bpe": "model-variants/models/mT5_BPE",
    "unigram": "model-variants/models/mT5_UNIGRAM"
}

In [None]:
def get_final_model_with_contextual_embeddings(spt_name):
    # Load tokenizers & models
    tokenizer = AutoTokenizer.from_pretrained(model_names[spt_name], use_fast=False, legacy=True)
    model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    model = PeftModel.from_pretrained(model, model_names[spt_name]).to(device)
    model.eval()

    # Load Contextual Embeddings
    contextual_embeddings = torch.load(f"model-variants/gen/{spt_name}_projected_contextual_embeddings.pt", map_location=device)

    return model, tokenizer, contextual_embeddings

# Generate Predictions

In [None]:
def generate_predictions(spt_name, batch_size=128, max_length=512):
    model, tokenizer, contextual_embeddings = get_final_model_with_contextual_embeddings(spt_name)

    # load data
    dataset = load_models_df("multilingual_combined")
    dataset = convert_to_hf(dataset)

    # dataset = dataset.select(range(100))

    def predict_fn(batch):
        batch_size = len(batch["burmese"])

        # Tokenize input texts
        inputs = tokenizer(
            batch["burmese"],
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(device)

        seq_len = inputs["input_ids"].shape[1]  # Get sequence length

        # Fix Contextual Embeddings Shape
        contextual_embeds = contextual_embeddings[:batch_size]  # Ensure batch size matches
        if contextual_embeds.dim() == 2:  # (batch_size, hidden_dim)
            contextual_embeds = contextual_embeds.unsqueeze(1).expand(-1, seq_len, -1)

        # Ensure correct device
        contextual_embeds = contextual_embeds.to(device)

        # Convert tokenized inputs to embeddings
        input_embeds = model.get_input_embeddings()(inputs["input_ids"])

        # Inject contextual embeddings by **adding** them to token embeddings
        final_embeds = input_embeds + contextual_embeds

        # Generate text using **concatenated embeddings**
        output_tokens = model.generate(
            inputs_embeds=final_embeds,  # Inject contextual embeddings
            attention_mask=inputs["attention_mask"],
            num_beams=2,
            use_cache=True,
            repetition_penalty=1.5,  # Avoids excessive repetition
            max_length=max_length,
        )

        # Decode predictions
        generated_texts = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)

        return {"generated": generated_texts}


    dataset = dataset.map(predict_fn, batched=True, batch_size=batch_size)

    display(dataset.to_pandas().head())

    save_model_variants_gen_df(dataset, f"{spt_name}_final_predictions")

In [14]:
# with bpe
generate_predictions("bpe")

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Unnamed: 0,english,burmese,generated
0,it's not worth seeing the nubian floor exhibit...,အထက် အီဂျစ်မှာ နူဘီးယား ကြမ်းပြင်ပြပွဲကို ကြည့...,အီဂျစ်မှာ နူဘီးယား ကြမ်းပြင် ပွဲကို ကြည့်ဖို့ ...
1,there are remote whitewashed villages that adv...,စွန့်စားချင်သူတွေ လည်ပတ်ချင်ကြတဲ့ ဝေးလံခေါင်သီ...,ဝေးလံခေါင် မြို့ကို လည်ပတ် နေတဲ့ ရွာတွေ ရှိတယ်။
2,"she makes these little tricks, very good, and ...",သူမက ဒီပျဉ်းစေ့ကြိုးတွေ လုပ်ပေးတယ် အရမ်းကောင်း...,သူမက ဒီပျဉ်းစေ့ကြိုး တွေကို လုပ်ပေး တာ အရမ်းကေ...
3,the pair regained zimbabwe's times and finishe...,ထိုစုံတွဲသည် ဇင်ဘာဘွေ၏ အကြိမ်များကို ပြန်လည်ရရ...,ဇင်ဘာဘွေ သည် ဇင်ဘာဘွေ ၏ အကြိမ်အကြိမ် များကို ပ...
4,potential of clarifying its notices to taxpaye...,အခွန်ထမ်းများထံ ၎င်း၏သတိပေးချက်များကို ရှင်းလင...,အခွန်ထမ်း များသည် ၎င်းတို့၏ လုပ်ငန်းတာဝန် များ...


Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
# with unigram
generate_predictions("unigram")

# Evaluate Model Performance
Compute BLEU, ROUGE-1, ROUGE-2, ROUGE-3, ROUGE-L, chrF-S, BERTScore and Perplexity scores.

## Metrics

In [None]:
# Function to Compute Metrics for Fine-Tuned Model using HF Dataset
def compute_metric(spt_name):
    # Load dataset
    metrics_dataset = load_model_variants_gen_df(f"{spt_name}_final_predictions")
    metrics_dataset = convert_to_hf(metrics_dataset)

    # if debug, remove comment
    #metrics_dataset = metrics_dataset.select(range(100))  # Keep this for debugging

    # Compute metrics
    print(f"Processing Data for {spt_name.upper()}...")
    metrics_dataset = compute_metrics_hf_batch(metrics_dataset, device)

    # Display results
    print(f"Metrics scores for {spt_name.upper()}:")
    print(f"BLEU Score: {np.mean(metrics_dataset['bleu'])}")
    print(f"ROUGE-1 Score: {np.mean(metrics_dataset['rouge-1'])}")
    print(f"ROUGE-2 Score: {np.mean(metrics_dataset['rouge-2'])}")
    print(f"ROUGE-L Score: {np.mean(metrics_dataset['rouge-l'])}")
    print(f"chrF-S Score: {np.mean(metrics_dataset['chrf-s'])}")
    print(f"BERT Score: {np.mean(metrics_dataset['bert_score'])}")

    # Save results
    save_tmp_df(metrics_dataset, f"{spt_name}_final_metrics")

In [None]:
# with bpe
compute_metric("bpe")

In [None]:
# with unigram
compute_metric("unigram")

## Perplexity

In [None]:
def compute_perplexity_batch(texts, model, tokenizer, contextual_embeddings, max_length=512):
    """
    Computes perplexity for a batch of text using an mT5 model with contextual embeddings.
    """
    # Tokenize texts
    inputs = tokenizer(
        texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    ).to(device)

    # Prepare labels (ignore padding)
    labels = inputs["input_ids"].clone()
    labels[labels == tokenizer.pad_token_id] = -100  # Ignore padding tokens in loss calculation

    # Ensure contextual embeddings match batch size
    if contextual_embeddings is not None:
        contextual_embeddings = contextual_embeddings[:len(texts)].to(device).half()

        # Normalize embeddings to prevent extreme values
        contextual_embeddings = contextual_embeddings / (contextual_embeddings.norm(dim=-1, keepdim=True) + 1e-6)

        if contextual_embeddings.dim() == 2:  # (batch_size, hidden_dim)
            seq_len = inputs["input_ids"].shape[1]  # Get tokenized input sequence length
            contextual_embeddings = contextual_embeddings[:, :seq_len]  # Trim extra tokens if needed
            contextual_embeddings = contextual_embeddings.unsqueeze(1).expand(-1, seq_len, -1)

        # Fix hidden dimension mismatch (should match model embedding size)
        hidden_dim = model.get_input_embeddings().weight.shape[-1]
        if contextual_embeddings.shape[-1] != hidden_dim:
            contextual_embeddings = torch.nn.functional.pad(
                contextual_embeddings,
                (0, hidden_dim - contextual_embeddings.shape[-1]),
                "constant",
                0
            )

    with torch.no_grad():
        # Convert tokenized inputs to embeddings
        input_embeds = model.get_input_embeddings()(inputs["input_ids"])

        # Inject contextual embeddings using direct addition (matching `generate_predictions`)
        if contextual_embeddings is not None:
            final_embeds = input_embeds + contextual_embeddings  # **Ensures consistency with generation**
        else:
            final_embeds = input_embeds

        outputs = model(
            inputs_embeds=final_embeds,
            attention_mask=inputs["attention_mask"],
            labels=labels
        )

    # Shift logits & labels for loss calculation
    shift_logits = outputs.logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    shift_attention_mask = inputs["attention_mask"][:, 1:].contiguous()

    # Compute per-token loss
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
    per_token_loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    
    # Reshape loss
    per_token_loss = per_token_loss.view(shift_labels.shape)

    # Mask out padding tokens
    per_token_loss *= shift_attention_mask

    # Compute sentence-level mean loss
    sentence_loss = per_token_loss.sum(dim=1) / shift_attention_mask.sum(dim=1)

    # Convert to perplexity (clamping max loss to prevent explosion)
    perplexity_scores = torch.exp(torch.clamp(sentence_loss, max=10)).cpu().numpy()

    return perplexity_scores

In [None]:
def compute_perplexity(spt_name, part_num=1, batch_size=8, max_length=512):
    """
    Computes perplexity for a specific part of the dataset, using the same method as `generate_predictions`.
    
    Arguments:
        spt_name (str): Model name identifier.
        batch_size (int): Batch size for processing.
        part_num (int): Part number (1-2) to process.
        max_length (int): Maximum sequence length (must match `generate_predictions`).
    """

    # Validate part number
    num_splits = 2
    if part_num not in range(1, num_splits + 1):
        raise ValueError(f"Invalid part number. Please choose between 1 and {num_splits}.")

    # Load dataset
    print(f"Loading dataset for {spt_name} (Part {part_num})...")
    dataset = load_model_variants_gen_df(f"{spt_name}_final_predictions")
    dataset = convert_to_hf(dataset)

    # Load contextual embeddings
    model, tokenizer, contextual_embeddings = get_final_model_with_contextual_embeddings(spt_name)

    # Split dataset into 6 parts
    split_size = len(dataset) // num_splits
    datasets = [dataset.select(range(i * split_size, (i + 1) * split_size)) for i in range(num_splits)]

    # Split contextual embeddings
    contextual_splits = [None] * num_splits  # Default to None if no embeddings
    if contextual_embeddings is not None:
        contextual_splits = [contextual_embeddings[i * split_size: (i + 1) * split_size] for i in range(num_splits)]

    # Get the dataset and contextual embeddings for the selected part
    dataset_part = datasets[part_num - 1]
    contextual_embeddings_part = contextual_splits[part_num - 1]

    # remove comment for debug
    # dataset_part = dataset_part.select(range(100))

    print(f"Processing Part {part_num} with {len(dataset_part)} samples...")

    def compute_perplexity_fn(batch):
        """
        Compute perplexity for a batch of text.
        """
        batch_size = len(batch["generated"])  # Ensure batch size consistency

        # Extract text inputs
        texts = [str(text) if text is not None else "" for text in batch["generated"]]

        # Fix Contextual Embeddings Shape
        contextual_embeds = contextual_embeddings_part[:batch_size] if contextual_embeddings_part is not None else None
        if contextual_embeds is not None and contextual_embeds.dim() == 2:  # (batch_size, hidden_dim)
            seq_len = max(len(tokenizer.encode(text)) for text in texts)  # Get max seq length in batch
            contextual_embeds = contextual_embeds[:, :seq_len]  # Trim contextual embeddings to match seq length
            contextual_embeds = contextual_embeds.unsqueeze(1).expand(-1, seq_len, -1)

        # Compute perplexity
        perplexity_scores = compute_perplexity_batch(texts, model, tokenizer, contextual_embeds, max_length)

        return {"perplexity": perplexity_scores}

    # Compute perplexity in batches
    dataset_part = dataset_part.map(compute_perplexity_fn, batched=True, batch_size=batch_size)

    # Save results
    save_tmp_df(dataset_part, f"{spt_name}_final_perplexity_part_{part_num}")

    print(f"Completed Part {part_num} Processing.")

### BPE

In [None]:
# part 1
compute_perplexity("bpe", part_num=1)

In [None]:
# part 2
compute_perplexity("bpe", part_num=2)

### Unigram

In [None]:
# part 1
compute_perplexity("unigram", part_num=1)

In [None]:
# part 2
compute_perplexity("unigram", part_num=2)

## Save Evaluation Results

In [None]:
# combine evaluation results
for spt_name in model_names.keys():
    print(f"Processing {spt_name}...")

    evaluation_results = load_model_variants_gen_df(f"{spt_name}_final_predictions")

    # load metrics and set
    metrics = load_tmp_df(f"{spt_name}_final_metrics")
    evaluation_results["bleu"] = metrics["bleu"]
    evaluation_results["rouge-1"] = metrics["rouge-1"]
    evaluation_results["rouge-2"] = metrics["rouge-2"]
    evaluation_results["rouge-l"] = metrics["rouge-l"]
    evaluation_results["chrf-s"] = metrics["chrf-s"]
    evaluation_results["bert_score"] = metrics["bert_score"]

    # load perplexity and set
    perplexity = load_tmp_df(f"{spt_name}_final_perplexity")
    evaluation_results["perplexity"] = perplexity["perplexity"]

    save_model_variants_gen_df(evaluation_results, f"{spt_name}_final_evaluation_results")

# Benchmarking and Analysis

In [None]:
# load data
final_benchmarking_datasets = {}
for spt_name in model_names.keys():
    df = load_model_variants_gen_df(f"{spt_name}_final_evaluation_results")
    final_benchmarking_datasets[f"{spt_name.upper()}"] = df

In [None]:
# convert to mean score df
final_benchmarking_mean_scores = convert_to_mean_scores_df(final_benchmarking_datasets)

In [None]:
# Display mean scores
display(final_benchmarking_mean_scores)

In [None]:
# save benchmarking results
save_model_variants_gen_df(final_benchmarking_mean_scores, "final_evaluation_results")