In [15]:
import os
import sys
from pathlib import Path
import numpy as np
import torch
from PIL import Image
import pandas as pd
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

PROJECT_ROOT = Path("..").resolve()
print("PROJECT_ROOT:", PROJECT_ROOT)

# make `src` importable
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from src.data.dataloader import (
    ForgeryDataset,
    detection_collate_fn,
    get_val_transform,
)
from src.models.mask2former_v1 import Mask2FormerForgeryModel
from src.utils.config_utils import load_yaml, sanitize_model_kwargs
from src.training.train_cv import build_solution_df
from src.models.kaggle_metric import score as kaggle_score


PROJECT_ROOT: C:\Users\piiop\Desktop\Portfolio\Projects\RecodAI_LUC


In [16]:
# Constants
CFG_PATH = PROJECT_ROOT / "config" / "base.yaml"
WEIGHTS  = PROJECT_ROOT / "weights" / "full_train" / "model_full_data_baseline.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Define paths
train_authentic = PROJECT_ROOT / "data" / "train_images" / "authentic"   
train_forged    = PROJECT_ROOT / "data" / "train_images" / "forged"
train_masks     = PROJECT_ROOT / "data" / "train_masks" 

print("authentic:", "exists:", train_authentic.exists())
print("forged   :", "exists:", train_forged.exists())
print("masks    :", "exists:", train_masks.exists())


authentic: exists: True
forged   : exists: True
masks    : exists: True


In [3]:
# Build Dataset and inspect
ds = ForgeryDataset(
    transform=None,
    is_train=True,
)

print("Total samples:", len(ds))

# if empty, inspect why
if len(ds) == 0:
    # Quickly list what files it *would* see
    authentic_files = sorted(os.listdir(train_authentic)) if train_authentic.exists() else []
    forged_files    = sorted(os.listdir(train_forged))    if train_forged.exists() else []
    mask_files      = sorted(os.listdir(train_masks))     if train_masks.exists() else []

    print(f"#authentic files: {len(authentic_files)}")
    print(f"#forged files   : {len(forged_files)}")
    print(f"#mask files     : {len(mask_files)}")

    print("First few authentic:", authentic_files[:5])
    print("First few forged   :", forged_files[:5])
    print("First few masks    :", mask_files[:5])
else:
    # peek at first few entries
    print("First 3 samples:")
    for sample in ds.samples[:3]:
        print(sample)


Total samples: 5176
First 3 samples:
{'image_path': 'C:\\Users\\piiop\\Desktop\\Portfolio\\Projects\\RecodAI_LUC\\data\\train_images\\authentic\\10.png', 'mask_path': 'C:\\Users\\piiop\\Desktop\\Portfolio\\Projects\\RecodAI_LUC\\data\\train_masks\\10.npy', 'is_forged': False}
{'image_path': 'C:\\Users\\piiop\\Desktop\\Portfolio\\Projects\\RecodAI_LUC\\data\\train_images\\authentic\\10015.png', 'mask_path': 'C:\\Users\\piiop\\Desktop\\Portfolio\\Projects\\RecodAI_LUC\\data\\train_masks\\10015.npy', 'is_forged': False}
{'image_path': 'C:\\Users\\piiop\\Desktop\\Portfolio\\Projects\\RecodAI_LUC\\data\\train_images\\authentic\\10017.png', 'mask_path': 'C:\\Users\\piiop\\Desktop\\Portfolio\\Projects\\RecodAI_LUC\\data\\train_masks\\10017.npy', 'is_forged': False}


In [4]:
pos = sum(1 for s in ds.samples if s["is_forged"] and os.path.exists(s["mask_path"]) and np.load(s["mask_path"]).sum() > 0)
print("forged samples:", sum(s["is_forged"] for s in ds.samples))
print("forged-with-positive-mask:", pos)


forged samples: 2799
forged-with-positive-mask: 2799


In [5]:
def _safe_load_mask(mask_path: str):
    p = Path(mask_path)
    if not p.exists():
        return None, "missing"
    try:
        m = np.load(p)
        return m, "ok"
    except Exception as e:
        return None, f"load_error: {type(e).__name__}: {e}"

