# 05_compare_checkpoints.ipynb

Comparação objetiva entre checkpoints do fine-tuning (ex.: 6064 vs 9000).

Este notebook NÃO treina. Ele apenas:
1) Localiza/copiar checkpoints.
2) Avalia em val(1k).
3) Compara em 200 amostras (baseline já existente).
4) Escolhe o melhor e atualiza `artifacts/t5_lora_best/`.

> Ajuste o caminho do backup do checkpoint-9000 na célula 1 se necessário.

## Utilitários de carregamento/avaliação (val 1k)

In [3]:
import torch, json
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel
import evaluate

device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print("Device:", device)

# dataset cru (strings)
raw_val = load_dataset("json", data_files={"val":"../data/val.jsonl"})["val"]
N = min(1000, len(raw_val))
texts = raw_val.select(range(N))["input_text"]
refs  = raw_val.select(range(N))["target_text"]
print("Val utilizado:", N)

rouge = evaluate.load("rouge")
bleu  = evaluate.load("sacrebleu")

MODEL_NAME = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def load_peft_from_ckpt(base_model_name: str, ckpt_path: str):
    base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name).to(device)
    mdl = PeftModel.from_pretrained(base, ckpt_path).to(device)
    mdl.eval()
    return mdl

@torch.no_grad()
def eval_checkpoint(model, tokenizer, texts, refs, max_in=128, max_out=224, beams=1):
    preds=[]
    for x in texts:
        enc = tokenizer(x, return_tensors="pt", truncation=True, max_length=max_in).to(device)
        out = model.generate(**enc, max_new_tokens=max_out, num_beams=beams)
        preds.append(tokenizer.decode(out[0], skip_special_tokens=True))
        if device=="mps": torch.mps.empty_cache()
    r = rouge.compute(predictions=preds, references=refs, use_aggregator=True)
    b = bleu.compute(predictions=preds, references=[[y] for y in refs])
    return {"rougeL": float(r["rougeL"]), "bleu": float(b["score"])}

Device: mps
Val utilizado: 1000


## Avaliar 6064 e 9000 em val(1k) e salvar snapshot

In [4]:
import os, json

CKPT_9000_DEST = "../outputs/t5_lora_mps/checkpoint-9000"  # mesmo da célula 1
CKPT_6064 = "../outputs/t5_lora_mps/checkpoint-6064"

ckpts_to_eval = {}
if os.path.isdir(CKPT_6064):
    ckpts_to_eval["6064"] = CKPT_6064
if os.path.isdir(CKPT_9000_DEST):
    ckpts_to_eval["9000"] = CKPT_9000_DEST

assert ckpts_to_eval, "Nenhum checkpoint encontrado para avaliar."

results = {}
for tag, path in ckpts_to_eval.items():
    model = load_peft_from_ckpt(MODEL_NAME, path)
    metrics = eval_checkpoint(model, tokenizer, texts, refs, max_in=128, max_out=224, beams=1)
    results[tag] = metrics
    print(f"[{tag}] ROUGE-L={metrics['rougeL']:.4f} | BLEU={metrics['bleu']:.2f}")

os.makedirs("../outputs", exist_ok=True)
with open("../outputs/eval_ckpt_6064_vs_9000.json","w",encoding="utf-8") as f:
    json.dump(results, f, indent=2)
print("Snapshot salvo em outputs/eval_ckpt_6064_vs_9000.json")

[6064] ROUGE-L=0.1008 | BLEU=0.66


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: f2fd3a1a-9283-449d-80fe-e964b1b93c57)')' thrown while requesting HEAD https://huggingface.co/google/flan-t5-base/resolve/main/config.json
Retrying in 1s [Retry 1/5].


[9000] ROUGE-L=0.1259 | BLEU=1.19
Snapshot salvo em outputs/eval_ckpt_6064_vs_9000.json


## Comparação “200 amostras” (gera arquivos por checkpoint)

In [8]:
import json, torch, os

# carrega baseline 200 já existente
rows = [json.loads(l) for l in open("../outputs/baseline_val200.jsonl","r",encoding="utf-8")]
inputs = [r["input"] for r in rows]
refs200 = [r["ref"] for r in rows]

@torch.no_grad()
def gen_preds(model, tokenizer, texts, bs=2, max_in=128, max_out=224, beams=1):
    out=[]
    for i in range(0,len(texts),bs):
        enc = tokenizer(texts[i:i+bs], return_tensors="pt", padding=True,
                        truncation=True, max_length=max_in).to(device)
        ids = model.generate(**enc, max_new_tokens=max_out, num_beams=beams)
        out += tokenizer.batch_decode(ids, skip_special_tokens=True)
        if device=="mps": torch.mps.empty_cache()
    return out

