In [None]:
# In root/notebooks/fulltrain_analysis.ipynb

from pathlib import Path
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

BASE = Path("../experiments/cls_collapse")
RUN_25 = "full_full_cls_w0.25"   # 25-epoch run you just trained
RUN_5  = "full_full_cls_w0.25"   # (optional) change to your older 5-epoch run name if different

def read_json(p: Path):
    return json.loads(p.read_text()) if p.exists() else {}

def read_csv(p: Path):
    return pd.read_csv(p) if p.exists() else pd.DataFrame()

def read_jsonl(p: Path):
    if not p.exists():
        return pd.DataFrame()
    rows = []
    for line in p.read_text().splitlines():
        line = line.strip()
        if line:
            rows.append(json.loads(line))
    return pd.DataFrame(rows)

def load_run(run_name: str):
    d = BASE / run_name
    return {
        "name": run_name,
        "dir": d,
        "meta": read_json(d / "meta.json"),
        "opt": read_json(d / "optimizer.json"),
        "steps": read_csv(d / "step_losses.csv"),
        "epochs": read_csv(d / "epoch_summary.csv"),
        "dbg": read_jsonl(d / "debug.jsonl"),
    }

run = load_run(RUN_25)
run["dir"], run["steps"].shape, run["epochs"].shape


In [None]:
# 1) Final-epoch summary + quick collapse flags
epochs = run["epochs"].copy()
epochs = epochs.sort_values("epoch") if "epoch" in epochs.columns else epochs
last = epochs.iloc[-1].to_dict() if len(epochs) else {}

summary = {
    "run": run["name"],
    "epochs_logged": int(epochs["epoch"].max()) + 1 if len(epochs) and "epoch" in epochs.columns else None,
    "cls_max_mean_last": last.get("cls_max_mean"),
    "cls_max_p95_last": last.get("cls_max_p95"),
    "mask_max_mean_last": last.get("mask_max_mean"),
    "mask_max_p95_last": last.get("mask_max_p95"),
    "img_forged_mean_last": last.get("img_forged_mean"),
    "img_forged_p95_last": last.get("img_forged_p95"),
}

summary, {
    "cls_collapsed@~0.125": (summary["cls_max_mean_last"] is not None) and abs(summary["cls_max_mean_last"] - 0.125) < 1e-3,
    "mask_saturated@>0.99": (summary["mask_max_mean_last"] is not None) and summary["mask_max_mean_last"] > 0.99,
}


In [None]:
# 2) Learning curves (step-level)
steps = run["steps"].copy()
steps = steps.sort_values("global_step") if "global_step" in steps.columns else steps

def plot_step(metric):
    if metric not in steps.columns or "global_step" not in steps.columns:
        print(f"missing {metric} or global_step in step_losses.csv")
        return
    plt.figure()
    plt.plot(steps["global_step"], steps[metric])
    plt.xlabel("global_step")
    plt.ylabel(metric)
    plt.title(f"{run['name']} — {metric}")
    plt.grid(True)
    plt.show()

for m in ["loss_total", "loss_mask_cls", "loss_auth_penalty"]:
    plot_step(m)


In [None]:
# 3) Epoch curves (collapse indicators over time)
def plot_epoch(metric):
    if metric not in epochs.columns or "epoch" not in epochs.columns:
        print(f"missing {metric} or epoch in epoch_summary.csv")
        return
    plt.figure()
    plt.plot(epochs["epoch"], epochs[metric], marker="o")
    plt.xlabel("epoch")
    plt.ylabel(metric)
    plt.title(f"{run['name']} — {metric} (epoch)")
    plt.grid(True)
    plt.show()

for m in ["cls_max_mean", "cls_max_p95", "mask_max_mean", "img_forged_mean"]:
    plot_epoch(m)


In [None]:
# 4) Debug JSONL: cls target density + auth penalty stats (if present)
dbg = run["dbg"].copy()
dbg.head(), dbg["tag"].value_counts() if "tag" in dbg.columns else "no tag column"


In [None]:
def plot_dbg(tag, y):
    if dbg.empty or "tag" not in dbg.columns:
        print("no debug.jsonl")
        return
    sub = dbg[dbg["tag"] == tag].copy()
    if sub.empty or y not in sub.columns:
        print(f"no rows for tag={tag} with field={y}")
        return
    x = "global_step" if "global_step" in sub.columns else None
    if x is None:
        print("no global_step in debug.jsonl records")
        return
    sub = sub.sort_values(x)
    plt.figure()
    plt.plot(sub[x], sub[y])
    plt.xlabel(x)
    plt.ylabel(y)
    plt.title(f"{run['name']} — {tag}:{y}")
    plt.grid(True)
    plt.show()

plot_dbg("loss_cls_targets", "pos_frac")
plot_dbg("loss_auth_penalty_stats", "per_image_penalty_mean")
plot_dbg("loss_auth_penalty_stats", "loss_auth_penalty")


In [None]:
# (Optional) compare your new 25-epoch run to an older run (set RUN_5 to the old run folder name)
if RUN_5 != RUN_25:
    run_old = load_run(RUN_5)
    e_new = run["epochs"].sort_values("epoch")
    e_old = run_old["epochs"].sort_values("epoch")

    def compare_epoch(metric):
        if metric not in e_new.columns or metric not in e_old.columns:
            print("missing", metric)
            return
        plt.figure()
        plt.plot(e_old["epoch"], e_old[metric], marker="o", label=f"{RUN_5}")
        plt.plot(e_new["epoch"], e_new[metric], marker="o", label=f"{RUN_25}")
        plt.xlabel("epoch"); plt.ylabel(metric)
        plt.title(metric)
        plt.grid(True)
        plt.legend()
        plt.show()

    for m in ["cls_max_mean", "mask_max_mean", "img_forged_mean"]:
        compare_epoch(m)


In [None]:
import torch, numpy as np
from torch.utils.data import DataLoader
from src.data.dataloader import ForgeryDataset, get_val_transform
from src.models.mask2former_v1 import Mask2FormerForgeryModel
from src.utils.config_utils import sanitize_model_kwargs
import yaml

cfg = yaml.safe_load(open("../config/base.yaml","r"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ds = ForgeryDataset(transform=get_val_transform(img_size=cfg["data"]["img_size"]))
loader = DataLoader(ds, batch_size=cfg["trainer"]["batch_size"], shuffle=False,
                    collate_fn=lambda x: tuple(zip(*x)))

mk = sanitize_model_kwargs(cfg["model"])
mk.pop("auth_gate_forged_threshold", None)
model = Mask2FormerForgeryModel(**mk, auth_gate_forged_threshold=-1.0).to(device)

model.load_state_dict(torch.load("../weights/full_train/full_cls_w0.25.pth", map_location=device))
model.eval()

all_p = []
with torch.no_grad():
    for images, _ in loader:
        images = [im.to(device) for im in images]
        _, _, img_logits = model.forward_logits(images)     # <— bypass inference/gate
        all_p.append(torch.sigmoid(img_logits).cpu())
all_p = torch.cat(all_p).numpy()

print("min/median/mean/p95/max:", np.min(all_p), np.median(all_p), np.mean(all_p), np.quantile(all_p,0.95), np.max(all_p))
for g in [0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95]:
    print(g, "pass_frac:", float((all_p >= g).mean()))