def summarize_dataset_emits(ds, max_examples=5):
    samples = ds.samples
    n = len(samples)

    counts = {
        "total": n,
        "is_forged_true": 0,
        "is_forged_false": 0,
        "mask_missing": 0,
        "mask_load_error": 0,
        "mask_ok": 0,
        "mask_ndim_2": 0,
        "mask_ndim_3": 0,
        "mask_ndim_other": 0,
        "mask_union_sum_gt0": 0,
        "shape_mismatch_disk": 0,
        "authentic_with_nonempty_mask_on_disk": 0,
        "forged_with_nonempty_mask_on_disk": 0,
        # after __getitem__
        "getitem_image_label_1": 0,
        "getitem_image_label_0": 0,
        "getitem_instances_gt0": 0,
        "getitem_instances_eq0": 0,
        "forged_label1_but_zero_instances": 0,
    }

    examples = {k: [] for k in [
        "mask_load_error",
        "shape_mismatch_disk",
        "authentic_with_nonempty_mask_on_disk",
        "forged_with_nonempty_mask_on_disk",
        "forged_label1_but_zero_instances",
    ]}

    for i, s in enumerate(samples):
        is_forged = bool(s.get("is_forged", False))
        counts["is_forged_true" if is_forged else "is_forged_false"] += 1

        img_path = s["image_path"]
        mask_path = s.get("mask_path", "")

        # disk-level checks (don’t rely on __getitem__)
        m, status = _safe_load_mask(mask_path)
        if status == "missing":
            counts["mask_missing"] += 1
        elif status.startswith("load_error"):
            counts["mask_load_error"] += 1
            if len(examples["mask_load_error"]) < max_examples:
                examples["mask_load_error"].append({"i": i, "img": img_path, "mask": mask_path, "err": status})
        else:
            counts["mask_ok"] += 1
            if m.ndim == 2:
                counts["mask_ndim_2"] += 1
                union = (m > 0).astype(np.uint8)
            elif m.ndim == 3:
                counts["mask_ndim_3"] += 1
                # mirror dataloader behavior: accept channel-first (C,H,W) and union over C
                union = np.any(m > 0, axis=0).astype(np.uint8)
            else:
                counts["mask_ndim_other"] += 1
                union = None

            if union is not None and union.sum() > 0:
                counts["mask_union_sum_gt0"] += 1
                if is_forged:
                    counts["forged_with_nonempty_mask_on_disk"] += 1
                    if len(examples["forged_with_nonempty_mask_on_disk"]) < max_examples:
                        examples["forged_with_nonempty_mask_on_disk"].append({"i": i, "img": img_path, "mask": mask_path, "ndim": int(m.ndim)})
                else:
                    counts["authentic_with_nonempty_mask_on_disk"] += 1
                    if len(examples["authentic_with_nonempty_mask_on_disk"]) < max_examples:
                        examples["authentic_with_nonempty_mask_on_disk"].append({"i": i, "img": img_path, "mask": mask_path, "ndim": int(m.ndim)})

            # shape mismatch vs image (only if union derived)
            if union is not None:
                with Image.open(img_path) as im:
                    w, h = im.size
                if union.shape != (h, w):
                    counts["shape_mismatch_disk"] += 1
                    if len(examples["shape_mismatch_disk"]) < max_examples:
                        examples["shape_mismatch_disk"].append({
                            "i": i, "img": img_path, "mask": mask_path,
                            "img_hw": [h, w], "mask_hw": list(union.shape),
                            "mask_ndim": int(m.ndim),
                        })

        # emitted target checks via __getitem__
        try:
            img_t, tgt = ds[i]
        except Exception as e:
            # treat as load_error example
            if len(examples["mask_load_error"]) < max_examples:
                examples["mask_load_error"].append({"i": i, "img": img_path, "mask": mask_path, "err": f"getitem_error: {type(e).__name__}: {e}"})
            continue

        il = float(tgt.get("image_label", torch.tensor(-1.0)).item())
        if il == 1.0:
            counts["getitem_image_label_1"] += 1
        elif il == 0.0:
            counts["getitem_image_label_0"] += 1

        inst = tgt.get("masks", torch.zeros(0))
        k = int(inst.shape[0]) if isinstance(inst, torch.Tensor) and inst.ndim >= 3 else 0
        if k > 0:
            counts["getitem_instances_gt0"] += 1
        else:
            counts["getitem_instances_eq0"] += 1

        # key contradiction: forged image_label=1 but no instances
        if il == 1.0 and k == 0:
            counts["forged_label1_but_zero_instances"] += 1
            if len(examples["forged_label1_but_zero_instances"]) < max_examples:
                examples["forged_label1_but_zero_instances"].append({
                    "i": i,
                    "img": img_path,
                    "mask": mask_path,
                    "boxes_shape": list(tgt.get("boxes", torch.zeros((0,4))).shape),
                    "masks_shape": list(inst.shape) if isinstance(inst, torch.Tensor) else None,
                })

    return counts, examples

