# Importing all the important library

In [None]:
!pip install transformers==4.45.2 sentence-transformers==3.1.1

In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, EarlyStoppingCallback, Trainer
from utils.gpu import get_device
from utils.common import (
    get_prefix, apply_lora, 
    TRAIN_ARGS, 
    extract_metrics_from_logs,
    plot_training_metrics,
    plot_evaluation_metrics,
    get_prefix_fine_tuned_model,
    generate_mt5_predictions_hf_batch
)
from utils.dataframe import (
    load_model_variants_hf,
    save_model_variants_gen_df,
    load_models_df,
    convert_to_hf,
)

2025-02-19 05:41:58.067615: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-19 05:41:58.091754: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-19 05:41:58.091785: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-19 05:41:58.107221: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Transformers is only compatible with Keras 2, but yo

# Common

In [2]:
# gpu device 
device = get_device()

Tensorflow GPUs:  [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Using PyTorch device: cuda
GPU Name: NVIDIA A10G


2025-02-19 05:42:00.879213: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-02-19 05:42:00.917558: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-02-19 05:42:00.919528: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

In [None]:
# model name
mT5_model_name = "google/mT5-small"

In [None]:
# spt models
spt_models = ["bpe", "unigram"]

# 2. Prefix Tuning Transformer Models for Burmese

## Prefix Tuning

In [None]:
# tokenized train dataset
tokenized_train_datasets = {
    spt_name: load_model_variants_hf(f"mt5_{spt_name}_train")
    for spt_name in spt_models
}

In [None]:
# tokenized test dataset
tokenized_test_datasets = {
    spt_name: load_model_variants_hf(f"mt5_{spt_name}_test")
    for spt_name in spt_models
}

In [None]:
# custom class for Prefix Tuning
class PrefixTuningTrainer(Trainer):
    def __init__(self, *args, prefixes, prefix_projection, **kwargs):
        super().__init__(*args, **kwargs)
        device = self.model.device
        self.prefixes = prefixes.to(device)  # Ensure prefix embeddings are on GPU
        self.prefix_projection = prefix_projection.to(device)  # Ensure prefix projection is on GPU

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Compute loss for prefix tuning.
        """
        device = model.device  # Ensure all tensors are on the same device

        # Move inputs to GPU
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        decoder_input_ids = inputs["labels"].to(device)

        # Fix padding issue (-100 should be replaced with decoder_start_token_id)
        decoder_input_ids = decoder_input_ids.masked_fill(
            decoder_input_ids == -100, model.config.decoder_start_token_id
        )

        # Ensure prefix tensor is on GPU
        num_prefixes = self.prefixes.num_embeddings
        prefix_ids = torch.arange(num_prefixes, device=device)
        expanded_prefixes = self.prefixes(prefix_ids).unsqueeze(0).expand(input_ids.shape[0], -1, -1)

        # Project prefix embeddings to model hidden dimension
        projected_prefixes = self.prefix_projection(expanded_prefixes)
        assert projected_prefixes.shape[-1] == model.config.d_model, "Prefix projection mismatch!"

        # Convert token IDs to embeddings
        inputs_embeds = model.encoder.embed_tokens(input_ids).to(device)
        decoder_inputs_embeds = model.decoder.embed_tokens(decoder_input_ids).to(device)

        # Concatenate prefix embeddings with inputs
        inputs_embeds = torch.cat([projected_prefixes, inputs_embeds], dim=1)

        # Update attention mask
        new_seq_length = inputs_embeds.shape[1]
        updated_attention_mask = torch.ones((attention_mask.shape[0], new_seq_length), device=device)
        updated_attention_mask[:, projected_prefixes.shape[1]:] = attention_mask

        # Forward pass with correct decoder embeddings
        outputs = model(
            inputs_embeds=inputs_embeds,
            attention_mask=updated_attention_mask,
            decoder_inputs_embeds=decoder_inputs_embeds,  # Use embeddings instead of decoder_input_ids
            labels=decoder_input_ids  # Ensure loss is computed
        )

        # Extract loss
        loss = outputs.loss if hasattr(outputs, "loss") else outputs["loss"]

        if loss is None:
            raise ValueError("Model did not return a loss. Ensure labels are provided.")

        return (loss, outputs) if return_outputs else loss

In [None]:
def fine_tune_model_with_prefix_tuning(spt_name, batch_size):
    """
    Fine-tunes the model with LoRA on the specified SentencePiece tokenization (SPT).
    """
    print(f"Fine-tuning mT5 using SPT-{spt_name.upper()} with prefix tuning...")

    # Load tokenizer & model
    tokenizer = AutoTokenizer.from_pretrained(mT5_model_name, use_fast=False, legacy=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(mT5_model_name)

    # Move model to GPU before applying LoRA
    model.to(device)

    # get prefix
    prefixes, prefix_projection = get_prefix(model, device)

    # Apply LoRA for efficient parameter tuning
    model = apply_lora(model, mT5_model_name, device)

    # display trainable parameters
    print(model.print_trainable_parameters())
    
    # load dataset
    train_data = tokenized_train_datasets[spt_name]
    val_data = tokenized_test_datasets[spt_name]

    # for debug, remove comment
    #train_data = train_data.select(range(100))
    #val_data = val_data.select(range(100))

    trained_model_name = f"Prefix_mT5_{spt_name.upper()}"

    # Define TrainingArguments
    training_args = TrainingArguments(
        **TRAIN_ARGS,
        output_dir=f"model-variants/results/{trained_model_name}",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=3,
        learning_rate=5e-5,
        logging_dir=f"model-variants/logs/{trained_model_name}",
        label_names=["labels", "input_ids"]
    )
    
    # Initialize PrefixTuningTrainer
    trainer = PrefixTuningTrainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=val_data,
        tokenizer=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
        prefixes=prefixes,
        prefix_projection=prefix_projection
    )

    # Train the model
    trainer.train()

    # Save trained model and tokenizer
    save_path = f"model-variants/models/{trained_model_name}"
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)

    print(f"Model `mT5` fine-tuned and saved at `{save_path}`.")

In [None]:
# bpe
fine_tune_model_with_prefix_tuning("bpe", 8)

In [None]:
# unigram
fine_tune_model_with_prefix_tuning("bpe", 8)

## Train Results

### BPE

In [None]:
# load train results
mt5_bpe_prefix_fined_tuned_train_results, mt5_bpe_prefix_fined_tuned_eval_results, mt5_bpe_prefix_fine_tuned_final_results = extract_metrics_from_logs("Prefix_mT5_BPE")

In [None]:
# plot train results
plot_training_metrics(mt5_bpe_prefix_fined_tuned_train_results, "Prefix mT5 BPE")

In [None]:
# plot eval results
plot_evaluation_metrics(mt5_bpe_prefix_fined_tuned_eval_results, "Prefix mT5 BPE")

In [None]:
# display final result
display(mt5_bpe_prefix_fine_tuned_final_results)

In [None]:
# save train results
save_model_variants_gen_df(mt5_bpe_prefix_fined_tuned_train_results, "mt5_bpe_prefix_fined_tuned_train_results")
save_model_variants_gen_df(mt5_bpe_prefix_fined_tuned_eval_results, "mt5_bpe_prefix_fined_tuned_eval_results")
save_model_variants_gen_df(mt5_bpe_prefix_fine_tuned_final_results, "mt5_bpe_prefix_fine_tuned_final_results")

### Unigram

In [None]:
# load train results
mt5_unigram_prefix_fined_tuned_train_results, mt5_unigram_prefix_fined_tuned_eval_results, mt5_unigram_prefix_fine_tuned_final_results = extract_metrics_from_logs("Prefix_mT5_UNIGRAM")

In [None]:
# plot train results
plot_training_metrics(mt5_unigram_prefix_fined_tuned_train_results, "Prefix mT5 Unigram")

In [None]:
# plot eval results
plot_evaluation_metrics(mt5_unigram_prefix_fined_tuned_eval_results, "Prefix mT5 Unigram")

In [None]:
# display final result
display(mt5_unigram_prefix_fine_tuned_final_results)

In [None]:
# save train results
save_model_variants_gen_df(mt5_unigram_prefix_fined_tuned_train_results, "mt5_unigram_prefix_fined_tuned_train_results")
save_model_variants_gen_df(mt5_unigram_prefix_fined_tuned_eval_results, "mt5_unigram_prefix_fined_tuned_eval_results")
save_model_variants_gen_df(mt5_unigram_prefix_fine_tuned_final_results, "mt5_unigram_prefix_fine_tuned_final_results")

## Generate Predictions

In [None]:
# Function to generate predictions for fine-tuned model using Hugging Face Dataset
def generate_predictions_prefix_fine_tuned_model(spt_name):
    # Load tokenizers & models
    model, tokenizer = get_prefix_fine_tuned_model("mT5", spt_name, mT5_model_name, device)
    model.eval()

    # Load dataset 
    dataset = load_models_df("multilingual_combined")

    dataset = convert_to_hf(dataset)

    # remove comment for debug
    # dataset = dataset.select(range(10))

    # Run text generation
    dataset = generate_mt5_predictions_hf_batch(dataset, model, tokenizer, device)

    # Display results
    display(dataset.to_pandas().head())

    # Save dataset
    save_model_variants_gen_df(dataset, f"mt5_{spt_name}_trained_predictions")

In [None]:
# generate predictions for mT5 with BPE
generate_predictions_prefix_fine_tuned_model("bpe")

In [None]:
# generate predictions for mT5 with Unigram
generate_predictions_prefix_fine_tuned_model("unigram")