In [None]:
!pip install datasets accelerate evaluate sentencepiece
!pip install --upgrade transformers

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.5
Collecting transformers
  Downloading transformers-4.56.1-py3-none-any.whl.metadata (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.56.1-py3-none-any.whl (11.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m122.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.56.0
    Uninstalling transformers-4.56.0:
      Successfully uninstalled transformers-4.56.0
Successfully installed transformers-4.56.1

In [None]:
# ---------------------------
# 0) Imports
# ---------------------------
import os
import re
import math
import time
import random
import shutil
import difflib
from typing import List, Tuple
from collections import Counter
from transformers.trainer_utils import get_last_checkpoint

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments,
    EarlyStoppingCallback, TrainerCallback, TrainingArguments,
    TrainerControl, TrainerState
)

# Only available inside Google Colab:
try:
    from google.colab import files
except Exception:
    files = None  # Allows the script to run outside Colab (download step will be skipped)

In [None]:
# -----------------------------
# 0) Config
# -----------------------------
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)

DATA_DIR   = "/content/output"
TRAIN_FILE = f"{DATA_DIR}/llm_train_gen.jsonl"
VALID_FILE = f"{DATA_DIR}/llm_valid_gen.jsonl"
TEST_FILE  = f"{DATA_DIR}/llm_test_gen.jsonl"

# Pick base model
BASE_MODEL = "t5-base"          # or: "google/flan-t5-base" | "google/flan-t5-large"
OUT_DIR    = "/content/t5_degree2_ckpt"
FINAL_DIR  = "/content/t5_degree2_final"

CHUNK_N_SENT = None

MAX_IN_LEN  = 512
MAX_OUT_LEN = 256
EVENT_TOKEN = "<EVENTSEP>"

In [None]:
# -----------------------------
# 1) Regex & utils
# -----------------------------
EVENT_SEP = re.compile(r"\s*<EVENTSEP>\s*", re.IGNORECASE)
TYPE_RE   = re.compile(r'event\s*type\s*:\s*(.+?)\.', re.IGNORECASE)
TRIG_RE   = re.compile(r'trigger\s*:\s*(.+?)\.', re.IGNORECASE)
TOK_RE    = re.compile(r"\w+", re.UNICODE)

def norm(s: str) -> str:
    return re.sub(r"\s+"," ", s.strip().lower()) if s else s

def parse_pairs(text: str) -> List[Tuple[str,str]]:
    if not text or "no events" in text.lower(): return []
    out=[]
    for ch in [c for c in EVENT_SEP.split(text) if c.strip()]:
        t = TYPE_RE.search(ch); g = TRIG_RE.search(ch)
        et = norm(t.group(1)) if t else None
        tr = norm(g.group(1)) if g else None
        if et and tr: out.append((et,tr))
    return out

def dedup_events_str(text: str) -> str:
    if not text: return ""
    seen=set(); kept=[]
    for p in [p.strip() for p in EVENT_SEP.split(text) if p.strip()]:
        key = norm(p)
        if key and key not in seen:
            seen.add(key); kept.append(p)
    return f" {EVENT_TOKEN} ".join(kept)

def token_set(s): return set(TOK_RE.findall(s.lower())) if s else set()

def trigger_overlap(p,g):
    if not p or not g: return 0.0
    if norm(p)==norm(g): return 1.0
    ps,gs = token_set(p), token_set(g)
    return 0.5 if (ps and gs and ps&gs) else 0.0

In [None]:
# -----------------------------
# 2) Load plain sets (to build few-shot & ontology)
# -----------------------------
def load_plain(path): return load_dataset("json", data_files={"data": path})["data"]

train_plain = load_plain(TRAIN_FILE)
valid_plain = load_plain(VALID_FILE)
test_plain  = load_plain(TEST_FILE)

TYPE_FREQ = Counter()
for ex in train_plain:
    for et,tr in parse_pairs(ex["output"]):
        TYPE_FREQ[et]+=1
ONTOLOGY = sorted(TYPE_FREQ.keys())

def nearest_type(t: str) -> str:
    if not t or not ONTOLOGY: return t
    cand = difflib.get_close_matches(t, ONTOLOGY, n=1, cutoff=0.8)
    return cand[0] if cand else t

# Few-shot واقعی و کوتاه
def build_fewshot_bank(ds, k=60):
    bank=[]
    for ex in ds:
        out = ex["output"]
        if (EVENT_TOKEN in out) and (len(out)<600):
            src = ex["input"]
            m = re.search(r'\"(.+?)\"\s*\nUse <EVENTSEP>', src, re.DOTALL)
            sent = m.group(1) if m else src
            bank.append((sent.strip(), out.strip()))
    random.shuffle(bank)
    return bank[:k]

FEWSHOT = build_fewshot_bank(train_plain, k=60)

def sample_k_shots(k=3):
    k = min(k, len(FEWSHOT))
    return random.sample(FEWSHOT, k) if k>0 else []

Generating data split: 0 examples [00:00, ? examples/s]

Generating data split: 0 examples [00:00, ? examples/s]

Generating data split: 0 examples [00:00, ? examples/s]

In [None]:
# -----------------------------
# 3) Prompt builder (DEGREE2-style)
# -----------------------------
PROMPT_HEADER = (
    "Extract ALL events from the sentence below.\n"
    f"Output only lines like: {EVENT_TOKEN} Event type: <TYPE>. Trigger: <TRIGGER>.\n"
    "If no events, output exactly: No events.\n\n"
)

def build_prompt(sentence: str, k=3) -> str:
    head = PROMPT_HEADER
    shots = sample_k_shots(k)
    if shots:
        head += "### Examples\n"
        for s,o in shots:
            head += f'Sentence: "{s}"\n{o}\n\n'
    head += "### Now extract\n"
    head += f'Sentence: "{sentence}"\nOutput:\n'
    return head

In [None]:
# -----------------------------
# 4) Chunk text to sentences
# -----------------------------
def chunk_text_to_sentences(text: str) -> list:
    return re.split(r'(?<=[\.\!\?])\s+', text.strip())

def chunk_doc(sentence_or_doc: str, n=3):
    sents = chunk_text_to_sentences(sentence_or_doc)
    if len(sents)<=n: return [sentence_or_doc]
    chunks=[]
    for i in range(0,len(sents),n):
        chunks.append(" ".join(sents[i:i+n]))
    return chunks

