In [None]:
!pip install datasets accelerate evaluate sentencepiece
!pip install --upgrade transformers

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.5


In [None]:
# ---------------------------
# 0) Imports
# ---------------------------
import os, re, math, time, random, shutil, difflib, glob, json, zipfile, tarfile
from typing import List, Tuple
from collections import Counter
from pathlib import Path

import torch
from datasets import load_dataset
from transformers.trainer_utils import get_last_checkpoint
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments,
    EarlyStoppingCallback, TrainerCallback, TrainingArguments,
    TrainerControl, TrainerState
)

try:
    from google.colab import files
except Exception:
    files = None

In [None]:
# -----------------------------
# 1) Config (parallel to Code A)
# -----------------------------
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)

DATA_DIR   = "/content/output"
TRAIN_FILE = f"{DATA_DIR}/llm_train_gen.jsonl"
VALID_FILE = f"{DATA_DIR}/llm_valid_gen.jsonl"
TEST_FILE  = f"{DATA_DIR}/llm_test_gen.jsonl"

# >>> Use FLAN-T5 here <<<
BASE_MODEL = "google/flan-t5-base"   # or "google/flan-t5-large" (watch VRAM)
OUT_DIR    = "/content/flan_t5_degree2_ckpt"
FINAL_DIR  = "/content/flan_t5_degree2_final"

MAX_IN_LEN  = 512
MAX_OUT_LEN = 256
EVENT_TOKEN = "<EVENTSEP>"
CHUNK_N_SENT = None

In [None]:
# -----------------------------
# 2) Regex & utils
# -----------------------------
EVENT_SEP = re.compile(r"\s*<EVENTSEP>\s*", re.IGNORECASE)
TYPE_RE   = re.compile(r'event\s*type\s*:\s*(.+?)\.', re.IGNORECASE)
TRIG_RE   = re.compile(r'trigger\s*:\s*(.+?)\.', re.IGNORECASE)
TOK_RE    = re.compile(r"\w+", re.UNICODE)

def norm(s: str) -> str:
    return re.sub(r"\s+"," ", s.strip().lower()) if s else s

def parse_pairs(text: str) -> List[Tuple[str,str]]:
    if not text or "no events" in text.lower(): return []
    out=[]
    for ch in [c for c in EVENT_SEP.split(text) if c.strip()]:
        t = TYPE_RE.search(ch); g = TRIG_RE.search(ch)
        et = norm(t.group(1)) if t else None
        tr = norm(g.group(1)) if g else None
        if et and tr: out.append((et,tr))
    return out

def dedup_events_str(text: str) -> str:
    if not text: return ""
    seen=set(); kept=[]
    for p in [p.strip() for p in EVENT_SEP.split(text) if p.strip()]:
        key = norm(p)
        if key and key not in seen:
            seen.add(key); kept.append(p)
    return f" {EVENT_TOKEN} ".join(kept)

def token_set(s): return set(TOK_RE.findall(s.lower())) if s else set()

def trigger_overlap(p,g):
    if not p or not g: return 0.0
    if norm(p)==norm(g): return 1.0
    ps,gs = token_set(p), token_set(g)
    return 0.5 if (ps and gs and ps&gs) else 0.0

In [None]:
# -----------------------------
# 3) Data loading + cleaning / alignment (+ optional upsampling)
# -----------------------------
import unicodedata
from difflib import SequenceMatcher
from datasets import Dataset, concatenate_datasets

# ---- A. Input normalization ----
_WS    = re.compile(r"\s+")
_CTRL  = re.compile(r"[\u0000-\u001F\u007F]")
HTML_TAG = re.compile(r"</?[^>]+>")

def normalize_input_text(s: str) -> str:
    if not isinstance(s, str): return ""
    s = unicodedata.normalize("NFKC", s)
    s = (s.replace("“","\"").replace("”","\"")
           .replace("‘","'").replace("’","'")
           .replace("–","-").replace("—","-"))
    s = _CTRL.sub(" ", s)
    s = HTML_TAG.sub(" ", s)                 # strip naive HTML tags
    s = re.sub(r"\\[a-zA-Z]+", " ", s)       # \alpha, \cite, ...
    s = re.sub(r"\{[^{}]{0,200}\}", " ", s)  # {…} light cleanup
    s = _WS.sub(" ", s).strip()
    return s

# ---- B. Target (gold) normalization ----
# strict single-line validator
EVENT_LINE_RE = re.compile(
    r'^\s*Event\s*type\s*:\s*([^.\n\r:]+)\.\s*Trigger\s*:\s*([^.\n\r<]+)\.\s*$',
    re.IGNORECASE
)

def normalize_output_line(line: str) -> str | None:
    if not isinstance(line, str): return None
    line = unicodedata.normalize("NFKC", line)
    line = line.replace("–","-").replace("—","-").strip()
    if not line.endswith("."):
        line += "."
    m = EVENT_LINE_RE.match(line)
    if not m:
        return None
    et = re.sub(r"\s+", " ", m.group(1).strip().lower())
    tr = re.sub(r"\s+", " ", m.group(2).strip().lower())
    return f"Event type: {et}. Trigger: {tr}."

def normalize_target_text(y: str, token: str = EVENT_TOKEN) -> str:
    if not isinstance(y, str) or not y.strip():
        return "No events."
    parts = [p for p in re.split(rf"\s*{re.escape(token)}\s*", y) if p.strip()]
    clean = []
    seen = set()
    for p in parts:
        nl = normalize_output_line(p)
        if nl and nl not in seen:
            seen.add(nl)
            clean.append(nl)
    return f" {token} ".join(clean) if clean else "No events."

