In [1]:
from utils.common import (
    generate_masked_predictions_hf_batch, generate_predictions_mT5_hf_batch,
    compute_metrics_hf_batch,
    convert_to_mean_scores_df,
    get_fine_tuned_model, get_embedded_fine_tuned_model,
    compute_multilingual_masked_perplexity_hf_batch, compute_multilingual_mt5_perplexity_batch,
    extract_metrics_from_logs,
    plot_training_metrics, plot_evaluation_metrics
)

In [2]:
import torch

In [3]:
# ✅ Function to Generate Predictions with Debugging
def generate_predictions_mT5_debug(model, tokenizer, sentences, max_length=128):
    """
    Generates predictions for mT5 fine-tuned with Prefix-Tuning + LoRA.
    Includes debugging for empty outputs.
    """

    model.eval()  # ✅ Set model to evaluation mode
    predictions = []

    for sentence in sentences:
        # ✅ Tokenize Input
        inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(model.device)

        # ✅ Generate Output with Beam Search & Sampling
        with torch.no_grad():
            output_tokens = model.generate(
                **inputs, 
                max_length=max_length, 
                num_beams=5, 
                do_sample=True, 
                temperature=0.9
            )

        # ✅ Debugging: Print Raw Output Tokens
        print(f"🔹 Input: {sentence}")
        print(f"🟢 Generated Tokens: {output_tokens}")

        # ✅ Decode Predictions
        prediction = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
        print(f"🟢 Decoded Prediction: {prediction}\n")

        predictions.append({"input": sentence, "prediction": prediction})

    return predictions

In [22]:
# ✅ Example Usage
test_sentences = ["most of the e-mails are from abroad.", "What is your name?", "This is an example sentence."]
model, tokenizer = get_fine_tuned_model("mT5", "bpe", "google/mT5-small", "mps")  # ✅ Load Fine-Tuned Model
results = generate_predictions_mT5_debug(model, tokenizer, test_sentences)

🔹 Input: most of the e-mails are from abroad.
🟢 Generated Tokens: tensor([[     0,    259, 211873,  21541, 169990,   2495,  14527,  29974,  10202,
           2495,  88859, 148554,  28771,  12868,    259,  22599,  71053,      1]],
       device='mps:0')
🟢 Decoded Prediction: မေးခွန်း အများစုဟာ အမေရိကန်မှာ ရှိတယ်။

🔹 Input: What is your name?
🟢 Generated Tokens: tensor([[     0,    259, 102868,  37588,  40959,    259, 228293,    259, 202929,
              1]], device='mps:0')
🟢 Decoded Prediction: မင်း ဘယ် အမည် လဲ။

🔹 Input: This is an example sentence.
🟢 Generated Tokens: tensor([[     0,  19725,  10202,    259, 230882,  15217,  13288,  14893,  40959,
            259,  71814,  21232,      1]], device='mps:0')
🟢 Decoded Prediction: ဒါဟာ အဓိပ္ပာယ် တစ်ခုပါ။



In [21]:
#✅ Print Predictions
for res in results:
    print(f"🔹 Input: {res['input']}\n🔹 Prediction: {res['prediction']}\n")

🔹 Input: they don't have any technology.
🔹 Prediction: သူတို့ဟာ နည်းပညာ မရှိဘူး။

🔹 Input: What is your name?
🔹 Prediction: မင်း ဘယ်လို အမည် လဲ။

🔹 Input: This is an example sentence.
🔹 Prediction: ဒါဟာ အဓိပ္ပာယ် တစ်ခုပါ။



In [24]:
# ✅ Function to Generate Predictions with Prefix-Tuning + LoRA
def generate_predictions_mT5_with_prefix(model, tokenizer, sentences, prefixes, prefix_projection, max_length=128):
    """
    Generates predictions for mT5 with Prefix-Tuning + LoRA.
    """
    model.eval()
    device = model.device
    predictions = []

    for sentence in sentences:
        inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)

        # ✅ Ensure Prefix is on GPU
        num_prefixes = prefixes.num_embeddings
        prefix_ids = torch.arange(num_prefixes, device=device)
        expanded_prefixes = prefixes(prefix_ids).unsqueeze(0).expand(inputs["input_ids"].shape[0], -1, -1)

        # ✅ Project Prefix to Hidden Dim
        projected_prefixes = prefix_projection(expanded_prefixes)

        # ✅ Convert token IDs to embeddings
        inputs_embeds = model.encoder.embed_tokens(inputs["input_ids"]).to(device)

        # ✅ Concatenate Prefix Embeddings with Inputs
        inputs_embeds = torch.cat([projected_prefixes, inputs_embeds], dim=1)

        # ✅ Update Attention Mask
        new_seq_length = inputs_embeds.shape[1]
        updated_attention_mask = torch.ones((inputs["attention_mask"].shape[0], new_seq_length), device=device)
        updated_attention_mask[:, projected_prefixes.shape[1]:] = inputs["attention_mask"]

        # ✅ Generate Output Using Prefix-Tuning
        with torch.no_grad():
            output_tokens = model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=updated_attention_mask,
                max_length=max_length,
                num_beams=5,
                early_stopping=True
            )

        # ✅ Decode Predictions
        prediction = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
        predictions.append({"input": sentence, "prediction": prediction})

    return predictions

In [23]:
# 2. Define Prefix Tuning Parameters (THIS IS WHERE THEY ARE CREATED)
prefix_length = 10
num_prefixes = model.config.num_layers * 2
prefix_projection_dim = 512

In [26]:
# Create the prefix embeddings:
prefixes = torch.nn.Embedding(num_prefixes, prefix_length * prefix_projection_dim).to("mps")

In [27]:
# Create the prefix projection layer (optional, but often recommended):
prefix_projection = torch.nn.Sequential(
    torch.nn.Linear(prefix_length * prefix_projection_dim, model.config.d_model),
    torch.nn.Tanh()  # Or another activation function
).to("mps")

In [None]:
test_sentences = ["most of the e-mails are from abroad.", "What is your name?", "This is an example sentence."]

In [28]:
predictions = generate_predictions_mT5_with_prefix(model, tokenizer, test_sentences, prefixes, prefix_projection)

In [29]:
predictions

[{'input': 'most of the e-mails are from abroad.',
  'prediction': 'ဂျာမန် ၏ မေးခွန်း အများစုသည် နိုင်ငံခြား သို့ ရောက်ရှိ ကြသည်။'},
 {'input': 'What is your name?', 'prediction': 'ဘယ်လို အဓိပ္ပါယ် လဲ ?'},
 {'input': 'This is an example sentence.',
  'prediction': 'ဂျာမန် ၏ အဓိပ္ပာယ် သည် ၎င်း၏ အဓိပ္ပါယ် တစ်ခုဖြစ်သည်။'}]