# Importing all the important libraries

In [None]:
import pandas as pd
import numpy as np
import torch
import sentencepiece as spm
from utils.dataframe import (
    load_gen_df, save_tmp_df, load_tmp_df, load_models_df,
    save_model_variants_df, load_model_variants_df,
    save_model_variants_hf, load_model_variants_hf,
    save_model_variants_gen_df, load_model_variants_gen_df,
    convert_to_hf, save_model_variants_chunk_hf,
)
from utils.gpu import get_device
from utils.common import (
    apply_lora, TRAIN_ARGS,
    generate_masked_predictions_hf_batch, generate_mt5_predictions_hf_batch,
    compute_metrics_hf_batch,
    convert_to_mean_scores_df,
    get_fine_tuned_model, get_embedded_fine_tuned_model,
    compute_multilingual_masked_perplexity_hf_batch, compute_mt5_perplexity_batch,
    extract_metrics_from_logs,
    plot_training_metrics, plot_evaluation_metrics
)
from IPython.display import display
from tqdm.notebook import tqdm
from transformers import (
    logging,
    AutoTokenizer, MT5ForConditionalGeneration
)
from peft import PeftModel
from torch import nn

# Set settings

In [None]:
tqdm.pandas()

In [None]:
# Suppress specific warnings from the transformers library
logging.set_verbosity_error()

# Common

In [None]:
# gpu device 
device = get_device()

In [None]:
# mT5 model path
model_names = {
    "bpe": "model-variants/models/mT5_BPE",
    "unigram": "model-variants/models/mT5_UNIGRAM"
}

In [None]:
def get_final_model_with_contextual_embeddings(spt_name):
    # Load tokenizers & models
    tokenizer = AutoTokenizer.from_pretrained(model_names[spt_name], use_fast=False, legacy=True)
    model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    model = PeftModel.from_pretrained(model, model_names[spt_name]).to(device)
    model.eval()

    # Load Contextual Embeddings
    contextual_embeddings = torch.load(f"model-variants/gen/{spt_name}_projected_contextual_embeddings.pt").to(device)

    return model, tokenizer, contextual_embeddings

# Generate Predictions

In [None]:
# Function to generate predictions
def generate_predictions(spt_name, batch_size=128, max_length=512):

    # Load
    model, tokenizer, contextual_embeddings = get_final_model_with_contextual_embeddings(spt_name)

    # Load dataset 
    dataset = load_models_df("multilingual_combined")

    dataset = convert_to_hf(dataset)

    # remove comment for debug
    # dataset = dataset.select(range(100))

    def predict_fn(batch):
        """
        Processes a batch of text inputs with contextual embeddings.
        """
        batch_size = len(batch["burmese"])

        # Tokenize input texts
        inputs = tokenizer(
            batch["burmese"],
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(device)

        seq_len = inputs["input_ids"].shape[1]  # Get sequence length

        # Expand contextual embeddings to match input length
        contextual_embeds = contextual_embeddings[:batch_size]  # Ensure batch size matches
        if contextual_embeds.dim() == 2:  # (batch_size, hidden_dim)
            contextual_embeds = contextual_embeds.unsqueeze(1).expand(-1, seq_len, -1)

        # Generate text (without decoder_input_ids)
        output_tokens = model.generate(
            inputs_embeds=contextual_embeds,
            attention_mask=inputs["attention_mask"],
            num_beams=2,
            repetition_penalty=1.5,  # Reduce excessive repetition
            max_length=max_length
        )

        # Decode predictions
        generated_texts = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)

        return {"generated": generated_texts}
    
    # Process dataset in batches
    dataset = dataset.map(predict_fn, batched=True, batch_size=batch_size)

    # Display results
    display(dataset.to_pandas().head())

    # Save dataset
    save_model_variants_gen_df(dataset, f"{spt_name}_final_predictions")

In [None]:
# with bpe
generate_predictions("bpe")

In [None]:
# with unigram
generate_predictions("unigram")

# Evaluate Model Performance
Compute BLEU, ROUGE-1, ROUGE-2, ROUGE-3, ROUGE-L, chrF-S, BERTScore and Perplexity scores.

## Metrics

In [None]:
# Function to Compute Metrics for Fine-Tuned Model using HF Dataset
def compute_metric(spt_name):
    # Load dataset
    metrics_dataset = load_model_variants_gen_df(f"{spt_name}_final_predictions")
    metrics_dataset = convert_to_hf(metrics_dataset)

    # if debug, remove comment
    #metrics_dataset = metrics_dataset.select(range(100))  # Keep this for debugging

    # Compute metrics
    print(f"Processing Data for {spt_name.upper()}...")
    metrics_dataset = compute_metrics_hf_batch(metrics_dataset, device)

    # Display results
    print(f"Metrics scores for {spt_name.upper()}:")
    print(f"BLEU Score: {np.mean(metrics_dataset['bleu'])}")
    print(f"ROUGE-1 Score: {np.mean(metrics_dataset['rouge-1'])}")
    print(f"ROUGE-2 Score: {np.mean(metrics_dataset['rouge-2'])}")
    print(f"ROUGE-L Score: {np.mean(metrics_dataset['rouge-l'])}")
    print(f"chrF-S Score: {np.mean(metrics_dataset['chrf-s'])}")
    print(f"BERT Score: {np.mean(metrics_dataset['bert_score'])}")

    # Save results
    save_tmp_df(metrics_dataset, f"{spt_name}_final_metrics")