counts, examples = summarize_dataset_emits(ds, max_examples=5)

print("=== Summary counts ===")
for k in sorted(counts.keys()):
    print(f"{k:40s} : {counts[k]}")

=== Summary counts ===
authentic_with_nonempty_mask_on_disk     : 2377
forged_label1_but_zero_instances         : 0
forged_with_nonempty_mask_on_disk        : 2799
getitem_image_label_0                    : 2377
getitem_image_label_1                    : 2799
getitem_instances_eq0                    : 2377
getitem_instances_gt0                    : 2799
is_forged_false                          : 2377
is_forged_true                           : 2799
mask_load_error                          : 0
mask_missing                             : 0
mask_ndim_2                              : 0
mask_ndim_3                              : 5176
mask_ndim_other                          : 0
mask_ok                                  : 5176
mask_union_sum_gt0                       : 5176
shape_mismatch_disk                      : 0
total                                    : 5176


In [6]:
def spotcheck_tensors(ds, n=10):
    out = []
    for i in range(min(n, len(ds))):
        img_t, tgt = ds[i]
        out.append({
            "i": i,
            "img_dtype": str(img_t.dtype),
            "img_shape": list(img_t.shape),
            "img_min": float(img_t.min().item()),
            "img_max": float(img_t.max().item()),
            "mask_dtype": str(tgt["masks"].dtype),
            "mask_shape": list(tgt["masks"].shape),
            "image_label": float(tgt["image_label"].item()),
            "num_boxes": int(tgt["boxes"].shape[0]),
        })
    return out

for row in spotcheck_tensors(ds, n=10):
    print(row)