In [None]:
# -----------------------------
# 5) Supervised datasets => prompted IO
# -----------------------------
def to_prompted(ds, fewshot_k=3):
    def _map(ex):
        txt = ex["input"]
        if CHUNK_N_SENT:
            pieces = chunk_doc(txt, n=CHUNK_N_SENT)
            txt = pieces[0]
        return {"input": build_prompt(txt, k=fewshot_k), "output": ex["output"]}
    return ds.map(_map, remove_columns=[c for c in ds.column_names if c not in ("input","output")])

train_ds = to_prompted(train_plain, fewshot_k=3)
valid_ds = to_prompted(valid_plain, fewshot_k=3)
test_ds  = to_prompted(test_plain,  fewshot_k=3)

Map:   0%|          | 0/32431 [00:00<?, ? examples/s]

Map:   0%|          | 0/8042 [00:00<?, ? examples/s]

Map:   0%|          | 0/9400 [00:00<?, ? examples/s]

In [None]:
# -----------------------------
# 6) Tokenizer & model
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
added = tokenizer.add_special_tokens({"additional_special_tokens":[EVENT_TOKEN]})

def preprocess(batch):
    enc = tokenizer(batch["input"], max_length=MAX_IN_LEN, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        lab = tokenizer(batch["output"], max_length=MAX_OUT_LEN, truncation=True, padding="max_length")
    enc["labels"] = lab["input_ids"]
    return enc

train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
valid_tok = valid_ds.map(preprocess, batched=True, remove_columns=valid_ds.column_names)

model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
if added>0: model.resize_token_embeddings(len(tokenizer))
model.gradient_checkpointing_enable()

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Map:   0%|          | 0/32431 [00:00<?, ? examples/s]



Map:   0%|          | 0/8042 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
# =====================================================
# A) Resume utilities + final snapshot helper
# =====================================================
def get_latest_ckpt_path(out_dir):
    """Return HF-style latest checkpoint path or None."""
    try:
        return get_last_checkpoint(out_dir)
    except Exception:
        return None

def save_final_snapshot(trainer, tokenizer, base_dir, tag=None):
    """
    Save a final snapshot regardless of save_steps.
    e.g.: /content/t5_degree2_ckpt/checkpoint-final-20000
    Also writes '/content/t5_degree2_ckpt/LATEST.txt'
    """
    step = int(trainer.state.global_step) if trainer and trainer.state else 0
    tag  = tag or f"final-{step}"
    snap_dir = os.path.join(base_dir, f"checkpoint-{tag}")
    os.makedirs(snap_dir, exist_ok=True)
    trainer.save_model(snap_dir)
    if tokenizer is not None:
        tokenizer.save_pretrained(snap_dir)
    with open(os.path.join(base_dir, "LATEST.txt"), "w") as f:
        f.write(snap_dir)
    print(f"[Snapshot] Saved '{snap_dir}' and updated LATEST.txt")
    return snap_dir

In [None]:
# =====================================================
# B) StopAtStepCallback: precise early-stop + save
# =====================================================
class StopAtStepCallback(TrainerCallback):
    """
    Stop cleanly at a target step (env STOP_AT_STEP or /content/stop_at_step.txt),
    save a final checkpoint, and allow later cells to continue.
    """
    def __init__(self, target_step=None, out_dir=None, tokenizer=None):
        self.target_step = target_step
        self.out_dir = out_dir
        self.tokenizer = tokenizer

    def _read_target(self):
        env_val = os.getenv("STOP_AT_STEP", "").strip()
        if env_val.isdigit():
            return int(env_val)
        fpath = "/content/stop_at_step.txt"
        if os.path.exists(fpath):
            try:
                with open(fpath) as f:
                    val = f.read().strip()
                if val.isdigit():
                    return int(val)
            except:
                pass
        return self.target_step

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        tgt = self._read_target()
        if tgt is None:
            return
        if state.global_step >= tgt:
            trainer = kwargs.get("trainer", None)
            print(f"\n[StopAtStep] Reached step {state.global_step} (target={tgt}). Saving and stopping...")
            if trainer is not None and self.out_dir is not None:
                save_final_snapshot(trainer, self.tokenizer, self.out_dir, tag=f"stop-{state.global_step}")
            control.should_training_stop = True
            control.should_save = True
            return control

In [None]:
# -----------------------------
# 7) Training (Adafactor, 10 epochs per paper)
# -----------------------------
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

args = Seq2SeqTrainingArguments(
    output_dir=OUT_DIR,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=800,
    save_steps=800,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    num_train_epochs=10,
    max_steps=-1,  # keep epochs; you can still stop exactly via STOP_AT_STEP/file
    learning_rate=1e-4,
    weight_decay=1e-5,
    lr_scheduler_type="constant",
    optim="adafactor",

    label_smoothing_factor=0.1,
    max_grad_norm=1.0,
    predict_with_generate=False,

    fp16=False,
    bf16=use_bf16,
    report_to="none",
    logging_steps=100,
    save_safetensors=True,
    seed=SEED, data_seed=SEED,
    remove_unused_columns=False,
)

# Auto-resume (if checkpoint exists)
resume_path = get_latest_ckpt_path(OUT_DIR)
if resume_path:
    print(f"[Resume] Found latest checkpoint: {resume_path}")
else:
    print("[Resume] No checkpoint found; starting from base model.")

stop_cb = StopAtStepCallback(target_step=None, out_dir=OUT_DIR, tokenizer=tokenizer)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=valid_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3), stop_cb]
)

