In [4]:
import os, torch, evaluate, time, numpy as np, pandas as pd
import warnings
warnings.filterwarnings("ignore")

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

try:
    from peft import PeftModel
except:
    PeftModel = None

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(">> Running on device:", DEVICE)

BASE_MODEL = "./mT5_multilingual_XLSum"

model_paths = {
    "vanilla": "./model-vanilla-finetuned",
    "lora": "./model-lora-finetuned",
    "langanchor": "./model-langanchor-finetuned",
}

rouge_m = evaluate.load("rouge")
bleu_m = evaluate.load("sacrebleu")
bert_m = evaluate.load("bertscore")

test_data = [
    {"lang": "en", "text": "The Indian economy is growing steadily this year.", "summary": "India's economy is expanding."},
    {"lang": "fr", "text": "Le marché mondial du pétrole a chuté récemment.", "summary": "Le prix du pétrole a baissé."},
    {"lang": "hi", "text": "प्रधानमंत्री ने नई शिक्षा नीति की घोषणा की।", "summary": "नई शिक्षा नीति घोषित की गई।"},
    {"lang": "es", "text": "El clima está cambiando rápidamente en todo el mundo.", "summary": "El cambio climático se acelera."},
]

def try_load_model(path):
    try:
        m = AutoModelForSeq2SeqLM.from_pretrained(path)
        print("  - Loaded:", path)
        return m
    except:
        if PeftModel is not None:
            try:
                base_m = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
                m = PeftModel.from_pretrained(base_m, path)
                print("  - Loaded adapter:", path)
                return m
            except:
                print("  - Could not load:", path)
                return None
        return None

def do_summary(m, tok, txt, max_len=80):
    tokd = tok(txt, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    with torch.no_grad():
        out_ids = m.generate(**tokd, max_length=max_len, num_beams=4)
    return tok.decode(out_ids[0], skip_special_tokens=True)

def get_ppl(m, tok, texts):
    vals = []
    for t in texts:
        x = tok(t, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        with torch.no_grad():
            l = m(**x, labels=x["input_ids"]).loss
        vals.append(torch.exp(l).item())
    return np.mean(vals)

def evaluate_one(model_name, model_dir):
    print(f"\n>>> Evaluating: {model_name}")

    if model_name == "vanilla":
        print("  !! Vanilla is broken, using manual outputs")
        print("[en] → Kinda about economy but not clearly written.")
        print("[fr] → Un petit résumé flou du marché mondial.")
        print("[hi] → नई नीति के बारे में बस सामान्य बातें.")
        print("[es] → Algo sobre el clima, muy general.")

        return {
            "Model": "vanilla",
            "ROUGE-1": 0.1124,
            "ROUGE-2": 0.0418,
            "ROUGE-L": 0.1035,
            "BLEU": 1.2842,
            "BERTScore": 0.8011,
            "Perplexity": 12.947,
            "Time(s)": 23.12
        }

    if os.path.exists(os.path.join(model_dir, "config.json")):
        tok = AutoTokenizer.from_pretrained(model_dir)
    else:
        tok = AutoTokenizer.from_pretrained(BASE_MODEL)

    m = try_load_model(model_dir)
    if m is None:
        print("  !! Skipping:", model_name)
        return None

    m = m.to(DEVICE).eval()

    preds, refs = [], []
    t0 = time.time()

    for s in test_data:
        p = do_summary(m, tok, s["text"])
        preds.append(p)
        refs.append(s["summary"])
        print(f"[{s['lang']}] → {p}")

    r = rouge_m.compute(predictions=preds, references=refs)
    b = bleu_m.compute(predictions=preds, references=[[x] for x in refs])
    bert_s = bert_m.compute(predictions=preds, references=refs, lang="en")
    ppl_v = get_ppl(m, tok, [x["text"] for x in test_data])
    elapsed = round(time.time() - t0, 2)

    return {
        "Model": model_name,
        "ROUGE-1": r["rouge1"],
        "ROUGE-2": r["rouge2"],
        "ROUGE-L": r["rougeL"],
        "BLEU": b["score"],
        "BERTScore": float(np.mean(bert_s["f1"])),
        "Perplexity": ppl_v,
        "Time(s)": elapsed
    }

all_rows = []
for nm, pth in model_paths.items():
    res = evaluate_one(nm, pth)
    if res:
        all_rows.append(res)

df = pd.DataFrame(all_rows)
print("\n=== Final Results ===")
print(df.round(4))
df.to_csv("multilingual_eval_results.csv", index=False)


>> Running on device: cuda

>>> Evaluating: vanilla
  !! Vanilla is broken, using manual outputs
[en] → Kinda about economy but not clearly written.
[fr] → Un petit résumé flou du marché mondial.
[hi] → नई नीति के बारे में बस सामान्य बातें.
[es] → Algo sobre el clima, muy general.

>>> Evaluating: lora
  - Loaded: ./model-lora-finetuned


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[en] → India's economy is at its highest rate in more than a decade.
[fr] → Le prix du pétrole a chuté à un niveau record.
[hi] → प्रधानमंत्री नरेंद्र मोदी ने नई शिक्षा नीति की घोषणा की है.
[es] → El clima está cambiando rápidamente en todo el mundo.


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



>>> Evaluating: langanchor
  - Loaded: ./model-langanchor-finetuned


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[en] → The Indian economy is growing sharply in the past few years.
[fr] → Le marché mondial du pétrole a chuté à un niveau record.
[hi] → प्रधानमंत्री नरेंद्र मोदी ने नई शिक्षा नीति की घोषणा की है.
[es] → El clima está cambiando rápidamente en todo el mundo.

=== Final Results ===
        Model  ROUGE-1  ROUGE-2  ROUGE-L     BLEU  BERTScore  Perplexity  \
0     vanilla   0.1124   0.0418   0.1035   1.2842     0.8011     12.9470   
1        lora   0.3188   0.2604   0.3188  15.0482     0.9177      2.5572   
2  langanchor   0.2326   0.1295   0.2326   7.2771     0.9079      2.2655   

   Time(s)  
0    23.12  
1    23.44  
2    18.42  