In [None]:
# with bpe
compute_metric("bpe")

In [None]:
# with unigram
compute_metric("unigram")

## Perplexity

In [None]:
def compute_perplexity_batch(texts, model, tokenizer, contextual_embeddings):
    """
    Computes perplexity for a batch of text using an mT5 model with contextual embeddings.
    """
    # Tokenize texts
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)

    # Prepare labels (same as input_ids, but padding tokens should be ignored)
    labels = inputs["input_ids"].clone()
    labels[labels == tokenizer.pad_token_id] = -100  # Ignore padding tokens in loss calculation

    # Expand contextual embeddings
    if contextual_embeddings is not None:
        contextual_embeddings = contextual_embeddings.to(device)

        if contextual_embeddings.dim() == 2:  # (batch_size, hidden_dim)
            seq_len = inputs["input_ids"].shape[1]
            contextual_embeddings = contextual_embeddings.unsqueeze(1).expand(-1, seq_len, -1)

    with torch.no_grad():
        # Inject embeddings into model
        outputs = model(inputs_embeds=contextual_embeddings, attention_mask=inputs["attention_mask"], labels=labels)
        logits = outputs.logits  # (batch_size, seq_len, vocab_size)

    # Shift logits & labels (for T5)
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    shift_attention_mask = inputs["attention_mask"][:, 1:].contiguous()

    # Compute per-token loss
    loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    per_token_loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    # Reshape loss
    per_token_loss = per_token_loss.view(shift_labels.shape)

    # Mask out padding tokens
    per_token_loss *= shift_attention_mask

    # Compute sentence-level mean loss
    sentence_loss = per_token_loss.sum(dim=1) / shift_attention_mask.sum(dim=1)

    # Convert to perplexity
    perplexity_scores = torch.exp(sentence_loss).cpu().numpy()

    return perplexity_scores

In [None]:
def compute_perplexity(spt_name, batch_size=16):
    """
    Computes perplexity for a fine-tuned model using Hugging Face Dataset in batches.
    """
    # Load
    model, tokenizer, contextual_embeddings = get_final_model_with_contextual_embeddings(spt_name)

    # load dataset
    print(f"Loading dataset for {spt_name}...")
    perplexity_dataset = load_model_variants_gen_df(f"{spt_name}_final_predictions")
    perplexity_dataset = convert_to_hf(perplexity_dataset)

    # for debug, remove comment
    # perplexity_dataset = perplexity_dataset.select(range(100))

    print(f"Computing perplexity in batches of {batch_size}...")
    
    def compute_perplexity_batch(batch):
        texts = batch["generated"]  # Get text batch
        
        # Ensure all elements are strings and remove None values
        texts = [str(text) if text is not None else "" for text in texts]

        perplexity_scores = compute_perplexity_batch(texts, model, tokenizer, contextual_embeddings)

        return {"perplexity": perplexity_scores}

    # Compute perplexity in batches
    perplexity_dataset = perplexity_dataset.map(compute_perplexity_batch, batched=True, batch_size=batch_size)

    # Display Results
    mean_perplexity = np.mean(perplexity_dataset["perplexity"])
    print(f"Perplexity Score: {mean_perplexity:.4f}")

    # Save dataset
    save_tmp_df(perplexity_dataset, f"{spt_name}_final_perplexity")

In [None]:
# with bpe
compute_perplexity("bpe")

In [None]:
# with unigram
compute_perplexity("unigram")

## Save Evaluation Results

In [None]:
# combine evaluation results
for spt_name in model_names.keys():
    print(f"Processing {spt_name}...")

    evaluation_results = load_model_variants_gen_df(f"{spt_name}_final_predictions")

    # load metrics and set
    metrics = load_tmp_df(f"{spt_name}_final_metrics")
    evaluation_results["bleu"] = metrics["bleu"]
    evaluation_results["rouge-1"] = metrics["rouge-1"]
    evaluation_results["rouge-2"] = metrics["rouge-2"]
    evaluation_results["rouge-l"] = metrics["rouge-l"]
    evaluation_results["chrf-s"] = metrics["chrf-s"]
    evaluation_results["bert_score"] = metrics["bert_score"]

    # load perplexity and set
    perplexity = load_tmp_df(f"{spt_name}_final_perplexity")
    evaluation_results["perplexity"] = perplexity["perplexity"]

    save_model_variants_gen_df(evaluation_results, f"{spt_name}_final_evaluation_results")

# Benchmarking and Analysis

In [None]:
# load data
final_benchmarking_datasets = {}
for spt_name in model_names.keys():
    df = load_model_variants_gen_df(f"{spt_name}_final_evaluation_results")
    final_benchmarking_datasets[f"{spt_name.upper()}"] = df

In [None]:
# convert to mean score df
final_benchmarking_mean_scores = convert_to_mean_scores_df(final_benchmarking_datasets)

In [None]:
# Display mean scores
display(final_benchmarking_mean_scores)

In [None]:
# save benchmarking results
save_model_variants_gen_df(final_benchmarking_mean_scores, "final_evaluation_results")