{'i': 0, 'img_dtype': 'torch.float32', 'img_shape': [3, 512, 648], 'img_min': 0.0, 'img_max': 1.0, 'mask_dtype': 'torch.uint8', 'mask_shape': [0, 512, 648], 'image_label': 0.0, 'num_boxes': 0}
{'i': 1, 'img_dtype': 'torch.float32', 'img_shape': [3, 1200, 1600], 'img_min': 0.0, 'img_max': 0.9411764740943909, 'mask_dtype': 'torch.uint8', 'mask_shape': [0, 1200, 1600], 'image_label': 0.0, 'num_boxes': 0}
{'i': 2, 'img_dtype': 'torch.float32', 'img_shape': [3, 256, 320], 'img_min': 0.11764705926179886, 'img_max': 0.843137264251709, 'mask_dtype': 'torch.uint8', 'mask_shape': [0, 256, 320], 'image_label': 0.0, 'num_boxes': 0}
{'i': 3, 'img_dtype': 'torch.float32', 'img_shape': [3, 666, 1000], 'img_min': 0.0, 'img_max': 1.0, 'mask_dtype': 'torch.uint8', 'mask_shape': [0, 666, 1000], 'image_label': 0.0, 'num_boxes': 0}
{'i': 4, 'img_dtype': 'torch.float32', 'img_shape': [3, 712, 414], 'img_min': 0.01568627543747425, 'img_max': 1.0, 'mask_dtype': 'torch.uint8', 'mask_shape': [0, 712, 414], 'ima

In [10]:
# Check CV score with dummy submission

# Build the exact same ground-truth dataframe CV uses (row_id, annotation, shape)
solution_df = build_solution_df(ds)

# Dummy submission: predict "authentic" for every row_id
dummy_submission = pd.DataFrame(
    {"row_id": solution_df["row_id"].values, "annotation": ["authentic"] * len(solution_df)}
)

dummy_score = kaggle_score(
    solution_df.copy(),
    dummy_submission.copy(),
    row_id_column_name="row_id",
)

print("Dummy all-authentic score (CV-aligned):", float(dummy_score))


Dummy all-authentic score (CV-aligned): 0.45923493044822256


runs model.forward_logits(images) and records per-image and per-query statistics:

img_forged_prob: is the image head always predicting “authentic” or “forged”?

cls_max/cls_mean/keep_rate@thr: is the per-query class head collapsed low (so nothing ever survives cls filtering)?

mask_max: is the mask head dead (mask probs never rise) or always-hot (spamming FG)?
It also attaches the GT image_label from targets so you can compare distributions by class.

quick_plots(df_img) + summary prints: visual checks for collapse (single spike distributions, near-zero keep rates, etc.).

So overall: it tests (1) you loaded the right weights into the right architecture and (2) whether collapse is happening in the image head, cls filtering, or mask logits

In [12]:
cfg = load_yaml(str(CFG_PATH))
model_cfg = cfg.get("model", {})

In [None]:
state = torch.load(WEIGHTS, map_location="cpu")

# unwrap common checkpoint formats
if isinstance(state, dict) and "state_dict" in state:
    sd = state["state_dict"]
elif isinstance(state, dict) and "model" in state:
    sd = state["model"]
else:
    sd = state  # assume it's already a state_dict

keys = list(sd.keys())
print("num keys:", len(keys))
print("sample keys:", keys[:20])
# Ensuring we're on the correct architecture
print("has pixel_decoder:", any(k.startswith("pixel_decoder.") for k in keys))
print("has mask_feature_proj:", any(k.startswith("mask_feature_proj.") for k in keys))


num keys: 432
sample keys: ['backbone.body.0.0.weight', 'backbone.body.0.0.bias', 'backbone.body.0.1.weight', 'backbone.body.0.1.bias', 'backbone.body.1.0.layer_scale', 'backbone.body.1.0.block.0.weight', 'backbone.body.1.0.block.0.bias', 'backbone.body.1.0.block.2.weight', 'backbone.body.1.0.block.2.bias', 'backbone.body.1.0.block.3.weight', 'backbone.body.1.0.block.3.bias', 'backbone.body.1.0.block.5.weight', 'backbone.body.1.0.block.5.bias', 'backbone.body.1.1.layer_scale', 'backbone.body.1.1.block.0.weight', 'backbone.body.1.1.block.0.bias', 'backbone.body.1.1.block.2.weight', 'backbone.body.1.1.block.2.bias', 'backbone.body.1.1.block.3.weight', 'backbone.body.1.1.block.3.bias']
has pixel_decoder: True
has mask_feature_proj: False


In [None]:
def build_full_eval_loader(
    img_size=256,
    batch_size=4,
    num_workers=4,
):


    ds = ForgeryDataset(
        transform=get_val_transform(img_size=img_size),
        is_train=False,
    )

    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True,  # shuffle so repeated runs don't stare at same few images
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=detection_collate_fn,  # list[Tensor], list[dict]
        persistent_workers=(num_workers > 0),
    )
    return ds, loader

def load_model(weights_path, device):
    mk = sanitize_model_kwargs(model_cfg)

    # Disable gate for analysis unless you're explicitly studying gating
    model = Mask2FormerForgeryModel(
        **mk
    ).to(device)

    state = torch.load(weights_path, map_location=device)
    model.load_state_dict(state)
    model.eval()
    return model