def compare_and_save(tag, ckpt_path):
    model = load_peft_from_ckpt(MODEL_NAME, ckpt_path)
    preds = gen_preds(model, tokenizer, inputs, bs=2, max_in=128, max_out=224, beams=1)
    out_path = f"../outputs/compare_finetuned_{tag}_vs_baseline.jsonl"
    with open(out_path,"w",encoding="utf-8") as f:
        for x, pf, y in zip(inputs, preds, refs200):
            f.write(json.dumps({"input": x, "finetuned_pred": pf, "ref": y}, ensure_ascii=False)+"\n")
    print(f"Comparação salva: {out_path}")
    return out_path

if os.path.isdir(CKPT_6064):
    compare_and_save("6064", CKPT_6064)
if os.path.isdir(CKPT_9000_DEST):
    compare_and_save("9000", CKPT_9000_DEST)

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 8c6e6250-2f72-4fd3-9c12-c555d77783dc)')' thrown while requesting HEAD https://huggingface.co/google/flan-t5-base/resolve/main/config.json
Retrying in 1s [Retry 1/5].


Comparação salva: ../outputs/compare_finetuned_6064_vs_baseline.jsonl


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 60e55549-88ae-4553-a813-f7df5ad04466)')' thrown while requesting HEAD https://huggingface.co/google/flan-t5-base/resolve/main/config.json
Retrying in 1s [Retry 1/5].


Comparação salva: ../outputs/compare_finetuned_9000_vs_baseline.jsonl


## Métricas nas 200 amostras (ambos)

In [9]:
import evaluate, json, os

rouge = evaluate.load("rouge")
bleu  = evaluate.load("sacrebleu")

def load_preds(file_path):
    rows = [json.loads(l) for l in open(file_path,"r",encoding="utf-8")]
    return [r["finetuned_pred"] for r in rows], [r["ref"] for r in rows]

m = {}

p6064 = "../outputs/compare_finetuned_6064_vs_baseline.jsonl"
p9000 = "../outputs/compare_finetuned_9000_vs_baseline.jsonl"

def metric(preds, refs):
    r = rouge.compute(predictions=preds, references=refs, use_aggregator=True)
    b = bleu.compute(predictions=preds, references=[[x] for x in refs])
    return {"rougeL": float(r["rougeL"]), "bleu": float(b["score"])}

if os.path.exists(p6064):
    preds, refs = load_preds(p6064)
    m["6064"] = {k: round(v,4) for k,v in metric(preds, refs).items()}
    print("200 amostras — 6064 →", m["6064"])

if os.path.exists(p9000):
    preds, refs = load_preds(p9000)
    m["9000"] = {k: round(v,4) for k,v in metric(preds, refs).items()}
    print("200 amostras — 9000 →", m.get("9000"))

with open("../outputs/metrics_200_ckpt_6064_vs_9000.json","w",encoding="utf-8") as f:
    json.dump(m, f, indent=2)
print("Métricas salvas em outputs/metrics_200_ckpt_6064_vs_9000.json")

200 amostras — 6064 → {'rougeL': 0.1117, 'bleu': 0.9039}
200 amostras — 9000 → {'rougeL': 0.1266, 'bleu': 1.2264}
Métricas salvas em outputs/metrics_200_ckpt_6064_vs_9000.json


## Selecionar o melhor checkpoint e atualizar `artifacts/t5_lora_best/`

In [12]:
import os, shutil, json

BEST_DIR = "../outputs/t5_lora_best"
metrics_path = "../outputs/metrics_200_ckpt_6064_vs_9000.json"
m = json.load(open(metrics_path, "r"))
print("Métricas carregadas:", m)

def better(a, b):
    if a["bleu"] != b["bleu"]:
        return "6064" if a["bleu"] > b["bleu"] else "9000"
    return "6064" if a["rougeL"] > b["rougeL"] else "9000"

choose = None
if "6064" in m and "9000" in m:
    choose = better(m["6064"], m["9000"])
elif "6064" in m:
    choose = "6064"
elif "9000" in m:
    choose = "9000"
else:
    raise RuntimeError("Sem métricas válidas. Rode as células anteriores.")

src = "../outputs/t5_lora_mps/checkpoint-6064" if choose=="6064" else "../outputs/t5_lora_mps/checkpoint-9000"

print(f"Melhor: {choose} → {src}")
if os.path.exists(BEST_DIR):
    shutil.rmtree(BEST_DIR)
shutil.copytree(src, BEST_DIR)
print("Modelo final atualizado em:", BEST_DIR)

Métricas carregadas: {'6064': {'rougeL': 0.1117, 'bleu': 0.9039}, '9000': {'rougeL': 0.1266, 'bleu': 1.2264}}
Melhor: 9000 → ../outputs/t5_lora_mps/checkpoint-9000
Modelo final atualizado em: ../outputs/t5_lora_best