# ---- C. Anti-hallucination: ensure trigger appears (roughly) in source ----
def _rough_contains(text: str, phrase: str, threshold=0.82) -> bool:
    text_l   = text.lower()
    phrase_l = phrase.lower()
    if phrase_l in text_l:
        return True
    n = len(phrase_l)
    if n < 3:
        return False
    # coarse sliding window; hop = n//2 (fast & robust)
    hop = max(1, n // 2)
    best = 0.0
    for i in range(0, max(1, len(text_l) - n + 1), hop):
        cand = text_l[i:i+n]
        best = max(best, SequenceMatcher(None, cand, phrase_l).ratio())
        if best >= threshold:
            return True
    return False

def align_and_filter_output(src: str, y: str, token: str = EVENT_TOKEN) -> str:
    src = normalize_input_text(src)
    if y.strip().lower() == "no events.":
        return y
    kept = []
    for chunk in [c for c in re.split(rf"\s*{re.escape(token)}\s*", y) if c.strip()]:
        nl = normalize_output_line(chunk.strip())
        if not nl:
            continue
        m = EVENT_LINE_RE.match(nl)
        trig = m.group(2) if m else ""
        # keep only if trigger is (roughly) anchored in source
        if _rough_contains(src, trig, threshold=0.82):
            kept.append(nl)
    return f" {token} ".join(kept) if kept else "No events."

# ---- D. Example-level validation ----
def is_valid_example(inp: str, out: str, min_chars_src=20, max_out_len=800) -> bool:
    if not inp or len(inp) < min_chars_src:
        return False
    if not out or len(out) > max_out_len:
        return False
    if out.strip().lower() == "no events.":
        return True
    # must contain at least one valid event line
    for chunk in re.split(rf"\s*{re.escape(EVENT_TOKEN)}\s*", out):
        if normalize_output_line(chunk.strip()):
            return True
    return False

# ---- E. Load raw → clean/align → filter ----
def load_plain(path):
    return load_dataset("json", data_files={"data": path})["data"]

raw_train = load_plain(TRAIN_FILE)
raw_valid = load_plain(VALID_FILE)
raw_test  = load_plain(TEST_FILE)

def _clean_map(ex):
    # normalize source
    src = normalize_input_text(ex["input"])
    # normalize target format
    tgt = normalize_target_text(ex["output"], token=EVENT_TOKEN)
    # align triggers to source (drop unanchored)
    tgt = align_and_filter_output(src, tgt, token=EVENT_TOKEN)
    return {"input": src, "output": tgt}

train_clean = raw_train.map(_clean_map)
valid_clean = raw_valid.map(_clean_map)
test_clean  = raw_test.map(_clean_map)

# filter malformed/noisy
train_clean = train_clean.filter(lambda ex: is_valid_example(ex["input"], ex["output"]))
valid_clean = valid_clean.filter(lambda ex: is_valid_example(ex["input"], ex["output"]))
test_clean  = test_clean.filter(lambda ex: is_valid_example(ex["input"], ex["output"]))

# ---- F. Prevent leakage: drop exact source duplicates across splits ----
train_keys = set(train_clean.map(lambda ex: {"k": normalize_input_text(ex["input"])} )["k"])

def _not_in_train(ex):
    return normalize_input_text(ex["input"]) not in train_keys

valid_clean = valid_clean.filter(_not_in_train)
test_clean  = test_clean.filter(_not_in_train)

# ---- G. Build ontology from *clean* train ----
TYPE_FREQ = Counter()
for ex in train_clean:
    for et,tr in parse_pairs(ex["output"]):
        TYPE_FREQ[et] += 1
ONTOLOGY = sorted(TYPE_FREQ.keys())

def nearest_type(t: str) -> str:
    if not t or not ONTOLOGY: return t
    t = norm(t)
    cand = difflib.get_close_matches(t, ONTOLOGY, n=1, cutoff=0.8)
    return cand[0] if cand else t

# ---- H. (Optional) Upsample rare types to improve recall on tail classes ----
UPSAMPLE_RARE_TYPES = False           # set True to enable
TARGET_PER_TYPE     = None            # None → median; or set an integer

def _extract_types_from_target(y: str) -> list[str]:
    out=[]
    for ch in [c for c in EVENT_SEP.split(y) if c.strip()]:
        m = TYPE_RE.search(ch)
        if m:
            out.append(norm(m.group(1)))
    return out

if UPSAMPLE_RARE_TYPES:
    # convert to a Python list to rebalance, then back to Dataset
    rows = []
    for ex in train_clean:
        rows.append({"input": ex["input"], "output": ex["output"], "_types": list(set(_extract_types_from_target(ex["output"])))})
    from collections import defaultdict
    buckets = defaultdict(list)
    for r in rows:
        types = r["_types"] or ["_no_event"]
        for t in types:
            buckets[t].append({"input": r["input"], "output": r["output"]})

    # choose target size
    if TARGET_PER_TYPE is None:
        freqs = [len(v) for k,v in buckets.items() if k != "_no_event"]
        TARGET_PER_TYPE = int(sorted(freqs)[len(freqs)//2]) if freqs else 0

    aug = []
    for t, items in buckets.items():
        if t == "_no_event":
            aug.extend(items)  # keep as-is
            continue
        if len(items) >= TARGET_PER_TYPE:
            aug.extend(items)
        else:
            need = TARGET_PER_TYPE - len(items)
            aug.extend(items + random.choices(items, k=need))
    train_clean = Dataset.from_list(aug)

In [None]:
# -----------------------------
# 4) Prompt builder (now with retrieval-based few-shot; fallback to random)
# -----------------------------
# Prepare a bank of good few-shot candidates from *clean* train
def build_fewshot_bank(ds, k_keep=400):
    bank=[]
    for ex in ds:
        out = ex["output"]
        if (EVENT_TOKEN in out) and (len(out) < 600):
            src = ex["input"]
            bank.append((src.strip(), out.strip()))
    random.shuffle(bank)
    return bank[:k_keep]

FEWSHOT_BANK = build_fewshot_bank(train_clean, k_keep=400)

# Retrieval-based few-shot (TF-IDF); falls back to random if sklearn is missing
RETRIEVER = None
try:
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity

    _FS_TEXTS = [s for s,_ in FEWSHOT_BANK]
    _VEC = TfidfVectorizer(ngram_range=(1,2), min_df=2).fit(_FS_TEXTS)
    _X   = _VEC.transform(_FS_TEXTS)

    def _extract_types(y: str) -> set[str]:
        return set(_extract_types_from_target(y))

    def retriever_for(query: str, k=3):
        q = normalize_input_text(query)
        v = _VEC.transform([q])
        sims = cosine_similarity(v, _X).ravel()
        idx = sims.argsort()[::-1]
        chosen, seen_types = [], set()
        for i in idx:
            s, o = FEWSHOT_BANK[i]
            types_i = _extract_types(o)
            # encourage type diversity in the shots
            if types_i and (seen_types & types_i) == types_i:
                continue
            chosen.append((s, o))
            seen_types |= types_i
            if len(chosen) >= k:
                break
        if not chosen:  # fallback
            chosen = [FEWSHOT_BANK[i] for i in idx[:k]]
        return chosen

    RETRIEVER = retriever_for
except Exception:
    # No sklearn → fallback to random sampler
    def retriever_for(query: str, k=3):
        k = min(k, len(FEWSHOT_BANK))
        return random.sample(FEWSHOT_BANK, k) if k>0 else []
    RETRIEVER = retriever_for

# Rebuild the prompted datasets using the cleaned splits + retrieval few-shot
def build_prompt(sentence: str, k=3) -> str:
    head = (
        "Extract ALL events from the sentence below.\n"
        f"Output only lines like: {EVENT_TOKEN} Event type: <TYPE>. Trigger: <TRIGGER>.\n"
        "If no events, output exactly: No events.\n\n"
    )
    shots = RETRIEVER(sentence, k=k) if RETRIEVER else []
    if shots:
        head += "### Examples\n"
        for s, o in shots:
            head += f'Sentence: "{s}"\n{o}\n\n'
    head += "### Now extract\n"
    head += f'Sentence: "{sentence}"\nOutput:\n'
    return head

def chunk_text_to_sentences(text: str) -> list:
    return re.split(r'(?<=[\.\!\?])\s+', text.strip())

def chunk_doc(sentence_or_doc: str, n=3):
    sents = chunk_text_to_sentences(sentence_or_doc)
    if len(sents) <= n:
        return [sentence_or_doc]
    chunks=[]
    for i in range(0, len(sents), n):
        chunks.append(" ".join(sents[i:i+n]))
    return chunks

def to_prompted(ds, fewshot_k=3):
    def _map(ex):
        txt = ex["input"]
        if CHUNK_N_SENT:
            pieces = chunk_doc(txt, n=CHUNK_N_SENT)
            txt = pieces[0]
        return {"input": build_prompt(txt, k=fewshot_k), "output": ex["output"]}
    keep_cols = [c for c in ds.column_names if c not in ("input","output")]
    return ds.map(_map, remove_columns=keep_cols)

# final, prompted datasets (cleaned → prompted)
train_ds = to_prompted(train_clean, fewshot_k=3)
valid_ds = to_prompted(valid_clean, fewshot_k=3)
test_ds  = to_prompted(test_clean,  fewshot_k=3)

In [None]:
# -----------------------------
# 5) Robust resume: checkpoint extraction & tokenizer load
# -----------------------------
RESUME_DIR = f"{OUT_DIR}/checkpoint-7200"
def looks_like_checkpoint_dir(p: str) -> bool:
    if not os.path.isdir(p):
        return False
    files = set(os.listdir(p))
    has_model = any(f in files for f in (
        "pytorch_model.bin", "pytorch_model.bin.index.json",
        "model.safetensors", "model.safetensors.index.json"
    ))
    has_state = "trainer_state.json" in files
    return has_model and has_state

def print_tree(path: str, max_depth: int = 2):
    print(f"\n>>> Tree under: {path}")
    base_depth = path.rstrip("/").count("/")
    for root, dirs, files in os.walk(path):
        depth = root.count("/") - base_depth
        if depth > max_depth:
            continue
        show = [f for f in files if f.endswith((".json",".bin",".safetensors",".pt",".pth"))]
        if show:
            print("  -", root, "->", show)

def extract_archive_gently(archive_path: str, dest_dir: str):
    if zipfile.is_zipfile(archive_path):
        with zipfile.ZipFile(archive_path, "r") as zf:
            zf.extractall(dest_dir)
        print(f"[ok] Extracted zip into: {dest_dir}")
        return True
    try:
        if tarfile.is_tarfile(archive_path):
            with tarfile.open(archive_path, "r:*") as tf:
                tf.extractall(dest_dir)
            print(f"[ok] Extracted tar into: {dest_dir}")
            return True
    except tarfile.TarError:
        pass
    print(f"[warn] '{archive_path}' is not a valid zip/tar archive.")
    return False


archive_candidates = (
    glob.glob(os.path.join(OUT_DIR, "checkpoint-7200*")) +
    glob.glob("/content/checkpoint-7200*")
)
archive_candidates = [p for p in archive_candidates if os.path.isfile(p) and not p.endswith(".ipynb")]
archive_candidates.sort()
if archive_candidates and not looks_like_checkpoint_dir(RESUME_DIR):
    candidate = archive_candidates[-1]
    print(f"[info] Found checkpoint archive: {candidate}")
    extract_archive_gently(candidate, OUT_DIR)


if not looks_like_checkpoint_dir(RESUME_DIR):
    alts = [d for d in glob.glob(os.path.join(OUT_DIR, "checkpoint-7200*")) if os.path.isdir(d)]
    valid_alts = [d for d in alts if looks_like_checkpoint_dir(d)]
    if valid_alts:
        src = valid_alts[0]
        if src != RESUME_DIR:
            if os.path.exists(RESUME_DIR): shutil.rmtree(RESUME_DIR)
            os.rename(src, RESUME_DIR)
            print(f"[fix] Renamed '{src}' -> '{RESUME_DIR}'")

if not looks_like_checkpoint_dir(RESUME_DIR):
    needed = [
        "config.json", "trainer_state.json",
        "pytorch_model.bin", "pytorch_model.bin.index.json",
        "model.safetensors", "model.safetensors.index.json",
        "optimizer.pt", "scheduler.pt", "rng_state.pth", "scaler.pt",
        "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json",
        "added_tokens.json", "generation_config.json", "training_args.bin"
    ]
    present = [f for f in needed if os.path.exists(os.path.join(OUT_DIR, f))]
    if present:
        print(f"[fix] Creating '{RESUME_DIR}' and moving loose files into it…")
        os.makedirs(RESUME_DIR, exist_ok=True)
        for f in present:
            src = os.path.join(OUT_DIR, f)
            if os.path.exists(src):
                shutil.move(src, os.path.join(RESUME_DIR, f))

print_tree(OUT_DIR, max_depth=2)

# Load tokenizer (from checkpoint if present, else base)
try:
    if looks_like_checkpoint_dir(RESUME_DIR):
        tokenizer = AutoTokenizer.from_pretrained(RESUME_DIR)
        print("[tok] Loaded tokenizer from checkpoint.")
    else:
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        print("[tok] Loaded tokenizer from base model.")
except Exception as e:
    print(f"[tok] Fallback to base tokenizer due to: {e}")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

# Ensure EVENT_TOKEN exists (keeps vocab consistent if it was in the checkpoint)
_ = tokenizer.add_special_tokens({"additional_special_tokens":[EVENT_TOKEN]})


[info] Found checkpoint archive: /content/checkpoint-7200_20250910-125541.zip
[ok] Extracted zip into: /content/flan_t5_degree2_ckpt
[fix] Creating '/content/flan_t5_degree2_ckpt/checkpoint-7200' and moving loose files into it…

>>> Tree under: /content/flan_t5_degree2_ckpt
  - /content/flan_t5_degree2_ckpt/checkpoint-7200 -> ['added_tokens.json', 'model.safetensors', 'config.json', 'optimizer.pt', 'training_args.bin', 'generation_config.json', 'special_tokens_map.json', 'scheduler.pt', 'rng_state.pth', 'tokenizer.json', 'tokenizer_config.json', 'trainer_state.json']
[tok] Loaded tokenizer from checkpoint.


In [None]:
# -----------------------------
# 6) Save helpers + StopAtStep (same)
# -----------------------------
def get_latest_ckpt_path(out_dir):
    try:
        return get_last_checkpoint(out_dir)
    except Exception:
        return None

def save_final_snapshot(trainer, tokenizer, base_dir, tag=None):
    step = int(trainer.state.global_step) if trainer and trainer.state else 0
    tag  = tag or f"final-{step}"
    snap_dir = os.path.join(base_dir, f"checkpoint-{tag}")
    os.makedirs(snap_dir, exist_ok=True)
    trainer.save_model(snap_dir)
    if tokenizer is not None:
        tokenizer.save_pretrained(snap_dir)
    with open(os.path.join(base_dir, "LATEST.txt"), "w") as f:
        f.write(snap_dir)
    print(f"[Snapshot] Saved '{snap_dir}' and updated LATEST.txt")
    return snap_dir

class StopAtStepCallback(TrainerCallback):
    def __init__(self, target_step=None, out_dir=None, tokenizer=None):
        self.target_step = target_step
        self.out_dir = out_dir
        self.tokenizer = tokenizer
    def _read_target(self):
        env_val = os.getenv("STOP_AT_STEP", "").strip()
        if env_val.isdigit():
            return int(env_val)
        fpath = "/content/stop_at_step.txt"
        if os.path.exists(fpath):
            try:
                with open(fpath) as f:
                    val = f.read().strip()
                if val.isdigit():
                    return int(val)
            except:
                pass
        return self.target_step
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        tgt = self._read_target()
        if tgt is None: return
        if state.global_step >= tgt:
            trainer = kwargs.get("trainer", None)
            print(f"\n[StopAtStep] Reached step {state.global_step} (target={tgt}). Saving and stopping...")
            if trainer is not None and self.out_dir is not None:
                save_final_snapshot(trainer, self.tokenizer, self.out_dir, tag=f"stop-{state.global_step}")
            control.should_training_stop = True
            control.should_save = True
            return control

In [None]:
# -----------------------------
# 7) Load model (resume if possible; else fresh)
# -----------------------------
if looks_like_checkpoint_dir(RESUME_DIR):
    model = AutoModelForSeq2SeqLM.from_pretrained(RESUME_DIR)
    print(f"[model] Loaded from checkpoint: {RESUME_DIR}")
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
    print(f"[model] Loaded fresh from base: {BASE_MODEL}")

# Expand embeddings if we added special tokens
model.resize_token_embeddings(len(tokenizer))

# Gradient checkpointing parity with Code A
try:
    model.gradient_checkpointing_enable()
    if getattr(model.config, "use_cache", None):
        model.config.use_cache = False
except Exception:
    pass

[model] Loaded from checkpoint: /content/flan_t5_degree2_ckpt/checkpoint-7200


In [None]:
# -----------------------------
# 8) TrainingArguments (EXACT Code A hyper-params)
# -----------------------------
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

args = Seq2SeqTrainingArguments(
    output_dir=OUT_DIR,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=800,
    save_steps=800,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,

    num_train_epochs=10,
    max_steps=-1,
    learning_rate=1e-4,
    weight_decay=1e-5,
    lr_scheduler_type="constant",
    optim="adafactor",

    label_smoothing_factor=0.1,
    max_grad_norm=1.0,
    predict_with_generate=False,

    fp16=False,
    bf16=use_bf16,
    report_to="none",
    logging_steps=100,
    save_safetensors=True,
    seed=SEED, data_seed=SEED,
    remove_unused_columns=False,
)

In [None]:
# -----------------------------
# 9) Tokenize (text_target) → Collator (with model) → Trainer → Train/Resume
# -----------------------------
def preprocess(batch):
    enc = tokenizer(
        batch["input"],
        max_length=MAX_IN_LEN, truncation=True, padding="max_length"
    )
    tgt = tokenizer(
        text_target=batch["output"],
        max_length=MAX_OUT_LEN, truncation=True, padding="max_length"
    )
    enc["labels"] = tgt["input_ids"]
    return enc

train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
valid_tok = valid_ds.map(preprocess, batched=True, remove_columns=valid_ds.column_names)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="longest",
    label_pad_token_id=-100
)

stop_cb = StopAtStepCallback(target_step=None, out_dir=OUT_DIR, tokenizer=tokenizer)
callbacks = [EarlyStoppingCallback(early_stopping_patience=3), stop_cb]

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=valid_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=callbacks
)

# Optional auto-stop (delete to disable)
os.environ["STOP_AT_STEP"] = "20000"

resume_path = RESUME_DIR if looks_like_checkpoint_dir(RESUME_DIR) else None
if resume_path:
    print("Resuming EXACTLY from:", resume_path)
    # If you previously switched optim/scheduler types, you can delete optimizer/scheduler state here.
    trainer.train(resume_from_checkpoint=resume_path)
else:
    print("Starting fresh training…")
    trainer.train()

Map:   0%|          | 0/32430 [00:00<?, ? examples/s]

Map:   0%|          | 0/8015 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(


Resuming EXACTLY from: /content/flan_t5_degree2_ckpt/checkpoint-7200


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Step,Training Loss,Validation Loss
8000,1.3994,1.389901
8800,1.3983,1.389069
9600,1.3986,1.389795
10400,1.396,1.388636
11200,1.3969,1.388947
12000,1.3948,1.387686
12800,1.392,1.387259


KeyboardInterrupt: 

In [None]:
# -----------------------------
# 10) Save final + zip best (robust)
# -----------------------------
def _is_valid_checkpoint_dir(p: str) -> bool:
    if not os.path.isdir(p): return False
    files = set(os.listdir(p))
    has_model = any(f in files for f in (
        "pytorch_model.bin", "pytorch_model.bin.index.json",
        "model.safetensors", "model.safetensors.index.json"
    ))
    has_state = "trainer_state.json" in files
    return has_model and has_state

def _safe_zip_dir(src_dir: str, out_prefix: str) -> str:
    stamp = time.strftime("%Y%m%d-%H%M%S")
    base = f"/content/{out_prefix}_{stamp}"
    zip_path = shutil.make_archive(base, "zip", src_dir)
    try:
        with zipfile.ZipFile(zip_path, "r") as zf:
            bad = zf.testzip()
        if bad is not None:
            print(f"[warn] Zip integrity issue with file: {bad}")
    except zipfile.BadZipFile:
        print("[warn] Created file is not recognized as zip (BadZipFile).")
    size_gb = os.path.getsize(zip_path) / (1024**3)
    print(f"[zip] {zip_path}  (~{size_gb:.2f} GB)")
    return zip_path

final_snap = save_final_snapshot(trainer, tokenizer, OUT_DIR)
os.makedirs(FINAL_DIR, exist_ok=True)
trainer.save_model(FINAL_DIR); tokenizer.save_pretrained(FINAL_DIR)
print("Saved best model to:", FINAL_DIR)

def find_best_checkpoint(out_dir: str) -> str | None:
    state_path = os.path.join(out_dir, "trainer_state.json")
    best = None
    if os.path.exists(state_path):
        try:
            with open(state_path, "r") as f:
                st = json.load(f)
            cand = st.get("best_model_checkpoint", None)
            if cand and os.path.exists(cand) and _is_valid_checkpoint_dir(cand):
                best = cand
            if best is None:
                for rec in reversed(st.get("log_history", [])):
                    if isinstance(rec, dict) and "best_model_checkpoint" in rec:
                        cand = rec["best_model_checkpoint"]
                        if cand and os.path.exists(cand) and _is_valid_checkpoint_dir(cand):
                            best = cand; break
        except Exception as e:
            print(f"[warn] Could not parse trainer_state.json ({e}).")
    if best is None:
        try:
            latest = get_last_checkpoint(out_dir)
            if latest and _is_valid_checkpoint_dir(latest):
                best = latest
        except Exception:
            pass
    if best is None and final_snap and _is_valid_checkpoint_dir(final_snap):
        best = final_snap
    return best

def pick_checkpoint_to_zip(out_dir: str) -> str:
    ckpt = find_best_checkpoint(out_dir)
    if ckpt is None:
        raise FileNotFoundError("No valid checkpoint found to zip.")
    print(f"[info] Best checkpoint selected: {ckpt}")
    return ckpt

try:
    ckpt_dir = pick_checkpoint_to_zip(OUT_DIR)
    zip_path = _safe_zip_dir(ckpt_dir, out_prefix=Path(ckpt_dir).name)
    if files is not None:
        files.download(zip_path)
except Exception as e:
    print("[zip-best] skip:", e)

[Snapshot] Saved '/content/flan_t5_degree2_ckpt/checkpoint-final-12802' and updated LATEST.txt
Saved best model to: /content/flan_t5_degree2_final
[info] Best checkpoint selected: /content/flan_t5_degree2_ckpt/checkpoint-12800
[zip] /content/checkpoint-12800_20250911-104520.zip  (~0.86 GB)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# -----------------------------
# 11) Inference & Evaluation (beam or consensus + smart postproc)
# -----------------------------
import re
from difflib import SequenceMatcher

# ---------- Smart post-processing (optional) ----------
_STOPWORDS = {
    "the","a","an","of","in","on","at","to","for","from","by","with",
    "and","or","but","if","as","is","are","was","were","be","been","being",
    "this","that","these","those"
}

def _basic_lemma(word: str) -> str:
    w = word
    for suf in ("'s","’s"):
        if w.endswith(suf) and len(w) > len(suf)+1:
            w = w[:-len(suf)]
    for suf in ("ing","ed","es","s"):
        if w.endswith(suf) and len(w) > len(suf)+1:
            w = w[:-len(suf)]
            break
    return w

def normalize_trigger_str(tr: str) -> str:
    if not tr:
        return ""
    tr = tr.strip()
    tr = tr.replace("“","\"").replace("”","\"").replace("‘","'").replace("’","'")
    tr = tr.replace("–","-").replace("—","-")
    tr = tr.strip(" '\"`")
    tr = tr.lower()
    tr = re.sub(r"[^\w\s\-/]", " ", tr)   # keep letters/digits/space/- and /
    tr = re.sub(r"\s+", " ", tr).strip()
    if len(tr) <= 1:
        return ""
    toks = [t for t in tr.split() if t not in _STOPWORDS]
    toks = [_basic_lemma(t) for t in toks]
    tr2 = " ".join(toks).strip() or tr
    return tr2

def _token_jaccard(a: str, b: str) -> float:
    A, B = set(a.split()), set(b.split())
    if not A or not B:
        return 0.0
    inter = len(A & B); union = len(A | B)
    return inter/union if union else 0.0

def _char_sim(a: str, b: str) -> float:
    return SequenceMatcher(None, a, b).ratio()

def _near_duplicate(tr1: str, tr2: str, j_th=0.60, c_th=0.80) -> bool:
    if not tr1 or not tr2:
        return False
    return (_token_jaccard(tr1, tr2) >= j_th) or (_char_sim(tr1, tr2) >= c_th)

def merge_near_duplicates(pairs):
    """pairs: list[(etype, trigger)] → merge near-duplicate triggers per type (keep shorter)."""
    kept = []
    for et, tr in pairs:
        merged = False
        for i, (et2, tr2) in enumerate(kept):
            if et == et2 and _near_duplicate(tr, tr2):
                rep = tr if (0 < len(tr) <= len(tr2)) else tr2
                kept[i] = (et2, rep)
                merged = True
                break
        if not merged:
            kept.append((et, tr))
    return kept

# ---------- Canonicalizers ----------
def clean_and_canon_basic(text: str) -> str:
    text = dedup_events_str(text)
    parts=[]
    for ch in [c for c in EVENT_SEP.split(text) if c.strip()]:
        t = TYPE_RE.search(ch); g = TRIG_RE.search(ch)
        et = norm(t.group(1)) if t else None
        tr = norm(g.group(1)) if g else None
        if et:
            et = nearest_type(et)
        if et and tr:
            parts.append(f"Event type: {et}. Trigger: {tr}.")
    return f" {EVENT_TOKEN} ".join(parts) if parts else "No events."

def clean_and_canon_smart(text: str) -> str:
    """Parse -> normalize triggers -> nearest_type -> filter -> dedup-near -> rebuild."""
    raw = parse_pairs(text)
    norm_pairs = []
    for et, tr in raw:
        et_norm = nearest_type(et) if et else et
        tr_norm = normalize_trigger_str(tr)
        if not tr_norm:
            continue
        if tr_norm.isdigit():
            continue
        # At least 2 alnum chars
        if len(re.sub(r"[^a-z0-9]+", "", tr_norm)) < 2:
            continue
        norm_pairs.append((et_norm, tr_norm))
    if not norm_pairs:
        return "No events."
    merged = merge_near_duplicates(norm_pairs)
    seen = set(); final=[]
    for et,tr in merged:
        key=(et,tr)
        if key not in seen:
            seen.add(key); final.append((et,tr))
    if not final:
        return "No events."
    return f" {EVENT_TOKEN} ".join([f"Event type: {et}. Trigger: {tr}." for et,tr in final])

# ---------- Generator (beam or consensus sampling) ----------
def generate_batch(
    prompts: List[str],
    mdl,
    tok,
    bs: int = 8,
    device: str | None = None,
    strategy: str = "beam",          # "beam" | "consensus"
    samples: int = 3,                 # used when strategy="consensus"
    top_p: float = 0.9,
    temperature: float = 0.7,
    max_new_tokens: int = 160,
    min_new_tokens: int = 12,
    no_repeat_ngram_size: int = 3,
    repetition_penalty: float = 1.05,
    postproc: str = "basic"           # "basic" | "smart"
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mdl.eval().to(device)

    canon = clean_and_canon_smart if postproc == "smart" else clean_and_canon_basic
    outs = []

    with torch.no_grad():
        for i in range(0, len(prompts), bs):
            batch_prompts = prompts[i:i+bs]
            enc = tok(
                batch_prompts, return_tensors="pt",
                padding=True, truncation=True, max_length=MAX_IN_LEN
            ).to(device)

            if strategy == "beam":
                gen_kwargs = dict(
                    max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens,
                    num_beams=6, num_beam_groups=3, diversity_penalty=0.2,
                    no_repeat_ngram_size=no_repeat_ngram_size,
                    length_penalty=0.9, repetition_penalty=repetition_penalty,
                    early_stopping=False, trust_remote_code=True
                )
                gen = mdl.generate(**enc, **gen_kwargs)
                dec = tok.batch_decode(gen, skip_special_tokens=True)
                outs += [canon(t) for t in dec]

            else:  # consensus sampling
                all_decoded = []
                for _ in range(samples):
                    gen_kwargs = dict(
                        do_sample=True, top_p=top_p, temperature=temperature,
                        max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens,
                        no_repeat_ngram_size=no_repeat_ngram_size,
                        repetition_penalty=repetition_penalty,
                        early_stopping=False, trust_remote_code=True
                    )
                    gen = mdl.generate(**enc, **gen_kwargs)
                    dec = tok.batch_decode(gen, skip_special_tokens=True)
                    all_decoded.append(dec)

                # merge K samples per input, then canonicalize
                for k in range(len(batch_prompts)):
                    variants = [all_decoded[s][k] for s in range(samples)]
                    merged_text = f" {EVENT_TOKEN} ".join(variants)
                    outs.append(canon(merged_text))

    return outs

# ---------- Metrics (unchanged) ----------
def match_partial(preds, golds):
    used=set(); exact=partial=0
    for pt,ptr in preds:
        best=(0.0,-1)
        for j,(gt,gtr) in enumerate(golds):
            if j in used: continue
            if pt!=gt:   continue
            ov = trigger_overlap(ptr,gtr)
            if ov>best[0]: best=(ov,j)
        if best[1]!=-1:
            used.add(best[1])
            if math.isclose(best[0],1.0): exact+=1
            elif best[0]>=0.5: partial+=1
    return exact, partial, len(preds), len(golds)

def prf(e,p,pt,gt, w=0.5):
    wtp = e + w*p
    P = wtp/pt if pt else 0.0
    R = wtp/gt if gt else 0.0
    F = (2*P*R)/(P+R) if (P+R) else 0.0
    return P,R,F

def relaxed_recall_by_chunks(pairs_pred, pairs_gold, chunk_size=1):
    correct=partial=extra=possible=impossible=0
    for pp,gg in zip(pairs_pred, pairs_gold):
        ce,cp,pt,gt = match_partial(pp,gg)
        matched = ce + 0.5*cp
        possible += gt
        impossible += max(0, gt-1)
        extra += max(0.0, matched-1.0)
        correct += ce; partial += cp
    denom = max(1, possible - impossible)
    num   = max(0.0, (correct + 0.5*partial) - extra)
    return num/denom

def evaluate(
    ds, mdl, tok,
    strategy: str = "beam",      # "beam" | "consensus"
    postproc: str = "basic",     # "basic" | "smart"
    samples: int = 3,
    bs: int = 8,
    top_p: float = 0.9,
    temperature: float = 0.7
):
    prompts = [ex["input"] for ex in ds]
    gtexts  = [ex["output"] for ex in ds]
    preds   = generate_batch(
        prompts, mdl, tok, bs=bs,
        strategy=strategy, postproc=postproc,
        samples=samples, top_p=top_p, temperature=temperature
    )

    strict_tp=strict_pred=strict_gold=0
    part_e=part_p=part_pt=part_gt=0
    chunks_pred=[]; chunks_gold=[]

    for ptxt,gtxt in zip(preds, gtexts):
        pp = parse_pairs(ptxt)
        gg = parse_pairs(gtxt)
        sp,sg = set(pp), set(gg)
        tp = len(sp & sg)
        strict_tp += tp; strict_pred += len(sp); strict_gold += len(sg)
        ce,cp,pt,gt = match_partial(pp,gg)
        part_e += ce; part_p += cp; part_pt += pt; part_gt += gt
        chunks_pred.append(pp); chunks_gold.append(gg)

    sP = strict_tp/strict_pred if strict_pred else 0.0
    sR = strict_tp/strict_gold if strict_gold else 0.0
    sF = (2*sP*sR)/(sP+sR) if (sP+sR) else 0.0

    pP,pR,pF = prf(part_e,part_p,part_pt,part_gt, w=0.5)
    r_rel = relaxed_recall_by_chunks(chunks_pred, chunks_gold)

    print("\n===== STRICT =====")
    print(f"P={sP:.4f} R={sR:.4f} F1={sF:.4f}")
    print("===== PARTIAL (MUC 0.5) =====")
    print(f"P={pP:.4f} R={pR:.4f} F1={pF:.4f}")
    print("===== RELAXED (DEGREE2) =====")
    rF = (2*pP*r_rel)/(pP+r_rel) if (pP+r_rel)>0 else 0.0
    print(f"Relaxed-Recall={r_rel:.4f} | Relaxed-F1≈{rF:.4f}")
    return dict(strict_f1=sF, partial_f1=pF, relaxed_recall=r_rel, relaxed_f1=rF)

# ---------- Quick sanity + Eval with FINAL_DIR ----------
best_tok   = AutoTokenizer.from_pretrained(FINAL_DIR)
best_model = AutoModelForSeq2SeqLM.from_pretrained(FINAL_DIR)

print("\n--- VALID (beam + basic) ---")
evaluate(valid_ds, best_model, best_tok, strategy="beam", postproc="basic", bs=8)

print("\n--- TEST  (beam + basic) ---")
evaluate(test_ds,  best_model, best_tok, strategy="beam", postproc="basic", bs=8)

# To try recall-oriented mode:
# print("\n--- VALID (consensus + smart) ---")
# evaluate(valid_ds, best_model, best_tok, strategy="consensus", postproc="smart",samples=3, bs=16, top_p=0.9, temperature=0.7)

demo_sent = (
    "The hijacking of Lufthansa Flight 615 was an act of terrorism committed by a Palestinian group "
    "that occurred on 29 October 1972 and aimed at the liberation of the three surviving perpetrators "
    "of the Munich massacre from a West German prison."
)
demo_prompt = build_prompt(demo_sent, k=3)
print("\nDEMO (beam+basic):\n", generate_batch([demo_prompt], best_model, best_tok, strategy="beam", postproc="basic", bs=1)[0])
# print("\nDEMO (consensus+smart):\n", generate_batch([demo_prompt], best_model, best_tok, strategy="consensus", postproc="smart", samples=3, bs=1)[0])


--- VALID (beam + basic) ---


Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call.



===== STRICT =====
P=0.5966 R=0.2746 F1=0.3761
===== PARTIAL (MUC 0.5) =====
P=0.5986 R=0.2756 F1=0.3774
===== RELAXED (DEGREE2) =====
Relaxed-Recall=0.6490 | Relaxed-F1≈0.6228

--- TEST  (beam + basic) ---

===== STRICT =====
P=0.6019 R=0.1379 F1=0.2244
===== PARTIAL (MUC 0.5) =====
P=0.6037 R=0.1383 F1=0.2251
===== RELAXED (DEGREE2) =====
Relaxed-Recall=0.6127 | Relaxed-F1≈0.6082

DEMO (beam+basic):
 Event type: coming_to_be. Trigger: occurred.


In [None]:
# -----------------------------
# 11) Inference & Evaluation (hybrid beam+consensus with soft-smart postproc + reranker)
# -----------------------------
import re
from difflib import SequenceMatcher

# ======== Soft-smart post-processing ========
_STOPWORDS = {
    "the","a","an","of","in","on","at","to","for","from","by","with",
    "and","or","but","if","as","is","are","was","were","be","been","being",
    "this","that","these","those"
}

def _soft_stem(word: str) -> str:
    """Very light stemming; avoid over-truncating (keeps >=4 chars if possible)."""
    w = word
    for suf in ("'s","’s"):
        if w.endswith(suf) and len(w) > len(suf)+1:
            w = w[:-len(suf)]
    for suf in ("ing","ed","es","s"):
        if w.endswith(suf) and len(w) >= 5:  # <-- guard to prevent occurr/committ
            w = w[:-len(suf)]
            break
    return w

def normalize_trigger_str(tr: str) -> str:
    if not tr:
        return ""
    tr = tr.strip()
    tr = (tr.replace("“","\"").replace("”","\"")
            .replace("‘","'").replace("’","'")
            .replace("–","-").replace("—","-"))
    tr = tr.strip(" '\"`").lower()
    tr = re.sub(r"[^\w\s\-/]", " ", tr)
    tr = re.sub(r"\s+", " ", tr).strip()
    if len(tr) <= 1:
        return ""
    toks = [t for t in tr.split() if t not in _STOPWORDS]
    toks = [_soft_stem(t) for t in toks]
    tr2 = " ".join(toks).strip() or tr
    return tr2

def _token_jaccard(a: str, b: str) -> float:
    A, B = set(a.split()), set(b.split())
    if not A or not B:
        return 0.0
    inter = len(A & B); union = len(A | B)
    return inter/union if union else 0.0

def _char_sim(a: str, b: str) -> float:
    return SequenceMatcher(None, a, b).ratio()

def _near_duplicate(tr1: str, tr2: str, j_th=0.55, c_th=0.78) -> bool:
    if not tr1 or not tr2:
        return False
    return (_token_jaccard(tr1, tr2) >= j_th) or (_char_sim(tr1, tr2) >= c_th)

def merge_near_duplicates(pairs):
    """pairs: list[(etype, trigger)] → merge near-duplicate triggers per type (keep better-shaped/shorter)."""
    kept = []
    for et, tr in pairs:
        merged = False
        for i, (et2, tr2) in enumerate(kept):
            if et == et2 and _near_duplicate(tr, tr2):
                # choose representative by: longer >=4 preferred (readability) but avoid multiword bloat
                def _score_shape(s):
                    letters = len(re.sub(r"[^a-z0-9]+","", s))
                    tokens  = len(s.split())
                    return letters - 0.5 * max(0, tokens-1)
                rep = tr if _score_shape(tr) >= _score_shape(tr2) else tr2
                kept[i] = (et2, rep)
                merged = True
                break
        if not merged:
            kept.append((et, tr))
    return kept

# Two canonicalizers: basic (Code A parity) and soft-smart
def clean_and_canon_basic(text: str) -> str:
    text = dedup_events_str(text)
    parts=[]
    for ch in [c for c in EVENT_SEP.split(text) if c.strip()]:
        t = TYPE_RE.search(ch); g = TRIG_RE.search(ch)
        et = norm(t.group(1)) if t else None
        tr = norm(g.group(1)) if g else None
        if et:
            et = nearest_type(et)
        if et and tr:
            parts.append(f"Event type: {et}. Trigger: {tr}.")
    return f" {EVENT_TOKEN} ".join(parts) if parts else "No events."

def clean_and_canon_softsmart(text: str) -> str:
    raw = parse_pairs(text)
    norm_pairs = []
    for et, tr in raw:
        et_norm = nearest_type(et) if et else et
        tr_norm = normalize_trigger_str(tr)
        if not tr_norm:
            continue
        if tr_norm.isdigit():
            continue
        if len(re.sub(r"[^a-z0-9]+", "", tr_norm)) < 2:
            continue
        norm_pairs.append((et_norm, tr_norm))
    if not norm_pairs:
        return "No events."
    merged = merge_near_duplicates(norm_pairs)
    seen = set(); final=[]
    for et,tr in merged:
        key=(et,tr)
        if key not in seen:
            seen.add(key); final.append((et,tr))
    if not final:
        return "No events."
    return f" {EVENT_TOKEN} ".join([f"Event type: {et}. Trigger: {tr}." for et,tr in final])

# ======== Helpers for reranking / scoring ========
def _anchor_score(src: str, trig: str) -> float:
    """How strongly 'trig' appears in 'src' (0..1). Fast Jaccard/char mixture."""
    s = norm(src)
    t = norm(trig)
    if not s or not t: return 0.0
    if t in s: return 1.0
    # character similarity on windows around occurrences of first token
    toks = t.split()
    if not toks: return 0.0
    probe = toks[0]
    best = 0.0
    # quick scan
    for m in re.finditer(re.escape(probe), s):
        i = m.start()
        span = s[max(0, i-20):i+len(t)+20]
        best = max(best, _char_sim(span, t))
        if best >= 0.92: break
    # token Jaccard as tie-breaker
    best = max(best, _token_jaccard(s, t))
    return float(best)

def _type_prior(et: str) -> float:
    # TYPE_FREQ built earlier from clean train
    return float(TYPE_FREQ.get(et, 0)) ** 0.5  # sublinear

def _shape_bonus(tr: str) -> float:
    letters = len(re.sub(r"[^a-z0-9]+","", tr))
    if letters <= 2: return -0.5
    if letters <= 4: return 0.0
    return 0.1

def _score_event(src: str, et: str, tr: str) -> float:
    """Weighted score; tune-able."""
    a = _anchor_score(src, tr)             # 0..1
    p = _type_prior(et)                    # 0..sqrt(freq)
    sh= _shape_bonus(tr)                   # small +/-
    return 0.70*a + 0.25*(p/(p+1.0)) + 0.05*sh

def _parse_pairs_canon(text: str, canonizer) -> list[tuple[str,str]]:
    txt = canonizer(text)
    return parse_pairs(txt)

# ======== Generator variants ========
def _generate_once(prompts, mdl, tok, device, **gen_kwargs):
    enc = tok(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_IN_LEN).to(device)
    gen = mdl.generate(**enc, trust_remote_code=True, **gen_kwargs)
    return tok.batch_decode(gen, skip_special_tokens=True)

def _beam_texts(prompts, mdl, tok, device, **kwargs):
    gk = dict(
        max_new_tokens=kwargs.get("max_new_tokens", 160),
        min_new_tokens=kwargs.get("min_new_tokens", 8),
        num_beams=8, num_beam_groups=4, diversity_penalty=0.2,
        no_repeat_ngram_size=4,
        length_penalty=0.8, repetition_penalty=1.03,
        early_stopping=False
    )
    return _generate_once(prompts, mdl, tok, device, **gk)

def _sample_texts(prompts, mdl, tok, device, **kwargs):
    gk = dict(
        do_sample=True, top_p=kwargs.get("top_p", 0.92), temperature=kwargs.get("temperature", 0.6),
        max_new_tokens=kwargs.get("max_new_tokens", 160), min_new_tokens=kwargs.get("min_new_tokens", 8),
        no_repeat_ngram_size=3, repetition_penalty=1.02, early_stopping=False
    )
    return _generate_once(prompts, mdl, tok, device, **gk)

# ======== Hybrid generator with reranking ========
def generate_batch(
    prompts: List[str],
    mdl, tok,
    bs: int = 8,
    device: str | None = None,
    strategy: str = "hybrid",        # "beam" | "consensus" | "hybrid"
    samples: int = 3,                # for sampling/consensus
    postproc: str = "softsmart",     # "basic" | "softsmart"
    top_p: float = 0.92,
    temperature: float = 0.6,
    max_new_tokens: int = 160,
    min_new_tokens: int = 8,
    max_events_total: int = 6,       # cap to prevent explosion
    max_events_per_type: int = 3,    # per-type cap
    tau_frac: float = 0.62,          # relative threshold
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mdl.eval().to(device)
    canonizer = clean_and_canon_softsmart if postproc == "softsmart" else clean_and_canon_basic

    outs = []
    with torch.no_grad():
        for i in range(0, len(prompts), bs):
            batch = prompts[i:i+bs]

            # 1) BEAM baseline
            beam_dec = _beam_texts(batch, mdl, tok, device,
                                   max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens)
            beam_pairs = [_parse_pairs_canon(t, canonizer) for t in beam_dec]

            if strategy == "beam":
                # Beam + postproc only
                outs.extend([canonizer(t) for t in beam_dec])
                continue

            # 2) CONSENSUS candidates (K samples)
            all_samples = []
            for _ in range(samples):
                dec = _sample_texts(batch, mdl, tok, device,
                                    top_p=top_p, temperature=temperature,
                                    max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens)
                all_samples.append([_parse_pairs_canon(t, canonizer) for t in dec])

            # 3) Fuse per item with reranking + fallback
            for k in range(len(batch)):
                src_prompt = batch[k]
                # start from beam pairs
                pool = list(beam_pairs[k])

                # add consensus variants (merged by near-duplicate)
                # gather raw candidates
                cands = []
                for s_id in range(samples):
                    cands.extend(all_samples[s_id][k])

                # fuzzy merge across candidates
                merged = []
                for et,tr in cands:
                    matched = False
                    for i2,(et2,tr2) in enumerate(merged):
                        if et==et2 and _near_duplicate(tr,tr2):
                            # keep better-shaped
                            def _shape(s):
                                letters = len(re.sub(r"[^a-z0-9]+","", s))
                                toks = len(s.split())
                                return letters - 0.5*max(0,toks-1)
                            rep = tr if _shape(tr)>=_shape(tr2) else tr2
                            merged[i2]=(et2,rep)
                            matched=True; break
                    if not matched:
                        merged.append((et,tr))

                # add merged consensus to pool if not exact dup of beam
                for et,tr in merged:
                    if (et,tr) not in pool:
                        pool.append((et,tr))

                # 4) Score & select with adaptive threshold + caps
                # score each (et,tr)
                scored = []
                for et,tr in pool:
                    s = _score_event(src_prompt, et, tr)
                    scored.append((s, et, tr))
                if not scored:
                    outs.append("No events.")
                    continue

                m = max(s for s,_,_ in scored)
                tau = tau_frac * m
                # keep only above threshold
                keep = [(s,et,tr) for (s,et,tr) in sorted(scored, key=lambda x: x[0], reverse=True) if s >= tau]

                # per-type cap
                per_type_count = {}
                final_pairs=[]
                for s,et,tr in keep:
                    if len(final_pairs) >= max_events_total:
                        break
                    c = per_type_count.get(et,0)
                    if c >= max_events_per_type:
                        continue
                    per_type_count[et]=c+1
                    final_pairs.append((et,tr))

                # 5) If everything got pruned, fall back to beam only
                if not final_pairs and beam_pairs[k]:
                    final_pairs = beam_pairs[k]

                if not final_pairs:
                    outs.append("No events.")
                else:
                    parts = [f"Event type: {et}. Trigger: {tr}." for et,tr in final_pairs]
                    outs.append(f" {EVENT_TOKEN} ".join(parts))

    return outs

# ======== Metrics (unchanged) ========
def match_partial(preds, golds):
    used=set(); exact=partial=0
    for pt,ptr in preds:
        best=(0.0,-1)
        for j,(gt,gtr) in enumerate(golds):
            if j in used: continue
            if pt!=gt:   continue
            ov = trigger_overlap(ptr,gtr)
            if ov>best[0]: best=(ov,j)
        if best[1]!=-1:
            used.add(best[1])
            if math.isclose(best[0],1.0): exact+=1
            elif best[0]>=0.5: partial+=1
    return exact, partial, len(preds), len(golds)

def prf(e,p,pt,gt, w=0.5):
    wtp = e + w*p
    P = wtp/pt if pt else 0.0
    R = wtp/gt if gt else 0.0
    F = (2*P*R)/(P+R) if (P+R) else 0.0
    return P,R,F

def relaxed_recall_by_chunks(pairs_pred, pairs_gold, chunk_size=1):
    correct=partial=extra=possible=impossible=0
    for pp,gg in zip(pairs_pred, pairs_gold):
        ce,cp,pt,gt = match_partial(pp,gg)
        matched = ce + 0.5*cp
        possible += gt
        impossible += max(0, gt-1)
        extra += max(0.0, matched-1.0)
        correct += ce; partial += cp
    denom = max(1, possible - impossible)
    num   = max(0.0, (correct + 0.5*partial) - extra)
    return num/denom

def evaluate(
    ds, mdl, tok,
    strategy: str = "hybrid",      # "beam" | "consensus" | "hybrid"
    postproc: str = "softsmart",   # "basic" | "softsmart"
    samples: int = 3,
    bs: int = 8,
    top_p: float = 0.92,
    temperature: float = 0.6
):
    prompts = [ex["input"] for ex in ds]
    gtexts  = [ex["output"] for ex in ds]
    preds   = generate_batch(
        prompts, mdl, tok, bs=bs,
        strategy=strategy, postproc=postproc,
        samples=samples, top_p=top_p, temperature=temperature
    )

    strict_tp=strict_pred=strict_gold=0
    part_e=part_p=part_pt=part_gt=0
    chunks_pred=[]; chunks_gold=[]

    for ptxt,gtxt in zip(preds, gtexts):
        pp = parse_pairs(ptxt)
        gg = parse_pairs(gtxt)
        sp,sg = set(pp), set(gg)
        tp = len(sp & sg)
        strict_tp += tp; strict_pred += len(sp); strict_gold += len(sg)
        ce,cp,pt,gt = match_partial(pp,gg)
        part_e += ce; part_p += cp; part_pt += pt; part_gt += gt
        chunks_pred.append(pp); chunks_gold.append(gg)

    sP = strict_tp/strict_pred if strict_pred else 0.0
    sR = strict_tp/strict_gold if strict_gold else 0.0
    sF = (2*sP*sR)/(sP+sR) if (sP+sR) else 0.0

    pP,pR,pF = prf(part_e,part_p,part_pt,part_gt, w=0.5)
    r_rel = relaxed_recall_by_chunks(chunks_pred, chunks_gold)

    print("\n===== STRICT =====")
    print(f"P={sP:.4f} R={sR:.4f} F1={sF:.4f}")
    print("===== PARTIAL (MUC 0.5) =====")
    print(f"P={pP:.4f} R={pR:.4f} F1={pF:.4f}")
    print("===== RELAXED (DEGREE2) =====")
    rF = (2*pP*r_rel)/(pP+r_rel) if (pP+r_rel)>0 else 0.0
    print(f"Relaxed-Recall={r_rel:.4f} | Relaxed-F1≈{rF:.4f}")
    return dict(strict_f1=sF, partial_f1=pF, relaxed_recall=r_rel, relaxed_f1=rF)

# ======== Quick sanity + Eval with FINAL_DIR ========
best_tok   = AutoTokenizer.from_pretrained(FINAL_DIR)
best_model = AutoModelForSeq2SeqLM.from_pretrained(FINAL_DIR)

#print("\n--- VALID (hybrid + softsmart) ---")
#evaluate(valid_ds, best_model, best_tok, strategy="hybrid", postproc="softsmart", bs=8)

#print("\n--- TEST  (hybrid + softsmart) ---")
#evaluate(test_ds,  best_model, best_tok, strategy="hybrid", postproc="softsmart", bs=8)

# Optional: compare with beam-only (parity with your previous baseline)
print("\n--- VALID (beam + basic) ---")
evaluate(valid_ds, best_model, best_tok, strategy="beam", postproc="basic", bs=8)

print("\n--- TEST (beam + basic) ---")
evaluate(test_ds, best_model, best_tok, strategy="beam", postproc="basic", bs=8)


--- VALID (hybrid + softsmart) ---


KeyboardInterrupt: 

In [None]:
# ==============================
# 11) Inference++ (Hybrid + Vote + Anchor-aware Rerank) + Tiny Tuner
# ==============================
import re, math
from difflib import SequenceMatcher

# --------- Utilities ----------
_STOPWORDS = {
    "the","a","an","of","in","on","at","to","for","from","by","with",
    "and","or","but","if","as","is","are","was","were","be","been","being",
    "this","that","these","those"
}

def _soft_stem(w: str) -> str:
    for suf in ("'s","’s"):
        if w.endswith(suf) and len(w) > len(suf)+1:
            w = w[:-len(suf)]
    for suf in ("ing","ed","es","s"):
        if w.endswith(suf) and len(w) >= 5:
            w = w[:-len(suf)]
            break
    return w

def normalize_trigger_str(tr: str) -> str:
    if not tr: return ""
    tr = tr.strip().replace("“","\"").replace("”","\"").replace("‘","'").replace("’","'")
    tr = tr.replace("–","-").replace("—","-").strip(" '\"`").lower()
    tr = re.sub(r"[^\w\s\-/]", " ", tr)
    tr = re.sub(r"\s+"," ", tr).strip()
    if len(tr) <= 1: return ""
    toks = [t for t in tr.split() if t not in _STOPWORDS]
    toks = [_soft_stem(t) for t in toks]
    tr2  = " ".join(toks).strip() or tr

    if len(re.sub(r"[^a-z0-9]+","", tr2)) < 2:
        return ""

    if tr2.isdigit():
        return ""
    return tr2

def _token_jaccard(a: str, b: str) -> float:
    A, B = set(a.split()), set(b.split())
    if not A or not B: return 0.0
    inter = len(A & B); union = len(A | B)
    return inter/union if union else 0.0

def _char_sim(a: str, b: str) -> float:
    return SequenceMatcher(None, a, b).ratio()

def _near_dup(tr1: str, tr2: str, j=0.55, c=0.78) -> bool:
    if not tr1 or not tr2: return False
    return _token_jaccard(tr1, tr2) >= j or _char_sim(tr1, tr2) >= c

def _merge_near_dups(pairs):
    kept=[]
    for et,tr in pairs:
        placed=False
        for i,(et2,tr2) in enumerate(kept):
            if et==et2 and _near_dup(tr,tr2):
                def _shape(s):
                    letters = len(re.sub(r"[^a-z0-9]+","", s))
                    toks = len(s.split())
                    return letters - 0.5*max(0,toks-1)
                kept[i]=(et, tr if _shape(tr)>=_shape(tr2) else tr2)
                placed=True; break
        if not placed:
            kept.append((et,tr))
    return kept

def _canonize(text: str):
    """soft-smart canonizer"""
    raw = parse_pairs(text)
    norm_pairs=[]
    for et,tr in raw:
        etn = nearest_type(et) if et else et
        trn = normalize_trigger_str(tr)
        if etn and trn:
            norm_pairs.append((etn,trn))
    if not norm_pairs: return []
    return _merge_near_dups(norm_pairs)


_SRC_RE = re.compile(r'Sentence:\s*"(.+?)"\s*?\nOutput:', re.DOTALL|re.IGNORECASE)
def recover_src_from_prompt(prompt: str) -> str:
    m = _SRC_RE.search(prompt)
    if m:
        return m.group(1).strip()
    # fallback: شاید ساختار فرق داشته باشد
    m2 = re.search(r'Sentence:\s*"(.+?)"', prompt, re.DOTALL|re.IGNORECASE)
    return m2.group(1).strip() if m2 else prompt

# --------- Anchor-aware scoring ----------
def _anchor_score(src: str, trig: str) -> float:
    s = norm(src); t = norm(trig)
    if not s or not t: return 0.0
    if t in s: return 1.0
    toks = t.split()
    if not toks: return 0.0
    probe = toks[0]
    best = 0.0
    for m in re.finditer(re.escape(probe), s):
        i = m.start()
        span = s[max(0, i-24):i+len(t)+24]
        best = max(best, _char_sim(span, t))
        if best >= 0.93: break
    best = max(best, _token_jaccard(s, t))
    return float(best)

def _type_prior(et: str) -> float:
    return float(TYPE_FREQ.get(et,0))**0.5

def _shape_bonus(tr: str) -> float:
    letters = len(re.sub(r"[^a-z0-9]+","", tr))
    if letters <= 2: return -0.4
    if letters <= 4: return 0.0
    return 0.1

def _score_event(src: str, et: str, tr: str, w_anchor=0.70, w_prior=0.25, w_shape=0.05) -> float:
    a  = _anchor_score(src, tr)                     # 0..1
    pr = _type_prior(et)                            # 0..sqrt(freq)
    pr = pr/(pr+1.0)
    sh = _shape_bonus(tr)
    return w_anchor*a + w_prior*pr + w_shape*sh

# --------- Text generation backends ----------
def _generate_once(prompts, mdl, tok, device, **gen_kwargs):
    enc = tok(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_IN_LEN).to(device)
    gen = mdl.generate(**enc, trust_remote_code=True, **gen_kwargs)
    return tok.batch_decode(gen, skip_special_tokens=True)

def _beam_texts(prompts, mdl, tok, device,
                max_new_tokens=160, min_new_tokens=8,
                num_beams=8, num_beam_groups=4, no_repeat_ngram_size=4,
                length_penalty=0.85, repetition_penalty=1.03, diversity_penalty=0.2):
    gk = dict(max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens,
              num_beams=num_beams, num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty,
              no_repeat_ngram_size=no_repeat_ngram_size, length_penalty=length_penalty,
              repetition_penalty=repetition_penalty, early_stopping=False)
    return _generate_once(prompts, mdl, tok, device, **gk)

def _sample_texts(prompts, mdl, tok, device,
                  top_p=0.92, temperature=0.6, max_new_tokens=160, min_new_tokens=8,
                  no_repeat_ngram_size=3, repetition_penalty=1.02):
    gk = dict(do_sample=True, top_p=top_p, temperature=temperature,
              max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens,
              no_repeat_ngram_size=no_repeat_ngram_size, repetition_penalty=repetition_penalty,
              early_stopping=False)
    return _generate_once(prompts, mdl, tok, device, **gk)

# --------- Hybrid + Vote + Rerank ----------
def generate_batch_plus(
    prompts: List[str],
    mdl, tok,
    bs: int = 8,
    device: str | None = None,
    samples: int = 3,
    vote_k: int = 2,
    tau_frac: float = 0.60,
    max_events_total: int = 6,
    max_events_per_type: int = 3,
    top_p: float = 0.92,
    temperature: float = 0.6,
    beam_conf: dict | None = None,
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mdl.eval().to(device)
    outs=[]

    beam_conf = beam_conf or {}
    with torch.no_grad():
        for i in range(0, len(prompts), bs):
            batch = prompts[i:i+bs]
            # BEAM
            beam_dec = _beam_texts(batch, mdl, tok, device, **beam_conf)
            beam_pairs = [_canonize(t) for t in beam_dec]

            # CONSENSUS samples
            all_samples = []
            for _ in range(samples):
                dec = _sample_texts(batch, mdl, tok, device, top_p=top_p, temperature=temperature)
                all_samples.append([_canonize(t) for t in dec])

            # Fuse per item
            for k in range(len(batch)):
                votes = {}
                def add_vote(pairs):
                    for (et,tr) in pairs:
                        votes[(et,tr)] = votes.get((et,tr), 0) + 1

                add_vote(beam_pairs[k])
                for s_id in range(samples):
                    add_vote(all_samples[s_id][k])

                cand = [(et,tr, votes[(et,tr)]) for (et,tr) in votes if votes[(et,tr)] >= vote_k]

                if not cand:
                    keep = beam_pairs[k]
                else:
                    src_text = recover_src_from_prompt(batch[k])
                    scored=[]
                    for et,tr,v in cand:
                        s = _score_event(src_text, et, tr)
                        s += 0.02 * max(0, v-1)
                        scored.append((s, et, tr))
                    m = max(s for s,_,_ in scored) if scored else 0.0
                    tau = tau_frac * m
                    keep = [(et,tr) for (s,et,tr) in sorted(scored, key=lambda x: x[0], reverse=True) if s >= tau]


                if keep:
                    per_type={}
                    final=[]
                    for et,tr in keep:
                        if len(final) >= max_events_total: break
                        c = per_type.get(et,0)
                        if c >= max_events_per_type: continue
                        per_type[et]=c+1
                        final.append((et,tr))
                else:
                    final=[]

                if not final:
                    outs.append("No events.")
                else:
                    outs.append(" {} ".format(EVENT_TOKEN).join([f"Event type: {et}. Trigger: {tr}." for et,tr in final]))
    return outs

# --------- Metrics & Evaluate ----------
def match_partial(preds, golds):
    used=set(); exact=partial=0
    for pt,ptr in preds:
        best=(0.0,-1)
        for j,(gt,gtr) in enumerate(golds):
            if j in used: continue
            if pt!=gt:   continue
            ov = trigger_overlap(ptr,gtr)
            if ov>best[0]: best=(ov,j)
        if best[1]!=-1:
            used.add(best[1])
            if math.isclose(best[0],1.0): exact+=1
            elif best[0]>=0.5: partial+=1
    return exact, partial, len(preds), len(golds)

def prf(e,p,pt,gt, w=0.5):
    wtp = e + w*p
    P = wtp/pt if pt else 0.0
    R = wtp/gt if gt else 0.0
    F = (2*P*R)/(P+R) if (P+R) else 0.0
    return P,R,F

def relaxed_recall_by_chunks(pairs_pred, pairs_gold, chunk_size=1):
    correct=partial=extra=possible=impossible=0
    for pp,gg in zip(pairs_pred, pairs_gold):
        ce,cp,pt,gt = match_partial(pp,gg)
        matched = ce + 0.5*cp
        possible += gt
        impossible += max(0, gt-1)
        extra += max(0.0, matched-1.0)
        correct += ce; partial += cp
    denom = max(1, possible - impossible)
    num   = max(0.0, (correct + 0.5*partial) - extra)
    return num/denom

def evaluate_plus(
    ds, mdl, tok,
    bs: int = 8,
    samples: int = 3,
    vote_k: int = 2,
    tau_frac: float = 0.60,
    top_p: float = 0.92,
    temperature: float = 0.6,
    beam_conf: dict | None = None
):
    prompts = [ex["input"] for ex in ds]
    preds   = generate_batch_plus(
        prompts, mdl, tok, bs=bs,
        samples=samples, vote_k=vote_k, tau_frac=tau_frac,
        top_p=top_p, temperature=temperature,
        beam_conf=beam_conf
    )
    gtexts  = [ex["output"] for ex in ds]

    strict_tp=strict_pred=strict_gold=0
    part_e=part_p=part_pt=part_gt=0
    chunks_pred=[]; chunks_gold=[]
    for ptxt,gtxt in zip(preds, gtexts):
        pp = parse_pairs(ptxt)
        gg = parse_pairs(gtxt)
        sp,sg = set(pp), set(gg)
        tp = len(sp & sg)
        strict_tp += tp; strict_pred += len(sp); strict_gold += len(sg)
        ce,cp,pt,gt = match_partial(pp,gg)
        part_e += ce; part_p += cp; part_pt += pt; part_gt += gt
        chunks_pred.append(pp); chunks_gold.append(gg)
    sP = strict_tp/strict_pred if strict_pred else 0.0
    sR = strict_tp/strict_gold if strict_gold else 0.0
    sF = (2*sP*sR)/(sP+sR) if (sP+sR) else 0.0
    pP,pR,pF = prf(part_e,part_p,part_pt,part_gt, w=0.5)
    r_rel = relaxed_recall_by_chunks(chunks_pred, chunks_gold)
    rF = (2*pP*r_rel)/(pP+r_rel) if (pP+r_rel)>0 else 0.0

    print("\n===== STRICT =====")
    print(f"P={sP:.4f} R={sR:.4f} F1={sF:.4f}")
    print("===== PARTIAL (MUC 0.5) =====")
    print(f"P={pP:.4f} R={pR:.4f} F1={pF:.4f}")
    print("===== RELAXED (DEGREE2) =====")
    print(f"Relaxed-Recall={r_rel:.4f} | Relaxed-F1≈{rF:.4f}")
    return dict(strict_f1=sF, partial_f1=pF, relaxed_recall=r_rel, relaxed_f1=rF)

# --------- Tiny tuner on VALID ----------
def tiny_tune_on_valid(mdl, tok, valid_ds):
    grid_tau   = [0.58, 0.60, 0.62]
    grid_vote  = [2, 3]
    grid_tpT   = [(0.92,0.6), (0.90,0.7)]   # (top_p, temperature)
    best = (-1, None)
    for tau in grid_tau:
        for vk in grid_vote:
            for top_p, temp in grid_tpT:
                print(f"[tune] try tau={tau}, vote_k={vk}, top_p={top_p}, T={temp}")
                m = evaluate_plus(valid_ds, mdl, tok, bs=16, samples=3, vote_k=vk,
                                  tau_frac=tau, top_p=top_p, temperature=temp)
                key = m["relaxed_f1"]
                if key > best[0]:
                    best = (key, dict(tau_frac=tau, vote_k=vk, top_p=top_p, temperature=temp))
    print(f"[tune] best on VALID: relaxed_F1={best[0]:.4f} with {best[1]}")
    return best[1] or dict(tau_frac=0.60, vote_k=2, top_p=0.92, temperature=0.6)

# --------- Run ----------
best_tok   = AutoTokenizer.from_pretrained(FINAL_DIR)
best_model = AutoModelForSeq2SeqLM.from_pretrained(FINAL_DIR)

print("\n--- TUNE on VALID (hybrid+vote+rerank) ---")
best_params = tiny_tune_on_valid(best_model, best_tok, valid_ds)

print("\n--- VALID (best tuned) ---")
evaluate_plus(valid_ds, best_model, best_tok, **best_params)

print("\n--- TEST (best tuned) ---")
evaluate_plus(test_ds,  best_model, best_tok, **best_params)



--- TUNE on VALID (hybrid+vote+rerank) ---
[tune] try tau=0.58, vote_k=2, top_p=0.92, T=0.6


KeyboardInterrupt: 