[Resume] No checkpoint found; starting from base model.


  trainer = Seq2SeqTrainer(


In [None]:
try:
    trainer.train(resume_from_checkpoint=resume_path)
except KeyboardInterrupt:
    print("\n[Interrupt] Caught KeyboardInterrupt — saving final snapshot...")
    save_final_snapshot(trainer, tokenizer, OUT_DIR, tag=f"interrupt-{int(time.time())}")
    raise  # (optional) see the interrupt; comment out to silently continue

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss,Validation Loss
800,1.5001,1.46787
1600,1.4686,1.438593
2400,1.4557,1.426441
3200,1.4475,1.411881
4000,1.4401,1.412003
4800,1.4308,1.404725
5600,1.4307,1.40936
6400,1.4281,1.403879
7200,1.4252,1.405518
8000,1.4245,1.404045



[Interrupt] Caught KeyboardInterrupt — saving final snapshot...
[Snapshot] Saved '/content/t5_degree2_ckpt/checkpoint-interrupt-1755393849' and updated LATEST.txt


KeyboardInterrupt: 

In [None]:
import os, json, time, shutil
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from google.colab import files

assert 'OUT_DIR' in globals(), "⚠️ OUT_DIR is not defined."

def find_best_checkpoint(out_dir: str):
    state_path = os.path.join(out_dir, "trainer_state.json")
    if os.path.exists(state_path):
        with open(state_path, "r") as f:
            st = json.load(f)
        best = st.get("best_model_checkpoint", None)
        if best and os.path.exists(best):
            return best
    return None

def pick_checkpoint_to_zip(out_dir: str):
    best = find_best_checkpoint(out_dir)
    if best:
        print(f"[info] Best checkpoint: {best}")
        return best
    latest = get_last_checkpoint(out_dir)
    if latest:
        print(f"[info] Fallback to latest checkpoint: {latest}")
        return latest
    raise FileNotFoundError("No checkpoint found.")

ckpt_dir = pick_checkpoint_to_zip(OUT_DIR)

# Create a zip file
stamp = time.strftime("%Y%m%d-%H%M%S")
zip_base = f"/content/{Path(ckpt_dir).name}_{stamp}"
zip_path = shutil.make_archive(zip_base, "zip", ckpt_dir)

# Download directly in Colab
files.download(zip_path)


[info] Fallback to latest checkpoint: /content/t5_degree2_ckpt/checkpoint-14400


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# After training (complete/early/stop), save a final snapshot and export FINAL_DIR
final_snap = save_final_snapshot(trainer, tokenizer, OUT_DIR)  # e.g. checkpoint-final-<step>

os.makedirs(FINAL_DIR, exist_ok=True)
trainer.save_model(FINAL_DIR); tokenizer.save_pretrained(FINAL_DIR)
print("Saved best model to:", FINAL_DIR)

[Snapshot] Saved '/content/t5_degree2_ckpt/checkpoint-final-13660' and updated LATEST.txt
Saved best model to: /content/t5_degree2_final


In [None]:
# --- Zip & Download FINAL_DIR in Colab ---
try:
    from google.colab import files as colab_files
    stamp = time.strftime("%Y%m%d-%H%M%S")
    zip_base = f"/content/{os.path.basename(FINAL_DIR)}_{stamp}"
    for ext in (".zip",):
        if os.path.exists(zip_base + ext):
            os.remove(zip_base + ext)
    print("Zipping ... این کار ممکنه چند دقیقه طول بکشه.")
    archive_path = shutil.make_archive(zip_base, 'zip', FINAL_DIR)
    size_gb = os.path.getsize(archive_path) / (1024**3)
    print(f"Done: {archive_path}  (~{size_gb:.2f} GB)")
    colab_files.download(archive_path)
except Exception as e:
    print(f"[Zip/Download] Skipped or failed (non-Colab env?): {e}")

Zipping ... این کار ممکنه چند دقیقه طول بکشه.
Done: /content/t5_degree2_final_20250817-012846.zip  (~0.77 GB)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# -----------------------------
# 8) Inference helpers (recall-friendly decoding)
# -----------------------------
def clean_and_canon(text: str) -> str:
    text = dedup_events_str(text)
    parts=[]
    for ch in [c for c in EVENT_SEP.split(text) if c.strip()]:
        t = TYPE_RE.search(ch); g = TRIG_RE.search(ch)
        et = norm(t.group(1)) if t else None
        tr = norm(g.group(1)) if g else None
        if et:
            et = nearest_type(et)
        if et and tr:
            parts.append(f"Event type: {et}. Trigger: {tr}.")
    return f" {EVENT_TOKEN} ".join(parts) if parts else "No events."

def generate_batch(prompts: List[str], mdl, tok, bs=8, device=None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mdl.eval().to(device)
    gen_kwargs = dict(
        max_new_tokens=160, min_new_tokens=16,
        num_beams=6, num_beam_groups=3, diversity_penalty=0.2,
        no_repeat_ngram_size=3, length_penalty=0.9, repetition_penalty=1.05,
        early_stopping=False,
    )
    outs=[]
    with torch.no_grad():
        for i in range(0,len(prompts),bs):
            enc = tok(prompts[i:i+bs], return_tensors="pt", padding=True, truncation=True, max_length=MAX_IN_LEN).to(device)
            gen = mdl.generate(**enc, **gen_kwargs)
            dec = tok.batch_decode(gen, skip_special_tokens=True)
            outs += [clean_and_canon(t) for t in dec]
    return outs

In [None]:
# -----------------------------
# 9) Evaluation (Strict / Partial / Relaxed)
# -----------------------------
def match_partial(preds, golds):
    used=set(); exact=partial=0
    for pt,ptr in preds:
        best=(0.0,-1)
        for j,(gt,gtr) in enumerate(golds):
            if j in used: continue
            if pt!=gt:   continue
            ov = trigger_overlap(ptr,gtr)
            if ov>best[0]: best=(ov,j)
        if best[1]!=-1:
            used.add(best[1])
            if math.isclose(best[0],1.0): exact+=1
            elif best[0]>=0.5: partial+=1
    return exact, partial, len(preds), len(golds)

def prf(e,p,pt,gt, w=0.5):
    wtp = e + w*p
    P = wtp/pt if pt else 0.0
    R = wtp/gt if gt else 0.0
    F = (2*P*R)/(P+R) if (P+R) else 0.0
    return P,R,F

def relaxed_recall_by_chunks(pairs_pred, pairs_gold, chunk_size=1):
    correct=partial=extra=possible=impossible=0
    for pp,gg in zip(pairs_pred, pairs_gold):
        ce,cp,pt,gt = match_partial(pp,gg)
        matched = ce + 0.5*cp
        possible += gt
        impossible += max(0, gt-1)
        extra += max(0.0, matched-1.0)
        correct += ce; partial += cp
    denom = max(1, possible - impossible)
    num   = max(0.0, (correct + 0.5*partial) - extra)
    return num/denom

def evaluate(ds, mdl, tok):
    prompts = [ex["input"] for ex in ds]
    gtexts  = [ex["output"] for ex in ds]
    preds   = generate_batch(prompts, mdl, tok, bs=8)

    strict_tp=strict_pred=strict_gold=0
    part_e=part_p=part_pt=part_gt=0
    chunks_pred=[]; chunks_gold=[]

    for ptxt,gtxt in zip(preds, gtexts):
        pp = parse_pairs(ptxt)
        gg = parse_pairs(gtxt)
        # strict
        sp,sg = set(pp), set(gg)
        tp = len(sp & sg)
        strict_tp += tp; strict_pred += len(sp); strict_gold += len(sg)
        # partial
        ce,cp,pt,gt = match_partial(pp,gg)
        part_e += ce; part_p += cp; part_pt += pt; part_gt += gt
        chunks_pred.append(pp); chunks_gold.append(gg)

    sP = strict_tp/strict_pred if strict_pred else 0.0
    sR = strict_tp/strict_gold if strict_gold else 0.0
    sF = (2*sP*sR)/(sP+sR) if (sP+sR) else 0.0

    pP,pR,pF = prf(part_e,part_p,part_pt,part_gt, w=0.5)
    r_rel = relaxed_recall_by_chunks(chunks_pred, chunks_gold)

    print("\n===== STRICT =====")
    print(f"P={sP:.4f} R={sR:.4f} F1={sF:.4f}")
    print("===== PARTIAL (MUC 0.5) =====")
    print(f"P={pP:.4f} R={pR:.4f} F1={pF:.4f}")
    print("===== RELAXED (DEGREE2) =====")
    rF = (2*pP*r_rel)/(pP+r_rel) if (pP+r_rel)>0 else 0.0
    print(f"Relaxed-Recall={r_rel:.4f} | Relaxed-F1≈{rF:.4f}")
    return dict(strict_f1=sF, partial_f1=pF, relaxed_recall=r_rel, relaxed_f1=rF)

In [None]:
# -----------------------------
# 10) Quick sanity + Eval
# -----------------------------
best_tok   = AutoTokenizer.from_pretrained(FINAL_DIR)
best_model = AutoModelForSeq2SeqLM.from_pretrained(FINAL_DIR)

print("\n--- VALID ---")
evaluate(valid_ds, best_model, best_tok)
print("\n--- TEST ---")
evaluate(test_ds, best_model, best_tok)

demo_sent = 'The hijacking of Lufthansa Flight 615 was an act of terrorism committed by a Palestinian group that occurred on 29 October 1972 and aimed at the liberation of the three surviving perpetrators of the Munich massacre from a West German prison.'
demo_prompt = build_prompt(demo_sent, k=3)
print("\nDEMO:\n", generate_batch([demo_prompt], best_model, best_tok)[0])


--- VALID ---


Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call.



===== STRICT =====
P=0.6084 R=0.2817 F1=0.3851
===== PARTIAL (MUC 0.5) =====
P=0.6105 R=0.2761 F1=0.3802
===== RELAXED (DEGREE2) =====
Relaxed-Recall=0.6655 | Relaxed-F1≈0.6368

--- TEST ---

===== STRICT =====
P=0.6123 R=0.1402 F1=0.2282
===== PARTIAL (MUC 0.5) =====
P=0.6135 R=0.1390 F1=0.2266
===== RELAXED (DEGREE2) =====
Relaxed-Recall=0.6268 | Relaxed-F1≈0.6201

DEMO:
 Event type: killing. Trigger: massacre.


In [None]:
# ---------------------------
# 0) Imports
# ---------------------------
import os, re, math, time, random, shutil, difflib, glob, json, torch
from typing import List, Tuple
from collections import Counter
from pathlib import Path

from datasets import load_dataset
from transformers.trainer_utils import get_last_checkpoint
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments,
    EarlyStoppingCallback, TrainerCallback, TrainingArguments,
    TrainerControl, TrainerState
)

try:
    from google.colab import files
except Exception:
    files = None

In [None]:
# -----------------------------
# 1) Config (IDENTICAL to Code A)
# -----------------------------
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)