@torch.no_grad()
def collect_logit_stats(model, loader, device, num_batches=50, cls_thrs=(0.1, 0.2, 0.3)):
    """
    Returns:
      df_img: per-image rows (best for plotting / debugging collapse)
      df_query: optional per-query rows (heavier; use only if needed)
    """
    rows_img = []
    rows_query = []

    it = iter(loader)
    for b in range(num_batches):
        images, targets = next(it)
        images = [x.to(device, non_blocking=True) for x in images]

        mask_logits, class_logits, img_logits = model.forward_logits(images)

        cls_probs = class_logits.sigmoid()               # [B,Q]
        img_probs = img_logits.sigmoid()                 # [B]
        mask_probs = mask_logits.sigmoid().flatten(2)    # [B,Q,HW]

        # image-level summaries (mirrors train_full debug intent)
        cls_max = cls_probs.max(dim=1).values            # [B]
        mask_max = mask_probs.max(dim=2).values.max(dim=1).values  # [B]
        cls_mean = cls_probs.mean(dim=1)                 # [B]
        cls_std  = cls_probs.std(dim=1)                  # [B]

        # target info (authentic vs forged)
        y = torch.stack([t["image_label"].float() for t in targets]).cpu().numpy()

        for i in range(len(images)):
            r = {
                "batch": b,
                "i": i,
                "image_label": float(y[i]),                 # 0 authentic / 1 forged
                "img_forged_prob": float(img_probs[i].item()),
                "cls_max": float(cls_max[i].item()),
                "cls_mean": float(cls_mean[i].item()),
                "cls_std": float(cls_std[i].item()),
                "mask_max": float(mask_max[i].item()),
            }
            for thr in cls_thrs:
                r[f"keep_rate@{thr}"] = float((cls_probs[i] > thr).float().mean().item())
                r[f"num_keep@{thr}"] = int((cls_probs[i] > thr).sum().item())
            rows_img.append(r)

        # per-query (optional but useful if collapse is “all queries identical”)
        # comment this block out if you want it lighter
        B, Q = cls_probs.shape
        for bi in range(B):
            for q in range(Q):
                rows_query.append({
                    "batch": b,
                    "i": bi,
                    "q": q,
                    "image_label": float(y[bi]),
                    "img_forged_prob": float(img_probs[bi].item()),
                    "cls_prob": float(cls_probs[bi, q].item()),
                    "mask_prob_mean": float(mask_probs[bi, q].mean().item()),
                    "mask_prob_max": float(mask_probs[bi, q].max().item()),
                })

    df_img = pd.DataFrame(rows_img)
    df_query = pd.DataFrame(rows_query)
    return df_img, df_query


def quick_plots(df_img):
    # 1) forged-prob distribution split by label
    for lbl in [0.0, 1.0]:
        sub = df_img[df_img["image_label"] == lbl]
        plt.figure()
        plt.hist(sub["img_forged_prob"].values, bins=40)
        plt.title(f"img_forged_prob (label={int(lbl)})")
        plt.xlabel("prob(forged)")
        plt.ylabel("count")
        plt.show()

    # 2) cls_max and mask_max (collapse detectors)
    for col in ["cls_max", "mask_max"]:
        plt.figure()
        plt.hist(df_img[col].values, bins=40)
        plt.title(col)
        plt.xlabel(col)
        plt.ylabel("count")
        plt.show()

    # 3) keep-rate sanity
    keep_cols = [c for c in df_img.columns if c.startswith("keep_rate@")]
    for col in keep_cols:
        plt.figure()
        plt.hist(df_img[col].values, bins=40)
        plt.title(col)
        plt.xlabel(col)
        plt.ylabel("count")
        plt.show()


# ---------------------------
# Run it
# ---------------------------

ds, loader = build_full_eval_loader(img_size=256, batch_size=4, num_workers=4)
model = load_model(WEIGHTS, DEVICE)

df_img, df_query = collect_logit_stats(model, loader, DEVICE, num_batches=50)
display(df_img.head())
display(df_img.describe(percentiles=[0.5, 0.9, 0.95, 0.99]))

quick_plots(df_img)

# Useful “collapse checks” at a glance:
summary = {
    "img_forged_prob_mean": df_img["img_forged_prob"].mean(),
    "img_forged_prob_p95": df_img["img_forged_prob"].quantile(0.95),
    "cls_max_mean": df_img["cls_max"].mean(),
    "cls_max_p95": df_img["cls_max"].quantile(0.95),
    "mask_max_mean": df_img["mask_max"].mean(),
    "mask_max_p95": df_img["mask_max"].quantile(0.95),
}
print(pd.Series(summary).sort_index())

# If you suspect "all queries are the same", check per-image variance across queries:
per_image_cls_std = df_query.groupby(["batch","i"])["cls_prob"].std()
print("per-image cls_prob std: mean=", per_image_cls_std.mean(), " p05=", per_image_cls_std.quantile(0.05))
