In [None]:
# cv_analysis.ipynb

import json
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from ipywidgets import Dropdown, HBox, IntSlider, VBox, interact
from torch.utils.data import DataLoader

PROJECT_ROOT = Path("..").resolve()

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.data.dataloader import (
    ForgeryDataset,
    detection_collate_fn,
    get_val_transform,
)
from src.models.mask2former_v1 import Mask2FormerForgeryModel
from src.utils.config_utils import load_yaml, sanitize_model_kwargs

plt.style.use("ggplot")

In [None]:
# -----------------
# Constants / Paths
# -----------------

OOF_ROOT = PROJECT_ROOT / "experiments" / "oof_results"
FULL_TRAIN_ROOT = PROJECT_ROOT / "experiments" / "full_train_results"

CLS_THRESHOLD_PATH = (
    PROJECT_ROOT / "experiments" / "cls_threshold_sweep" / "cls_threshold_sweep.csv"
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CFG_PATH = PROJECT_ROOT / "config" / "base.yaml"

In [None]:
# -----------------
# Load cls_threshold
# -----------------
cls_threshold = pd.read_csv(CLS_THRESHOLD_PATH)

# Quick sanity check
display(cls_threshold.head())
print(f"Loaded {len(cls_threshold)} rows")
print(cls_threshold.columns.tolist())

In [None]:
cls_threshold['cls_threshold'].value_counts()

In [None]:
# aggregate per cls_threshold across all images
summary = (
    cls_threshold.groupby("cls_threshold")
      .agg(
          gate_pass_rate=("gate_pass", "mean"),
          avg_num_keep=("num_keep", "mean"),
          any_fg_pre_rate=("any_fg_pre_keep", "mean"),
          any_fg_post_rate=("any_fg_post_keep", "mean"),
          avg_max_cls_prob=("max_cls_prob", "mean"),
          avg_max_mask_prob=("max_mask_prob", "mean"),
          avg_image_forged_prob=("image_forged_prob", "mean"),
      )
      .reset_index()
      .sort_values("cls_threshold")
)

display(summary)

cls weight sweep

In [None]:
# notebooks/analyze_cls_weight_sweep.ipynb (run from repo root)
from pathlib import Path
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Logs are written by ClsCollapseLogger to:
# experiments/cls_collapse/<run_name>/{meta.json, optimizer.json, step_losses.csv, epoch_summary.csv, debug.jsonl}
# :contentReference[oaicite:0]{index=0}

BASE = Path("../experiments/cls_collapse")

def _safe_read_json(p: Path):
    if not p.exists():
        return None
    return json.loads(p.read_text())

def _safe_read_csv(p: Path):
    if not p.exists():
        return None
    return pd.read_csv(p)

def _safe_read_jsonl(p: Path):
    if not p.exists():
        return None
    rows = []
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return pd.DataFrame(rows) if rows else pd.DataFrame()

def load_run(run_dir: Path):
    meta = _safe_read_json(run_dir / "meta.json") or {}
    opt  = _safe_read_json(run_dir / "optimizer.json") or {}
    steps = _safe_read_csv(run_dir / "step_losses.csv")
    epochs = _safe_read_csv(run_dir / "epoch_summary.csv")
    dbg = _safe_read_jsonl(run_dir / "debug.jsonl")

    # Prefer weight from logged CSV snapshot (most robust)
    w = None
    if steps is not None and "w_mask_cls" in steps.columns and len(steps):
        w = float(pd.to_numeric(steps["w_mask_cls"], errors="coerce").dropna().iloc[-1])
    elif epochs is not None and "w_mask_cls" in epochs.columns and len(epochs):
        w = float(pd.to_numeric(epochs["w_mask_cls"], errors="coerce").dropna().iloc[-1])

    out = {
        "run_dir": run_dir,
        "run_name": run_dir.name,
        "w_mask_cls": w,
        "meta": meta,
        "optimizer": opt,
        "step_losses": steps,
        "epoch_summary": epochs,
        "debug": dbg,
    }
    return out

def load_sweep_runs(base=BASE, name_prefix="full_full_cls_w"):
    run_dirs = sorted([p for p in base.glob(f"{name_prefix}*") if p.is_dir()])
    runs = [load_run(rd) for rd in run_dirs]
    # Keep only runs that look like our sweep (have at least epoch_summary or step_losses)
    runs = [r for r in runs if (r["epoch_summary"] is not None or r["step_losses"] is not None)]
    return runs

runs = load_sweep_runs()
print("Found runs:", len(runs))
for r in runs:
    print(r["run_name"], "w_mask_cls=", r["w_mask_cls"])

# ----------------------------
# Build a compact sweep table
# ----------------------------
rows = []
for r in runs:
    epochs = r["epoch_summary"]
    steps = r["step_losses"]

    last_epoch = None
    if epochs is not None and len(epochs):
        last_epoch = epochs.sort_values("epoch").iloc[-1].to_dict()

    last_step = None
    if steps is not None and len(steps):
        last_step = steps.sort_values("global_step").iloc[-1].to_dict()

    rows.append({
        "run_name": r["run_name"],
        "run_dir": str(r["run_dir"]),
        "w_mask_cls": r["w_mask_cls"],
        "epochs_logged": int(epochs["epoch"].max()) if epochs is not None and len(epochs) else np.nan,
        "steps_logged": int(steps["global_step"].max()) + 1 if steps is not None and len(steps) else np.nan,

        # epoch-end collapse detectors
        "cls_max_mean_last": (last_epoch or {}).get("cls_max_mean", np.nan),
        "cls_max_p95_last":  (last_epoch or {}).get("cls_max_p95", np.nan),
        "mask_max_mean_last":(last_epoch or {}).get("mask_max_mean", np.nan),
        "img_forged_mean_last": (last_epoch or {}).get("img_forged_mean", np.nan),

        # loss snapshot
        "loss_total_last": (last_step or {}).get("loss_total", np.nan),
        "loss_mask_cls_last": (last_step or {}).get("loss_mask_cls", np.nan),
        "loss_auth_penalty_last": (last_step or {}).get("loss_auth_penalty", np.nan),
    })

sweep_df = pd.DataFrame(rows).sort_values("w_mask_cls", na_position="last")
display(sweep_df)

# ----------------------------
# Plots: collapse vs w
# ----------------------------
def plot_vs_w(df, x="w_mask_cls", ys=()):
    dfp = df.dropna(subset=[x]).sort_values(x)
    for y in ys:
        plt.figure()
        plt.plot(dfp[x].values, dfp[y].values, marker="o")
        plt.xlabel(x)
        plt.ylabel(y)
        plt.title(f"{y} vs {x}")
        plt.grid(True)
        plt.show()

plot_vs_w(
    sweep_df,
    ys=["cls_max_mean_last", "cls_max_p95_last", "mask_max_mean_last", "img_forged_mean_last", "loss_total_last"]
)

# ----------------------------
# Per-run learning curves
# ----------------------------
def plot_run_curves(runs, metric="loss_total"):
    for r in runs:
        steps = r["step_losses"]
        if steps is None or metric not in steps.columns:
            continue
        steps = steps.sort_values("global_step")
        plt.figure()
        plt.plot(steps["global_step"], steps[metric])
        plt.xlabel("global_step")
        plt.ylabel(metric)
        plt.title(f"{r['run_name']} (w={r['w_mask_cls']})")
        plt.grid(True)
        plt.show()

plot_run_curves(runs, metric="loss_total")
plot_run_curves(runs, metric="loss_mask_cls")
plot_run_curves(runs, metric="loss_auth_penalty")

# ----------------------------
# Debug JSONL: auth penalty + cls target density over time
# ----------------------------
def extract_debug_timeseries(runs):
    all_rows = []
    for r in runs:
        dbg = r["debug"]
        if dbg is None or dbg.empty:
            continue
        # keep only structured events we care about
        keep = dbg[dbg["tag"].isin(["loss_auth_penalty_stats", "loss_cls_targets"])].copy()
        if keep.empty:
            continue
        keep["run_name"] = r["run_name"]
        keep["w_mask_cls"] = r["w_mask_cls"]
        # normalize common x-axis
        if "global_step" not in keep.columns:
            keep["global_step"] = np.nan
        all_rows.append(keep)
    return pd.concat(all_rows, ignore_index=True) if all_rows else pd.DataFrame()

dbg_df = extract_debug_timeseries(runs)
display(dbg_df.head() if not dbg_df.empty else dbg_df)

def plot_debug_series(dbg_df, tag, y, title=None):
    if dbg_df.empty:
        print("No debug events found.")
        return
    sub = dbg_df[dbg_df["tag"] == tag].copy()
    if sub.empty or y not in sub.columns:
        print(f"No rows for tag={tag} with field {y}")
        return
    sub = sub.dropna(subset=["w_mask_cls"]).sort_values(["w_mask_cls", "global_step"])

    for w, g in sub.groupby("w_mask_cls"):
        g = g.sort_values("global_step")
        plt.figure()
        plt.plot(g["global_step"].values, g[y].values)
        plt.xlabel("global_step")
        plt.ylabel(y)
        plt.title(title or f"{tag}:{y} (w={w})")
        plt.grid(True)
        plt.show()

plot_debug_series(dbg_df, tag="loss_cls_targets", y="pos_frac", title="Matched-query positive fraction over time")
plot_debug_series(dbg_df, tag="loss_auth_penalty_stats", y="per_image_penalty_mean", title="Per-image auth penalty mean over time")
plot_debug_series(dbg_df, tag="loss_auth_penalty_stats", y="loss_auth_penalty", title="Auth penalty loss over time")


# ----------------------------
# Quick "collapse flags" summary
# ----------------------------
flags = sweep_df.copy()
flags["cls_collapsed@~0.125"] = np.isfinite(flags["cls_max_mean_last"]) & (np.abs(flags["cls_max_mean_last"] - 0.125) < 1e-3)
flags["mask_saturated@~1.0"] = np.isfinite(flags["mask_max_mean_last"]) & (flags["mask_max_mean_last"] > 0.99)
display(flags[["run_name","w_mask_cls","cls_max_mean_last","mask_max_mean_last","img_forged_mean_last","cls_collapsed@~0.125","mask_saturated@~1.0"]]
        .sort_values("w_mask_cls", na_position="last"))