DATA_DIR   = "/content/output"
TRAIN_FILE = f"{DATA_DIR}/llm_train_gen.jsonl"
VALID_FILE = f"{DATA_DIR}/llm_valid_gen.jsonl"
TEST_FILE  = f"{DATA_DIR}/llm_test_gen.jsonl"

BASE_MODEL = "t5-base"
OUT_DIR    = "/content/t5_degree2_ckpt"
FINAL_DIR  = "/content/t5_degree2_final"

MAX_IN_LEN  = 512
MAX_OUT_LEN = 256
EVENT_TOKEN = "<EVENTSEP>"
CHUNK_N_SENT = None

In [None]:
# -----------------------------
# 2) Regex & utils (IDENTICAL)
# -----------------------------
EVENT_SEP = re.compile(r"\s*<EVENTSEP>\s*", re.IGNORECASE)
TYPE_RE   = re.compile(r'event\s*type\s*:\s*(.+?)\.', re.IGNORECASE)
TRIG_RE   = re.compile(r'trigger\s*:\s*(.+?)\.', re.IGNORECASE)
TOK_RE    = re.compile(r"\w+", re.UNICODE)

def norm(s: str) -> str:
    return re.sub(r"\s+"," ", s.strip().lower()) if s else s

def parse_pairs(text: str) -> List[Tuple[str,str]]:
    if not text or "no events" in text.lower(): return []
    out=[]
    for ch in [c for c in EVENT_SEP.split(text) if c.strip()]:
        t = TYPE_RE.search(ch); g = TRIG_RE.search(ch)
        et = norm(t.group(1)) if t else None
        tr = norm(g.group(1)) if g else None
        if et and tr: out.append((et,tr))
    return out

def dedup_events_str(text: str) -> str:
    if not text: return ""
    seen=set(); kept=[]
    for p in [p.strip() for p in EVENT_SEP.split(text) if p.strip()]:
        key = norm(p)
        if key and key not in seen:
            seen.add(key); kept.append(p)
    return f" {EVENT_TOKEN} ".join(kept)

def token_set(s): return set(TOK_RE.findall(s.lower())) if s else set()

def trigger_overlap(p,g):
    if not p or not g: return 0.0
    if norm(p)==norm(g): return 1.0
    ps,gs = token_set(p), token_set(g)
    return 0.5 if (ps and gs and ps&gs) else 0.0

In [None]:
# -----------------------------
# 3) Data loading + few-shot (fewshot_k=3 like Code A)
# -----------------------------
def load_plain(path): return load_dataset("json", data_files={"data": path})["data"]

train_plain = load_plain(TRAIN_FILE)
valid_plain = load_plain(VALID_FILE)
test_plain  = load_plain(TEST_FILE)

TYPE_FREQ = Counter()
for ex in train_plain:
    for et,tr in parse_pairs(ex["output"]):
        TYPE_FREQ[et]+=1
ONTOLOGY = sorted(TYPE_FREQ.keys())

def nearest_type(t: str) -> str:
    if not t or not ONTOLOGY: return t
    t = norm(t)
    cand = difflib.get_close_matches(t, ONTOLOGY, n=1, cutoff=0.8)
    return cand[0] if cand else t

def build_fewshot_bank(ds, k=60):
    bank=[]
    for ex in ds:
        out = ex["output"]
        if (EVENT_TOKEN in out) and (len(out)<600):
            src = ex["input"]
            m = re.search(r'\"(.+?)\"\s*\nUse <EVENTSEP>', src, re.DOTALL)
            sent = m.group(1) if m else src
            bank.append((sent.strip(), out.strip()))
    random.shuffle(bank)
    return bank[:k]

FEWSHOT = build_fewshot_bank(train_plain, k=60)
def sample_k_shots(k=3):
    k = min(k, len(FEWSHOT))
    return random.sample(FEWSHOT, k) if k>0 else []

Generating data split: 0 examples [00:00, ? examples/s]

Generating data split: 0 examples [00:00, ? examples/s]

Generating data split: 0 examples [00:00, ? examples/s]

In [None]:
# -----------------------------
# 4) Prompt builder (IDENTICAL, no “Allowed types…”)
# -----------------------------
PROMPT_HEADER = (
    "Extract ALL events from the sentence below.\n"
    f"Output only lines like: {EVENT_TOKEN} Event type: <TYPE>. Trigger: <TRIGGER>.\n"
    "If no events, output exactly: No events.\n\n"
)
def build_prompt(sentence: str, k=3) -> str:
    head = PROMPT_HEADER
    shots = sample_k_shots(k)
    if shots:
        head += "### Examples\n"
        for s,o in shots:
            head += f'Sentence: "{s}"\n{o}\n\n'
    head += "### Now extract\n"
    head += f'Sentence: "{sentence}"\nOutput:\n'
    return head

def chunk_text_to_sentences(text: str) -> list:
    return re.split(r'(?<=[\.\!\?])\s+', text.strip())

def chunk_doc(sentence_or_doc: str, n=3):
    sents = chunk_text_to_sentences(sentence_or_doc)
    if len(sents)<=n: return [sentence_or_doc]
    chunks=[]
    for i in range(0,len(sents),n):
        chunks.append(" ".join(sents[i:i+n]))
    return chunks

def to_prompted(ds, fewshot_k=3):
    def _map(ex):
        txt = ex["input"]
        if CHUNK_N_SENT:
            pieces = chunk_doc(txt, n=CHUNK_N_SENT)
            txt = pieces[0]
        return {"input": build_prompt(txt, k=fewshot_k), "output": ex["output"]}
    return ds.map(_map, remove_columns=[c for c in ds.column_names if c not in ("input","output")])

train_ds = to_prompted(train_plain, fewshot_k=3)
valid_ds = to_prompted(valid_plain, fewshot_k=3)
test_ds  = to_prompted(test_plain,  fewshot_k=3)

Map:   0%|          | 0/32431 [00:00<?, ? examples/s]

Map:   0%|          | 0/8042 [00:00<?, ? examples/s]

Map:   0%|          | 0/9400 [00:00<?, ? examples/s]

In [None]:
# -----------------------------
# 5) Tokenizer & collator — robust resume from checkpoint archive
# -----------------------------
import os, glob, shutil, zipfile, tarfile

RESUME_DIR = "/content/t5_degree2_ckpt/checkpoint-13600"  # must match OUT_DIR/checkpoint-13600

def looks_like_checkpoint_dir(p: str) -> bool:
    if not os.path.isdir(p):
        return False
    files = set(os.listdir(p))
    has_model = any(f in files for f in (
        "pytorch_model.bin", "pytorch_model.bin.index.json",
        "model.safetensors", "model.safetensors.index.json"
    ))
    has_state = "trainer_state.json" in files
    return has_model and has_state

def print_tree(path: str, max_depth: int = 2):
    print(f"\n>>> Tree under: {path}")
    base_depth = path.rstrip("/").count("/")
    for root, dirs, files in os.walk(path):
        depth = root.count("/") - base_depth
        if depth > max_depth:
            continue
        show = [f for f in files if f.endswith((".json",".bin",".safetensors",".pt",".pth"))]
        if show:
            print("  -", root, "->", show)

def extract_archive_gently(archive_path: str, dest_dir: str):
    # Try as zip
    if zipfile.is_zipfile(archive_path):
        with zipfile.ZipFile(archive_path, "r") as zf:
            zf.extractall(dest_dir)
        print(f"[ok] Extracted zip into: {dest_dir}")
        return True
    # Try as tar / compressed tar
    try:
        if tarfile.is_tarfile(archive_path):
            with tarfile.open(archive_path, "r:*") as tf:
                tf.extractall(dest_dir)
            print(f"[ok] Extracted tar into: {dest_dir}")
            return True
    except tarfile.TarError:
        pass
    # Not a recognized archive
    print(f"[warn] '{archive_path}' is not a valid zip/tar archive.")
    return False

# 1) If a checkpoint archive exists, try to extract (WITHOUT touching optimizer/scheduler)
zip_candidates = (
    glob.glob(os.path.join(OUT_DIR, "checkpoint-13600*")) +   # maybe user placed here
    glob.glob("/content/checkpoint-13600*")                   # or uploaded to /content
)
# Prefer files that look like archives
archive_candidates = [p for p in zip_candidates if os.path.isfile(p) and not p.endswith(".ipynb")]
archive_candidates.sort()

if archive_candidates and not looks_like_checkpoint_dir(RESUME_DIR):
    candidate = archive_candidates[-1]
    print(f"[info] Found archive candidate: {candidate}")
    extracted = extract_archive_gently(candidate, OUT_DIR)
    if not extracted:
        print("[hint] If this is a folder, upload the *folder* (unzipped) or a real .zip/.tar(.gz) archive.")

# 2) Fix common nesting: e.g., OUT_DIR/checkpoint-13600_2025.../checkpoint-13600/
if not looks_like_checkpoint_dir(RESUME_DIR):
    alt_dirs = [d for d in glob.glob(os.path.join(OUT_DIR, "checkpoint-13600*")) if os.path.isdir(d)]
    # Prefer an alt that already looks like a valid checkpoint
    valid_alts = [d for d in alt_dirs if looks_like_checkpoint_dir(d)]
    if valid_alts:
        src = valid_alts[0]
        if src != RESUME_DIR:
            if os.path.exists(RESUME_DIR):
                shutil.rmtree(RESUME_DIR)
            os.rename(src, RESUME_DIR)
            print(f"[fix] Renamed '{src}' -> '{RESUME_DIR}'")

# 3) If files landed directly under OUT_DIR, move them into checkpoint-13600
if not looks_like_checkpoint_dir(RESUME_DIR):
    needed = [
        "config.json", "trainer_state.json",
        "pytorch_model.bin", "pytorch_model.bin.index.json",
        "model.safetensors", "model.safetensors.index.json",
        "optimizer.pt", "scheduler.pt", "rng_state.pth", "scaler.pt",
        # tokenizer & extras
        "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json",
        "added_tokens.json", "generation_config.json", "training_args.bin"
    ]
    present = [f for f in needed if os.path.exists(os.path.join(OUT_DIR, f))]
    if present:
        print(f"[fix] Creating '{RESUME_DIR}' and moving loose files into it…")
        os.makedirs(RESUME_DIR, exist_ok=True)
        for f in present:
            src = os.path.join(OUT_DIR, f)
            if os.path.exists(src):
                shutil.move(src, os.path.join(RESUME_DIR, f))

# 4) Final validation
print_tree(OUT_DIR, max_depth=2)
if not looks_like_checkpoint_dir(RESUME_DIR):
    raise RuntimeError(
        f"❌ '{RESUME_DIR}' is not a valid checkpoint folder.\n"
        f"Expected model weights (pytorch_model.bin or model.safetensors) AND trainer_state.json.\n"
        f"Make sure you uploaded the correct archive or folder."
    )

# 5) Load tokenizer from checkpoint for exact continuity
try:
    tokenizer = AutoTokenizer.from_pretrained(RESUME_DIR)
    print("[tok] Loaded tokenizer from checkpoint.")
except Exception as e:
    print(f"[tok] Failed to load tokenizer from checkpoint ({e}); falling back to base model tokenizer.")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=None, padding="longest")

[info] Found archive candidate: /content/checkpoint-13600_20250817-012440.zip
[ok] Extracted zip into: /content/t5_degree2_ckpt
[fix] Creating '/content/t5_degree2_ckpt/checkpoint-13600' and moving loose files into it…

>>> Tree under: /content/t5_degree2_ckpt
  - /content/t5_degree2_ckpt/checkpoint-13600 -> ['tokenizer.json', 'training_args.bin', 'model.safetensors', 'config.json', 'rng_state.pth', 'trainer_state.json', 'special_tokens_map.json', 'scheduler.pt', 'added_tokens.json', 'optimizer.pt', 'tokenizer_config.json', 'generation_config.json']
[tok] Loaded tokenizer from checkpoint.


In [None]:
# -----------------------------
# 6) Save helpers + StopAtStep (IDENTICAL)
# -----------------------------
def get_latest_ckpt_path(out_dir):
    try:
        return get_last_checkpoint(out_dir)
    except Exception:
        return None

def save_final_snapshot(trainer, tokenizer, base_dir, tag=None):
    step = int(trainer.state.global_step) if trainer and trainer.state else 0
    tag  = tag or f"final-{step}"
    snap_dir = os.path.join(base_dir, f"checkpoint-{tag}")
    os.makedirs(snap_dir, exist_ok=True)
    trainer.save_model(snap_dir)
    if tokenizer is not None:
        tokenizer.save_pretrained(snap_dir)
    with open(os.path.join(base_dir, "LATEST.txt"), "w") as f:
        f.write(snap_dir)
    print(f"[Snapshot] Saved '{snap_dir}' and updated LATEST.txt")
    return snap_dir

class StopAtStepCallback(TrainerCallback):
    def __init__(self, target_step=None, out_dir=None, tokenizer=None):
        self.target_step = target_step
        self.out_dir = out_dir
        self.tokenizer = tokenizer
    def _read_target(self):
        env_val = os.getenv("STOP_AT_STEP", "").strip()
        if env_val.isdigit():
            return int(env_val)
        fpath = "/content/stop_at_step.txt"
        if os.path.exists(fpath):
            try:
                with open(fpath) as f:
                    val = f.read().strip()
                if val.isdigit():
                    return int(val)
            except:
                pass
        return self.target_step
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        tgt = self._read_target()
        if tgt is None: return
        if state.global_step >= tgt:
            trainer = kwargs.get("trainer", None)
            print(f"\n[StopAtStep] Reached step {state.global_step} (target={tgt}). Saving and stopping...")
            if trainer is not None and self.out_dir is not None:
                save_final_snapshot(trainer, self.tokenizer, self.out_dir, tag=f"stop-{state.global_step}")
            control.should_training_stop = True
            control.should_save = True
            return control

In [None]:
# -----------------------------
# 7) Load model from checkpoint (EXACT weights)
# -----------------------------
model = AutoModelForSeq2SeqLM.from_pretrained(RESUME_DIR)
# match Code A behavior (grad checkpointing enabled originally)
try:
    model.gradient_checkpointing_enable()
    if getattr(model.config, "use_cache", None):
        model.config.use_cache = False
except Exception:
    pass

In [None]:
# -----------------------------
# 8) TrainingArguments (EXACTLY like Code A)
# -----------------------------
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

args = Seq2SeqTrainingArguments(
    output_dir=OUT_DIR,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=800,
    save_steps=800,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,

    num_train_epochs=10,
    max_steps=-1,
    learning_rate=1e-4,
    weight_decay=1e-5,
    lr_scheduler_type="constant",
    optim="adafactor",

    label_smoothing_factor=0.1,
    max_grad_norm=1.0,
    predict_with_generate=False,

    fp16=False,
    bf16=use_bf16,
    report_to="none",
    logging_steps=100,
    save_safetensors=True,
    seed=SEED, data_seed=SEED,
    remove_unused_columns=False,
)

In [None]:
# -----------------------------
# 9) Tokenize (with text_target) → Build collator (with model) → Build Trainer → Resume
# -----------------------------

# 9.1) (Re)preprocess with the tokenizer loaded from checkpoint (use text_target, not as_target_tokenizer)
def preprocess(batch):
    enc = tokenizer(
        batch["input"],
        max_length=MAX_IN_LEN,
        truncation=True,
        padding="max_length"
    )
    tgt = tokenizer(
        text_target=batch["output"],
        max_length=MAX_OUT_LEN,
        truncation=True,
        padding="max_length"
    )
    enc["labels"] = tgt["input_ids"]
    return enc

# Create tokenized datasets FIRST
train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
valid_tok = valid_ds.map(preprocess, batched=True, remove_columns=valid_ds.column_names)

# 9.2) IMPORTANT: collator must see the model so it can prepare decoder_input_ids when needed
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,                 # <— not None
    padding="longest",
    label_pad_token_id=-100
)

# 9.3) Callbacks
stop_cb = StopAtStepCallback(target_step=None, out_dir=OUT_DIR, tokenizer=tokenizer)
callbacks = [EarlyStoppingCallback(early_stopping_patience=3), stop_cb]

# 9.4) Build Trainer AFTER tokenization & with the proper collator
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=valid_tok,
    tokenizer=tokenizer,     # FutureWarning-safe; still fine here
    data_collator=data_collator,
    callbacks=callbacks
)

# Optional auto-stop
os.environ["STOP_AT_STEP"] = "20000"  # remove if you don’t want auto-stop

print("Resuming EXACTLY from:", RESUME_DIR)
trainer.train(resume_from_checkpoint=RESUME_DIR)

Map:   0%|          | 0/32431 [00:00<?, ? examples/s]

Map:   0%|          | 0/8042 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(
You are resuming training from a checkpoint trained with 4.55.2 of Transformers but your current version is 4.55.4. This is not recommended and could yield to errors or unwanted behaviors.


Resuming EXACTLY from: /content/t5_degree2_ckpt/checkpoint-13600


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Step,Training Loss,Validation Loss
14400,1.4103,1.395697
15200,1.41,1.393541
16000,1.41,1.394261
16800,1.4082,1.392042
17600,1.4088,1.396248
18400,1.4097,1.392669
19200,1.4084,1.391294
20000,1.4073,1.391613



[StopAtStep] Reached step 20000 (target=20000). Saving and stopping...


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=20000, training_loss=0.45106986541748045, metrics={'train_runtime': 17800.9279, 'train_samples_per_second': 18.219, 'train_steps_per_second': 2.277, 'total_flos': 9.743082665803776e+16, 'train_loss': 0.45106986541748045, 'epoch': 4.93339911198816})

In [None]:
# -----------------------------
# 10) Save final + zip best (robust)
# -----------------------------
import os, json, time, shutil, zipfile
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint

# 10.0) Always save a final snapshot and export FINAL_DIR
final_snap = save_final_snapshot(trainer, tokenizer, OUT_DIR)  # e.g. OUT_DIR/checkpoint-final-<step>
os.makedirs(FINAL_DIR, exist_ok=True)
trainer.save_model(FINAL_DIR); tokenizer.save_pretrained(FINAL_DIR)
print("Saved best model to:", FINAL_DIR)

def _is_valid_checkpoint_dir(p: str) -> bool:
    if not os.path.isdir(p):
        return False
    files = set(os.listdir(p))
    has_model = any(f in files for f in (
        "pytorch_model.bin", "pytorch_model.bin.index.json",
        "model.safetensors", "model.safetensors.index.json"
    ))
    has_state = "trainer_state.json" in files
    return has_model and has_state

def _safe_zip_dir(src_dir: str, out_prefix: str) -> str:
    """
    Create a .zip archive of src_dir as /content/<out_prefix>.zip
    Verifies it's a valid zip and prints its size.
    """
    stamp = time.strftime("%Y%m%d-%H%M%S")
    base = f"/content/{out_prefix}_{stamp}"
    zip_path = shutil.make_archive(base, "zip", src_dir)
    # verify zip
    try:
        with zipfile.ZipFile(zip_path, "r") as zf:
            bad = zf.testzip()
        if bad is not None:
            print(f"[warn] Zip integrity issue with file: {bad}")
    except zipfile.BadZipFile:
        print("[warn] Created file is not recognized as zip (BadZipFile).")
    # size
    size_gb = os.path.getsize(zip_path) / (1024**3)
    print(f"[zip] {zip_path}  (~{size_gb:.2f} GB)")
    return zip_path

def find_best_checkpoint(out_dir: str) -> str | None:
    """
    Prefer 'best_model_checkpoint' in trainer_state.json.
    Fallbacks:
      1) Last recorded 'best_model_checkpoint' in log_history
      2) Latest checkpoint directory under out_dir
      3) As a last resort, the final snapshot we just saved
    Returns a directory path or None.
    """
    state_path = os.path.join(out_dir, "trainer_state.json")
    best = None
    if os.path.exists(state_path):
        try:
            with open(state_path, "r") as f:
                st = json.load(f)
            # direct best
            cand = st.get("best_model_checkpoint", None)
            if cand and os.path.exists(cand) and _is_valid_checkpoint_dir(cand):
                best = cand
            # search in log_history (in reverse, most recent first)
            if best is None:
                for rec in reversed(st.get("log_history", [])):
                    if isinstance(rec, dict) and "best_model_checkpoint" in rec:
                        cand = rec["best_model_checkpoint"]
                        if cand and os.path.exists(cand) and _is_valid_checkpoint_dir(cand):
                            best = cand
                            break
        except Exception as e:
            print(f"[warn] Could not parse trainer_state.json ({e}).")

    if best is None:
        # fallback to the latest checkpoint (HF naming)
        try:
            latest = get_last_checkpoint(out_dir)
            if latest and _is_valid_checkpoint_dir(latest):
                best = latest
        except Exception:
            pass

    if best is None and final_snap and _is_valid_checkpoint_dir(final_snap):
        best = final_snap

    return best

def pick_checkpoint_to_zip(out_dir: str) -> str:
    ckpt = find_best_checkpoint(out_dir)
    if ckpt is None:
        raise FileNotFoundError("No valid checkpoint found to zip.")
    print(f"[info] Best checkpoint selected: {ckpt}")
    return ckpt

try:
    # 10.1) Zip BEST checkpoint
    ckpt_dir = pick_checkpoint_to_zip(OUT_DIR)
    zip_path = _safe_zip_dir(ckpt_dir, out_prefix=Path(ckpt_dir).name)
    # 10.2) Offer download (Colab)
    if files is not None:
        files.download(zip_path)
except Exception as e:
    print("[zip-best] skip:", e)

# 10.3) (Optional) also zip FINAL_DIR — set to True if you want both
ZIP_FINAL_DIR_TOO = False
if ZIP_FINAL_DIR_TOO:
    try:
        final_zip = _safe_zip_dir(FINAL_DIR, out_prefix=os.path.basename(FINAL_DIR))
        if files is not None:
            files.download(final_zip)
    except Exception as e:
        print("[zip-final] skip:", e)

[Snapshot] Saved '/content/t5_degree2_ckpt/checkpoint-final-20000' and updated LATEST.txt
Saved best model to: /content/t5_degree2_final
[info] Best checkpoint selected: /content/t5_degree2_ckpt/checkpoint-20000
[zip] /content/checkpoint-20000_20250824-135044.zip  (~0.77 GB)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# -----------------------------
# 11) Inference & Evaluation (IDENTICAL to A)
# -----------------------------
def clean_and_canon(text: str) -> str:
    text = dedup_events_str(text)
    parts=[]
    for ch in [c for c in EVENT_SEP.split(text) if c.strip()]:
        t = TYPE_RE.search(ch); g = TRIG_RE.search(ch)
        et = norm(t.group(1)) if t else None
        tr = norm(g.group(1)) if g else None
        if et:
            et = nearest_type(et)
        if et and tr:
            parts.append(f"Event type: {et}. Trigger: {tr}.")
    return f" {EVENT_TOKEN} ".join(parts) if parts else "No events."

def generate_batch(
    prompts: List[str],
    mdl,
    tok,
    bs: int = 8,
    device: str | None = None,
    samples: int = 3,
    top_p: float = 0.9,
    temperature: float = 0.7,
    max_new_tokens: int = 160,
    min_new_tokens: int = 12,
    no_repeat_ngram_size: int = 3,
    repetition_penalty: float = 1.02,
):

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mdl.eval().to(device)

    outs = []
    with torch.no_grad():
        for i in range(0, len(prompts), bs):
            batch_prompts = prompts[i:i+bs]

            all_decoded = []
            for _ in range(samples):
                enc = tok(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=MAX_IN_LEN
                ).to(device)

                gen_kwargs = dict(
                    do_sample=True,
                    top_p=top_p,
                    temperature=temperature,
                    max_new_tokens=max_new_tokens,
                    min_new_tokens=min_new_tokens,
                    no_repeat_ngram_size=no_repeat_ngram_size,
                    repetition_penalty=repetition_penalty,
                    early_stopping=False,
                    trust_remote_code=True,
                )
                gen = mdl.generate(**enc, **gen_kwargs)
                dec = tok.batch_decode(gen, skip_special_tokens=True)
                all_decoded.append(dec)

            for k in range(len(batch_prompts)):
                variants = [all_decoded[s][k] for s in range(samples)]
                merged_text = f" {EVENT_TOKEN} ".join(variants)
                outs.append(clean_and_canon(merged_text))

    return outs

def match_partial(preds, golds):
    used=set(); exact=partial=0
    for pt,ptr in preds:
        best=(0.0,-1)
        for j,(gt,gtr) in enumerate(golds):
            if j in used: continue
            if pt!=gt:   continue
            ov = trigger_overlap(ptr,gtr)
            if ov>best[0]: best=(ov,j)
        if best[1]!=-1:
            used.add(best[1])
            if math.isclose(best[0],1.0): exact+=1
            elif best[0]>=0.5: partial+=1
    return exact, partial, len(preds), len(golds)

def prf(e,p,pt,gt, w=0.5):
    wtp = e + w*p
    P = wtp/pt if pt else 0.0
    R = wtp/gt if gt else 0.0
    F = (2*P*R)/(P+R) if (P+R) else 0.0
    return P,R,F

def relaxed_recall_by_chunks(pairs_pred, pairs_gold, chunk_size=1):
    correct=partial=extra=possible=impossible=0
    for pp,gg in zip(pairs_pred, pairs_gold):
        ce,cp,pt,gt = match_partial(pp,gg)
        matched = ce + 0.5*cp
        possible += gt
        impossible += max(0, gt-1)
        extra += max(0.0, matched-1.0)
        correct += ce; partial += cp
    denom = max(1, possible - impossible)
    num   = max(0.0, (correct + 0.5*partial) - extra)
    return num/denom

def evaluate(ds, mdl, tok):
    prompts = [ex["input"] for ex in ds]
    gtexts  = [ex["output"] for ex in ds]
    preds   = generate_batch(prompts, mdl, tok, bs=8)

    strict_tp=strict_pred=strict_gold=0
    part_e=part_p=part_pt=part_gt=0
    chunks_pred=[]; chunks_gold=[]

    for ptxt,gtxt in zip(preds, gtexts):
        pp = parse_pairs(ptxt)
        gg = parse_pairs(gtxt)
        sp,sg = set(pp), set(gg)
        tp = len(sp & sg)
        strict_tp += tp; strict_pred += len(sp); strict_gold += len(sg)
        ce,cp,pt,gt = match_partial(pp,gg)
        part_e += ce; part_p += cp; part_pt += pt; part_gt += gt
        chunks_pred.append(pp); chunks_gold.append(gg)

    sP = strict_tp/strict_pred if strict_pred else 0.0
    sR = strict_tp/strict_gold if strict_gold else 0.0
    sF = (2*sP*sR)/(sP+sR) if (sP+sR) else 0.0

    pP,pR,pF = prf(part_e,part_p,part_pt,part_gt, w=0.5)
    r_rel = relaxed_recall_by_chunks(chunks_pred, chunks_gold)

    print("\n===== STRICT =====")
    print(f"P={sP:.4f} R={sR:.4f} F1={sF:.4f}")
    print("===== PARTIAL (MUC 0.5) =====")
    print(f"P={pP:.4f} R={pR:.4f} F1={pF:.4f}")
    print("===== RELAXED (DEGREE2) =====")
    rF = (2*pP*r_rel)/(pP+r_rel) if (pP+r_rel)>0 else 0.0
    print(f"Relaxed-Recall={r_rel:.4f} | Relaxed-F1≈{rF:.4f}")
    return dict(strict_f1=sF, partial_f1=pF, relaxed_recall=r_rel, relaxed_f1=rF)

# Quick sanity + Eval
best_tok   = AutoTokenizer.from_pretrained(FINAL_DIR)
best_model = AutoModelForSeq2SeqLM.from_pretrained(FINAL_DIR)

print("\n--- VALID ---")
evaluate(valid_ds, best_model, best_tok)
print("\n--- TEST ---")
evaluate(test_ds, best_model, best_tok)

demo_sent = (
    "The hijacking of Lufthansa Flight 615 was an act of terrorism committed by a Palestinian group "
    "that occurred on 29 October 1972 and aimed at the liberation of the three surviving perpetrators "
    "of the Munich massacre from a West German prison."
)
demo_prompt = build_prompt(demo_sent, k=3)
print("\nDEMO:\n", generate_batch([demo_prompt], best_model, best_tok)[0])


--- VALID ---


Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call.



===== STRICT =====
P=0.5966 R=0.2762 F1=0.3776
===== PARTIAL (MUC 0.5) =====
P=0.5979 R=0.2704 F1=0.3723
===== RELAXED (DEGREE2) =====
Relaxed-Recall=0.6517 | Relaxed-F1≈0.6236

--- TEST ---

===== STRICT =====
P=0.5958 R=0.1364 F1=0.2220
===== PARTIAL (MUC 0.5) =====
P=0.5970 R=0.1352 F1=0.2205
===== RELAXED (DEGREE2) =====
Relaxed-Recall=0.6098 | Relaxed-F1≈0.6033

DEMO:
 Event type: commitment. Trigger: committed.
