# Scientific Image Forgery — Clean Training & Inference Pipeline

This notebook consolidates preprocessing, validation (image-level F1), calibration, and inference with consistent transforms. Old conflicting cells were removed.
Key changes:
- Unified preprocessing across train/val/test (identical resize/normalize/morphology)
- Image-level F1 metric (leaderboard) computed on validation
- Grid-search over (pixel_threshold, area_frac)
- Save & reuse calibrated params in checkpoint for test-time inference

In [None]:

# === CONFIG & IMPORTS (NEW) ===
import os, math, json, numpy as np, torch, torch.nn.functional as F
from sklearn.metrics import f1_score
import torchvision.transforms as T

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Pick ONE size and keep it everywhere (train/val/test)
IMAGE_SIZE = 512 # or 256, just be consistent across the notebook

# Normalization (keep consistent)
MEAN = (0.485, 0.456, 0.406)
STD  = (0.229, 0.224, 0.225)

# Optional morphology; if you enable it, it MUST be the same in val & test
USE_MORPH = False
MORPH_KERNEL = 3

# Where to save the best model
CKPT_PATH = "best_model_calibrated.pth"


In [None]:

# === PREPROCESSING (REPLACE) ===
# Use the SAME preprocessing everywhere.

img_transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
])

def resize_mask(mask_tensor, size=(IMAGE_SIZE, IMAGE_SIZE)):
    """
    mask_tensor: FloatTensor in [0,1] or {0,1}, shape HxW or 1xHxW
    Uses NEAREST for masks to avoid blurring labels.
    """
    if mask_tensor.ndim == 2:
        mask_tensor = mask_tensor.unsqueeze(0)  # 1xHxW
    return F.interpolate(mask_tensor.unsqueeze(0), size=size, mode='nearest').squeeze(0)

def maybe_morphology(mask_bin):
    """
    mask_bin: FloatTensor 1xHxW with values {0,1}. Applies opening if USE_MORPH.
    Keep this identical in validation and test if you enable it.
    """
    if not USE_MORPH:
        return mask_bin
    k = MORPH_KERNEL
    pad = k // 2
    kernel = torch.ones((1,1,k,k), device=mask_bin.device)
    # Erode: pixel stays 1 only if every pixel in the kxk neighborhood is 1
    eroded = (F.conv2d(mask_bin.unsqueeze(0), kernel, padding=pad) == (k*k)).float()
    # Dilate: pixel is 1 if any neighbor is 1
    dilated = (F.conv2d(eroded, kernel, padding=pad) > 0).float()
    return dilated.squeeze(0)


# Recod.ai / LUC — Scientific Image Forgery Detection 

This notebook trains a U-Net segmentation model on GPU T4 to detect copy-move forgeries in biomedical images.
It:
- Loads `train_images/authentic` and `train_images/forged`
- Merges multiple `.npy` masks per case into a single binary mask
- Trains with PyTorch/XLA (multi-core)
- Evaluates with an F1-like metric (approx)
- Exports a submission CSV using RLE encoding (or "authentic" if no mask predicted)


## ✨ Update: Data Augmentation (no deformation, size preserved)

Changes made:
- Added an **AugmentedWrapper** dataset that applies **horizontal / vertical flips** to both image and mask, and **brightness / hue** changes to the **image only** (mask never altered except for flips).
- Rewired the training pipeline to **double** the effective number of training samples via `ConcatDataset(original, augmented)` → **X2 images** seen per epoch.
- **No resizing / warping** was added by these changes. Image shapes stay exactly as in the original pipeline after the existing `resize_pair` step.


In [None]:
# No extra installs required; we stick to torchvision to avoid dependency conflicts.
import torch, torchvision, sys, numpy
print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("Python:", sys.version)
print("CUDA available:", torch.cuda.is_available())

In [None]:
import os, glob, random
from pathlib import Path
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
import torchvision.models.segmentation as tvseg

from sklearn.model_selection import StratifiedKFold

# ===== NEW PATHS & CONFIG =====
from pathlib import Path
from typing import Optional, Dict, List
import random

# RecoDAI LUC Scientific Image Forgery Detection
RECODAI_AUTH_DIR = Path("/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images/authentic")
RECODAI_FORG_DIR = Path("/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images/forged")
RECODAI_MASK_DIR = Path("/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks")

# Extra train-only datasets
NUCLEI_ROOT = Path("/kaggle/input/nuclei-segmentation-in-microscope-cell-images/Nuclei")
AUGMENT_IMG_ROOT = Path("/kaggle/input/data-augment")   # PNGs are somewhere under this root
AUGMENT_MASK_ROOT = Path("/kaggle/input/data-augment")  # NPys are somewhere under this root

# Split config
TRAIN_FRACTION = 0.70
RANDOM_SEED = 42

# Labels
LABEL_AUTHENTIC = 0
LABEL_FORGED = 1

random.seed(RANDOM_SEED)


OUT_DIR    = "/kaggle/working"
os.makedirs(OUT_DIR, exist_ok=True)

# -------------------
# Training config
# -------------------
SEED          = 42
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE    = 518            
BATCH_SIZE    = 6              
EPOCHS        = 5             
LR            = 3e-4
WEIGHT_DECAY  = 1e-4
N_FOLDS       = 5
FOLD_TO_RUN   = 0              
USE_AMP       = True
PRINT_EVERY   = 50

print("Device:", DEVICE)

In [None]:
# --- DINOv2 config ---
IMAGE_SIZE = 518     # multiple of 14 (ViT patch size); 518 = 37x14
NUM_CLASSES = 1      # binary mask (change if you use multi-class)
BACKBONE_SIZE = 'vitb14'  # choices: 'vits14', 'vitb14', 'vitl14', 'vitg14'

# training knobs (keep your existing ones if you prefer)
LR = 1e-4
WEIGHT_DECAY = 1e-4
USE_AMP = True

# normalization expected by ImageNet-pretrained backbones (incl. DINOv2)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF

def dinov2_normalize(img_tensor):
    """
    img_tensor: float tensor [C,H,W] in [0,1]
    returns: normalized tensor
    """
    return TF.normalize(img_tensor, IMAGENET_MEAN, IMAGENET_STD)


In [None]:
# ===== HELPERS =====
def map_by_stem(paths):
    """Return a dict {stem: path} for quick filename-based lookups."""
    return {p.stem: p for p in paths}

def to_item(img_path: Path, label: int, mask_path: Optional[Path], src: str) -> Dict:
    return {
        "path": str(img_path),                              # <— required by your __getitem__
        "img": str(img_path),                               # compatibility
        "mask": (str(mask_path) if mask_path is not None else None),
        "mask_path": (str(mask_path) if mask_path is not None else None),
        "label": int(label),
        "src": src,
    }


def debug_counts(items: List[Dict], title="Items"):
    from collections import Counter
    c = Counter([it["label"] for it in items])
    print(f"{title}: total={len(items)}  authentic={c.get(0,0)}  forged={c.get(1,0)}")


In [None]:
def seed_everything(seed=42):
    import os, random
    import numpy as np
    import torch
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    if torch.cuda.is_available():
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [None]:
def load_mask_npy(mask_path: str) -> np.ndarray:
    """
    Load .npy mask which may be:
      - (H, W): single binary mask
      - (N, H, W): multiple instance masks (we OR them)
      - list/tuple/dict: we try to extract arrays and OR them
    Returns uint8 array {0,1}
    """
    m = np.load(mask_path, allow_pickle=True)

    # unpack common container types
    if isinstance(m, (list, tuple)):
        arrs = []
        for item in m:
            item = np.asarray(item)
            if item.ndim == 2:
                arrs.append(item.astype(np.uint8))
            elif item.ndim == 3:
                arrs.append((item.sum(axis=0) > 0).astype(np.uint8))
        if len(arrs) == 0:
            raise ValueError(f"Unsupported mask list/tuple in {mask_path}")
        m = np.stack(arrs, axis=0)

    if isinstance(m, dict):
        # If your dict uses a different key, edit here
        key = "masks" if "masks" in m else list(m.keys())[0]
        m = np.asarray(m[key])

    m = np.asarray(m)
    if m.ndim == 2:
        mask = (m > 0).astype(np.uint8)
    elif m.ndim == 3:
        mask = (m.sum(axis=0) > 0).astype(np.uint8)
    else:
        raise ValueError(f"Unsupported mask ndim {m.ndim} for {mask_path}")
    return mask


def resize_pair(img_pil: Image.Image, mask_np: np.ndarray, size: int) -> tuple[Image.Image, Image.Image]:
    """
    Resize image and mask to a square 'size' while keeping nearest for mask.
    """
    img_resized  = img_pil.resize((size, size), resample=Image.BILINEAR)
    mask_pil     = Image.fromarray(mask_np)
    mask_resized = mask_pil.resize((size, size), resample=Image.NEAREST)
    return img_resized, mask_resized


def random_flip_rotate(img: Image.Image, mask: Image.Image) -> tuple[Image.Image, Image.Image]:
    """
    Light augmentations that are mask-safe:
    - random horizontal flip
    - random vertical flip
    - small rotate [-10,10] degrees (expand=False), bilinear for img, nearest for mask
    """
    import random
    if random.random() < 0.5:
        img  = TF.hflip(img)
        mask = TF.hflip(mask)
    if random.random() < 0.2:
        img  = TF.vflip(img)
        mask = TF.vflip(mask)
    angle = random.uniform(-10, 10)
    img  = TF.rotate(img, angle=angle, interpolation=InterpolationMode.BILINEAR, expand=False)
    mask = TF.rotate(mask, angle=angle, interpolation=InterpolationMode.NEAREST,   expand=False)
    return img, mask


In [None]:
def build_items(base_dir: str):
    """
    Build list of samples from a base directory that contains:
      - authentic/ (images)
      - forged/ (images)
      - masks_npy/ (masks)
    Returns:
      - path: image path
      - case_id: stem
      - label: 1 if forged, 0 if authentic
      - mask_path: path to .npy if forged else None
    """
    items = []
    # Define specific paths based on the new structure
    mask_dir = f"{base_dir}/masks_npy" 
    
    for cls in ["authentic", "forged"]:
        img_dir = f"{base_dir}/{cls}" 
        if not os.path.exists(img_dir):
            continue
        for p in glob.glob(os.path.join(img_dir, "*")):
            case_id = Path(p).stem
            mask_path = None
            if cls == "forged":
                # The mask_dir is now defined inside this function
                cand = os.path.join(mask_dir, f"{case_id}.npy") 
                mask_path = cand if os.path.exists(cand) else None
            items.append({
                "path": p,
                "case_id": case_id,
                "label": 1 if cls == "forged" else 0,
                "mask_path": mask_path
            })
    return items


# New name so it won't conflict with the older ForgeryDataset in memory
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode

def _to_single_channel_mask(mask_np: np.ndarray) -> torch.Tensor:
    """Convert any (H,W), (H,W,C), or (C,H,W) mask to (1,H,W) binary tensor."""
    if mask_np.dtype != np.float32:
        mask_np = mask_np.astype(np.float32)

    if mask_np.ndim == 2:
        m = torch.from_numpy(mask_np).unsqueeze(0)        # (1,H,W)
    elif mask_np.ndim == 3:
        # handle (H,W,C) or (C,H,W)
        if mask_np.shape[0] <= 4:
            m = torch.from_numpy(mask_np)                 # (C,H,W)
        else:
            m = torch.from_numpy(mask_np).permute(2,0,1)  # (H,W,C)->(C,H,W)
        if m.shape[0] > 1:
            m = m.max(dim=0, keepdim=True)[0]            # collapse channels
    else:
        mask_np = np.squeeze(mask_np).astype(np.float32)
        return _to_single_channel_mask(mask_np)

    return (m > 0.5).float()                              # (1,H,W)

class ForgeryDatasetV2(Dataset):
    def __init__(self, items, transforms=None, image_size=512, is_train=True):
        """
        items: list of dicts with keys 'path' (or 'img'), optional 'mask'/'mask_path', and 'label'
        transforms: optional callable applied to the image tensor (C,H,W)
        image_size: int or (H,W) to resize image & mask
        is_train: available if you want to branch later (not used here)
        """
        self.items = items
        self.transforms = transforms
        self.is_train = is_train
        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        self.image_size = tuple(image_size)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        it = self.items[idx]
        p = it.get("path") or it.get("img") or it.get("image") or it.get("image_path")
        if p is None:
            raise KeyError(f"Missing image path at idx={idx}: {it}")

        # image -> tensor
        img = Image.open(p).convert("RGB")
        img_np = np.array(img)                                  # (H,W,3)
        img_t  = torch.from_numpy(img_np).permute(2,0,1).float() / 255.0  # (3,H,W)

        # mask -> (1,H,W)
        mpath = it.get("mask_path") or it.get("mask")
        mask_np = None
        if mpath is not None:
            try:
                mask_np = np.load(mpath)
            except Exception as e:
                print(f"[WARN] Could not load mask at {mpath}: {e}. Using empty mask.")
        if mask_np is None:
            h, w = img_np.shape[:2]
            mask_np = np.zeros((h, w), dtype=np.float32)

        mask_t = _to_single_channel_mask(mask_np)               # (1,H,W)

        # resize both to common size
        img_t  = tv_resize(img_t, size=self.image_size, interpolation=InterpolationMode.BILINEAR, antialias=True)
        mask_t = F.interpolate(mask_t.unsqueeze(0), size=self.image_size, mode="nearest").squeeze(0)  # (1,H,W)

        if self.transforms is not None:
            img_t = self.transforms(img_t)

        label   = int(it.get("label", 0))
        case_id = p
        return img_t, mask_t, case_id, label


In [None]:
# --- Data augmentation wrapper (no deformation; flips + color only) ---
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from PIL import Image
import numpy as np
import torch, random

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

class AugmentedWrapper(Dataset):
    """Wrap a base Dataset that yields (img_t, mask_t, case_id, label)
    and return an augmented view:
      - Horizontal and/or vertical flips are applied to both image and mask.
      - Brightness and hue adjustments are applied to the image only.
      - No geometric resizing/cropping/warping is introduced here.
    """
    def __init__(self, base_ds, 
                 hflip_p: float = 0.5, vflip_p: float = 0.5,
                 brightness: tuple = (0.9, 1.1),  # multiplicative
                 hue: tuple = (-0.03, 0.03)):     # fraction in [-0.5, 0.5]
        self.base_ds   = base_ds
        self.hflip_p   = hflip_p
        self.vflip_p   = vflip_p
        self.brightness= brightness
        self.hue       = hue

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, idx):
        img_t, mask_t, case_id, label = self.base_ds[idx]  # img_t: [3,H,W] float 0..1, mask_t: [1,H,W] float {0,1}
        
        # Convert to PIL for functional transforms that require PIL
        img_pil  = TF.to_pil_image(img_t)  # maintains size
        mask_np  = mask_t.squeeze(0).cpu().numpy().astype(np.uint8)
        mask_pil = Image.fromarray(mask_np, mode="L")      # single-channel mask
        
        # --- Flips ---
        if random.random() < self.hflip_p:
            img_pil  = TF.hflip(img_pil)
            mask_pil = TF.hflip(mask_pil)
        if random.random() < self.vflip_p:
            img_pil  = TF.vflip(img_pil)
            mask_pil = TF.vflip(mask_pil)

        # --- Color (image only) ---
        # Keep adjustments mild to respect copy–move appearance; no saturation/contrast changes per request.
        b = random.uniform(*self.brightness)
        img_pil = TF.adjust_brightness(img_pil, b)
        h = random.uniform(*self.hue)
        img_pil = TF.adjust_hue(img_pil, h)

        # Back to tensors (preserve size)
        img_out  = TF.to_tensor(img_pil).to(dtype=img_t.dtype)
        mask_out = torch.from_numpy(np.array(mask_pil, dtype=np.float32)).unsqueeze(0)  # [1,H,W], {0,1}

        img_np = np.array(img)                    # (H,W,3) uint8
        img_t  = torch.from_numpy(img_np).permute(2,0,1).float() / 255.0  # (3,H,W)
        
        mask_t = torch.from_numpy(mask.astype(np.float32))                 # (H,W)
        mask_t = mask_t.unsqueeze(0)                                       # (1,H,W)  <-- ADD THIS
        
        label  = int(it.get("label", 0))
        case_id = p
        img_t = TF.normalize(img_t, IMAGENET_MEAN, IMAGENET_STD)
        return img_t, mask_t, case_id, label

In [None]:
# ===== BUILD 70/30 TRAIN/VAL FROM RECODAI =====
# Gather images
recodai_auth_imgs = sorted(RECODAI_AUTH_DIR.glob("*.png"))
recodai_forg_imgs = sorted(RECODAI_FORG_DIR.glob("*.png"))

# Gather masks (npy) and map by stem for quick lookup
recodai_masks_map = map_by_stem(sorted(RECODAI_MASK_DIR.glob("*.npy")))

# Build labeled list (image, label, mask_path or None)
recodai_all = []

# Authentic: normally no mask; set to None (your dataset class can handle None or you can create zero masks later)
for p in recodai_auth_imgs:
    recodai_all.append(to_item(p, LABEL_AUTHENTIC, None, src="recodai"))

# Forged: try to match mask by stem
missing_masks = 0
for p in recodai_forg_imgs:
    m = recodai_masks_map.get(p.stem, None)
    if m is None:
        missing_masks += 1
    recodai_all.append(to_item(p, LABEL_FORGED, m, src="recodai"))

if missing_masks:
    print(f"[WARN] RecoDAI forged images missing masks: {missing_masks} (will set mask=None for those)")

# Deterministic shuffle before split
random.shuffle(recodai_all)

# Stratified 70/30 split by label
# We do this manually to avoid a hard dependency on sklearn.
auth_pool = [it for it in recodai_all if it["label"] == LABEL_AUTHENTIC]
forg_pool = [it for it in recodai_all if it["label"] == LABEL_FORGED]

def split_pool(pool, frac):
    n_train = int(round(len(pool) * frac))
    return pool[:n_train], pool[n_train:]

auth_train, auth_val = split_pool(auth_pool, TRAIN_FRACTION)
forg_train, forg_val = split_pool(forg_pool, TRAIN_FRACTION)

train_items = auth_train + forg_train
val_items   = auth_val   + forg_val

# Shuffle each split for good measure
random.shuffle(train_items)
random.shuffle(val_items)

debug_counts(train_items, "RecoDAI TRAIN")
debug_counts(val_items,   "RecoDAI VAL")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from typing import List, Tuple, Optional
from scipy import ndimage

def load_mask_instances(mask_path: str) -> List[np.ndarray]:
    """
    Load instance masks (list of [H,W] boolean arrays).
    - If the .npy contains (N,H,W) -> split along axis 0
    - If it's (H,W) -> split into connected components as separate instances
    - If it's a container (list/tuple/dict), try to extract arrays and split accordingly
    """
    m = np.load(mask_path, allow_pickle=True)

    def _to_list_of_masks(arr: np.ndarray) -> List[np.ndarray]:
        arr = np.asarray(arr)
        if arr.ndim == 2:
            # Connected components as instances
            lab, n = ndimage.label(arr > 0)
            inst = [(lab == k) for k in range(1, n + 1)]
            return [x.astype(np.uint8) for x in inst] if n > 0 else []
        elif arr.ndim == 3:
            return [((arr[k] > 0).astype(np.uint8)) for k in range(arr.shape[0])]
        else:
            raise ValueError(f"Unsupported array ndim {arr.ndim}")

    if isinstance(m, (list, tuple)):
        out = []
        for item in m:
            out.extend(_to_list_of_masks(np.asarray(item)))
        return out

    if isinstance(m, dict):
        key = "masks" if "masks" in m else list(m.keys())[0]
        return _to_list_of_masks(np.asarray(m[key]))

    # ndarray
    return _to_list_of_masks(np.asarray(m))


def overlay_instances_on_image(
    img_rgb: np.ndarray,
    inst_masks: List[np.ndarray],
    alpha: float = 0.35
) -> np.ndarray:
    """
    Draw semi-transparent colored overlays for each instance mask on top of the RGB image.
    Returns an RGB uint8 array.
    """
    out = img_rgb.astype(np.float32).copy()
    H, W, _ = out.shape

    # Generate a fixed set of random colors (to be consistent across runs, use seed)
    rng = np.random.default_rng(1234)
    colors = rng.integers(low=0, high=255, size=(max(1, len(inst_masks)), 3), dtype=np.uint8)

    for i, m in enumerate(inst_masks):
        if m.shape != (H, W):
            # Resize instance to match (rare, but safety)
            m = np.array(Image.fromarray((m > 0).astype(np.uint8)).resize((W, H), Image.NEAREST))
        mask = (m > 0).astype(np.uint8)
        color = colors[i % len(colors)].astype(np.float32)

        # Blend: out = (1-alpha)*out + alpha*color on masked pixels
        out[mask == 1] = (1 - alpha) * out[mask == 1] + alpha * color

    return np.clip(out, 0, 255).astype(np.uint8)


def show_forged_example(
    item: dict,
    image_size: Optional[int] = None,
    title: Optional[str] = None
) -> None:
    """
    Visualize one forged sample with instance overlays.
    - item: dict from items list (must be forged and have mask_path)
    - image_size: if provided, resize for display; otherwise show native
    """
    assert item["mask_path"] is not None, "This helper expects a forged item with a mask_path."
    img = Image.open(item["path"]).convert("RGB")
    w, h = img.size

    insts = load_mask_instances(item["mask_path"])

    if image_size is not None:
        img = img.resize((image_size, image_size), Image.BILINEAR)
        # Resize instances for display
        insts_resized = []
        for m in insts:
            m = np.array(Image.fromarray((m > 0).astype(np.uint8)).resize((image_size, image_size), Image.NEAREST))
            insts_resized.append(m)
        insts = insts_resized

    img_np = np.array(img)
    over = overlay_instances_on_image(img_np, insts, alpha=0.38)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.axis("off")
    plt.title(title or f"Forged image — {Path(item['path']).name}")

    plt.subplot(1, 2, 2)
    plt.imshow(over)
    plt.axis("off")
    plt.title(f"Overlay with {len(insts)} instance(s)")
    plt.show()


def show_forged_grid(
    forged_list: List[dict],
    n: int = 6,
    image_size: int = 512,
    cols: int = 3
) -> None:
    """
    Show a grid of forged examples with overlays.
    """
    n = min(n, len(forged_list))
    sel = random.sample(forged_list, n)

    rows = (n + cols - 1) // cols
    plt.figure(figsize=(5 * cols, 5 * rows))
    for i, it in enumerate(sel, 1):
        img = Image.open(it["path"]).convert("RGB").resize((image_size, image_size), Image.BILINEAR)
        insts = load_mask_instances(it["mask_path"])
        insts = [np.array(Image.fromarray((m > 0).astype(np.uint8)).resize((image_size, image_size), Image.NEAREST)) for m in insts]
        over  = overlay_instances_on_image(np.array(img), insts, alpha=0.38)

        ax = plt.subplot(rows, cols, i)
        ax.imshow(over)
        ax.axis("off")
        ax.set_title(f"{Path(it['path']).name}\n{len(insts)} instance(s)")
    plt.tight_layout()
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# DINOv2-based lightweight segmentation head; returns dict with "out" to be drop-in compatible.
def _load_dinov2(backbone='vitb14'):
    hub_id = {
        'vits14': 'dinov2_vits14',
        'vitb14': 'dinov2_vitb14',
        'vitl14': 'dinov2_vitl14',
        'vitg14': 'dinov2_vitg14',
    }[backbone]
    m = torch.hub.load('facebookresearch/dinov2', hub_id)
    m.eval()
    return m

class DINOv2Seg(nn.Module):
    def __init__(self, backbone='vitb14', out_ch=1):
        super().__init__()
        self.vit = _load_dinov2(backbone)
        self.embed_dim = {'vits14':384,'vitb14':768,'vitl14':1024,'vitg14':1536}[backbone]
        self.decode = nn.Sequential(
            nn.Conv2d(self.embed_dim, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, out_ch, 1)
        )

    def _tokens_from_vit(self, x):
        """
        Return patch tokens as [B, N, D] (no CLS), regardless of the exact API.
        Supports torch.hub dinov2 and timm variants.
        """
        # 1) Preferred: forward_features with return_all_tokens (timm & some hub builds)
        ff = getattr(self.vit, "forward_features", None)
        if callable(ff):
            try:
                out = ff(x, return_all_tokens=True)  # timm ViT supports this
                # timm sometimes returns a tensor [B, 1+N, D] or a dict
                if isinstance(out, dict):
                    if "x_norm_patchtokens" in out:         # dinov2-style dict
                        tokens = out["x_norm_patchtokens"]   # [B, N, D]
                        return tokens
                    if "tokens" in out:
                        t = out["tokens"]                    # [B, 1+N, D]
                        return t[:, 1:, :]
                elif out.dim() == 3:
                    # [B, 1+N, D] or [B, N, D]
                    return out[:, 1:, :] if out.size(1) > 0 and out.size(1) != int(out.size(1)**0.5)**2 else out
            except TypeError:
                # Some builds don't accept return_all_tokens
                out = ff(x)
                if isinstance(out, dict) and "x_norm_patchtokens" in out:
                    return out["x_norm_patchtokens"]

        # 2) Fallback: get_intermediate_layers (available on many ViTs, incl. dinov2)
        gil = getattr(self.vit, "get_intermediate_layers", None)
        if callable(gil):
            # returns list of tuples (cls, tokens) or tensors; take the last block
            inter = gil(x, n=1, return_class_token=True)
            last = inter[-1]
            if isinstance(last, (list, tuple)) and len(last) == 2:
                cls, tokens = last
                return tokens                         # [B, N, D]
            if torch.is_tensor(last) and last.dim() == 3:
                # [B, 1+N, D]
                return last[:, 1:, :]

        # 3) Absolute fallback: plain forward → usually [B, D] pooled → not usable
        y = self.vit(x)
        if torch.is_tensor(y) and y.dim() == 3:
            return y[:, 1:, :]
        raise RuntimeError("Could not extract patch tokens from DINOv2 backbone; got shape "
                           f"{tuple(y.shape) if torch.is_tensor(y) else type(y)}")

    def forward(self, x):
        tokens = self._tokens_from_vit(x)     # [B, N, D]
        B, N, D = tokens.shape
        S = int(N ** 0.5)                     # assume square token grid
        feat = tokens.transpose(1, 2).reshape(B, D, S, S)   # [B, D, S, S]
        logits = self.decode(feat)
        logits = F.interpolate(logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
        return {"out": logits}


def build_model():
    # Choose backbone: 'vits14' (small), 'vitb14' (base), 'vitl14', 'vitg14'
    return DINOv2Seg(backbone='vitb14', out_ch=1)

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2 * (probs * targets).sum() + self.eps
        den = probs.sum() + targets.sum() + self.eps
        return 1 - num / den


def bce_dice_loss(logits, targets, alpha=0.5, pos_weight=4.0):
    if targets.dim() == 3:
        targets = targets.unsqueeze(1)
    bce  = torch.nn.functional.binary_cross_entropy_with_logits(
        logits, targets, pos_weight=torch.tensor([pos_weight], device=logits.device)
    )
    dice = DiceLoss()(logits, targets)
    return alpha * bce + (1 - alpha) * dice


In [None]:
@torch.no_grad()
def batch_f1(logits, targets, thr=0.5):
    """
    Quick pixel F1 proxy (not the official image-wise oF1).
    Good enough to monitor training.
    """
    probs = torch.sigmoid(logits)
    preds = (probs > thr).float()
    tp = (preds * targets).sum().item()
    fp = (preds * (1 - targets)).sum().item()
    fn = ((1 - preds) * targets).sum().item()
    precision = tp / (tp + fp + 1e-6)
    recall    = tp / (tp + fn + 1e-6)
    return 2 * precision * recall / (precision + recall + 1e-6)


def train_one_epoch(model, loader, optimizer, device, scaler=None, print_every=50):
    model.train()
    running_loss, running_f1 = 0.0, 0.0

    for it, (imgs, masks, _, _) in enumerate(loader):
        imgs, masks = imgs.to(device), masks.to(device)
        if masks.dim() == 3:
            masks = masks.unsqueeze(1)   # (B,1,H,W)

        optimizer.zero_grad(set_to_none=True)
        # ✅ Proper indentation + new AMP API
        with torch.amp.autocast("cuda", enabled=(scaler is not None)):
            out = model(imgs)["out"]                # DeepLab returns dict
            loss = bce_dice_loss(out, masks)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        running_f1   += batch_f1(out.detach(), masks)

        if (it + 1) % print_every == 0:
            n = it + 1
            print(f"[train] it {n:04d} | loss {running_loss/n:.4f} | f1 {running_f1/n:.4f}")

    n = len(loader)
    return running_loss / n, running_f1 / n


@torch.no_grad()
def valid_one_epoch(model, loader, device, thr=0.5):
    model.eval()
    running_loss, running_f1 = 0.0, 0.0

    for imgs, masks, _, _ in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        if masks.dim() == 3:
            masks = masks.unsqueeze(1)

        out  = model(imgs)["out"]
        loss = bce_dice_loss(out, masks)
        running_loss += loss.item()
        running_f1   += batch_f1(out, masks, thr=thr)

    n = len(loader)
    return running_loss / n, running_f1 / n

@torch.no_grad()
def find_best_threshold(model, loader, device, thr_grid=None):
    model.eval()
    if thr_grid is None:
        thr_grid = np.linspace(0.05, 0.95, 19)
    all_logits = []
    all_masks  = []
    for imgs, masks, _, _ in loader:
        imgs = imgs.to(device)
        out  = model(imgs)["out"]  # (B,1,H,W)
        all_logits.append(out.cpu())
        all_masks.append(masks.cpu())
    logits = torch.cat(all_logits, dim=0)
    targets = torch.cat(all_masks, dim=0)
    best_thr, best_f1 = 0.5, -1.0
    for thr in thr_grid:
        f1 = batch_f1(logits, targets, thr=float(thr))
        if f1 > best_f1:
            best_f1, best_thr = float(f1), float(thr)
    print(f"[val] best F1={best_f1:.4f} @ thr={best_thr:.2f}")
    return best_thr, best_f1

In [None]:
from tqdm.auto import tqdm

# ===== ADD TRAIN-ONLY DATA =====

# 1) Nuclei images: put ALL as Authentic in TRAIN
nuclei_imgs = sorted(NUCLEI_ROOT.rglob("*.png"))
nuclei_items = []
for p in tqdm(nuclei_imgs, desc="Nuclei → Authentic", unit="img"):
    nuclei_items.append(to_item(p, LABEL_AUTHENTIC, None, src="nuclei"))

# 2) Data-augment set: all PNGs as Forged + match NPys by stem if available
augment_imgs = sorted(AUGMENT_IMG_ROOT.rglob("*.png"))
augment_masks_map = map_by_stem(sorted(AUGMENT_MASK_ROOT.rglob("*.npy")))

augment_items = []
aug_missing = 0
for p in tqdm(augment_imgs, desc="Augment → Forged (match masks)", unit="img"):
    m = augment_masks_map.get(p.stem, None)
    if m is None:
        aug_missing += 1
    augment_items.append(to_item(p, LABEL_FORGED, m, src="augment"))

if aug_missing:
    print(f"[WARN] Augment forged images missing masks: {aug_missing} (mask=None)")

# Append both groups ONLY to TRAIN (show progress as well)
for it in tqdm(nuclei_items, desc="Append nuclei to TRAIN", unit="item"):
    train_items.append(it)
for it in tqdm(augment_items, desc="Append augment to TRAIN", unit="item"):
    train_items.append(it)

# Shuffle final train
random.shuffle(train_items)

debug_counts(nuclei_items,  "Added Nuclei (TRAIN-only)")
debug_counts(augment_items, "Added Augment (TRAIN-only)")
debug_counts(train_items,   "FINAL TRAIN")
debug_counts(val_items,     "FINAL VAL (unchanged)")


In [None]:
# === FIX PATHS + RESCAN AUGMENT & NUCLEI, APPEND TO TRAIN ===
from pathlib import Path
from tqdm.auto import tqdm
import random

# 1) Correct roots
NUCLEI_ROOT       = Path("/kaggle/input/nuclei-segmentation-in-microscope-cell-images/Nuclei/Nuclei")
AUGMENT_IMG_ROOT  = Path("/kaggle/input/data-augment")
AUGMENT_MASK_ROOT = Path("/kaggle/input/data-augment")

# 2) Find files
nuclei_imgs   = sorted(NUCLEI_ROOT.rglob("*.png"))
augment_imgs  = sorted(AUGMENT_IMG_ROOT.rglob("*.png"))
augment_masks = sorted(AUGMENT_MASK_ROOT.rglob("*.npy"))
augment_masks_map = {p.stem: p for p in augment_masks}

print(f"[scan] nuclei png:  {len(nuclei_imgs)}")
print(f"[scan] augment png: {len(augment_imgs)}")
print(f"[scan] augment npy: {len(augment_masks)}")

# 3) Build items for each source (ensure to_item writes 'path'/'mask_path')
def to_item(img_path, label, mask_path, src):
    return {
        "path": str(img_path),
        "img": str(img_path),
        "mask_path": (str(mask_path) if mask_path is not None else None),
        "mask": (str(mask_path) if mask_path is not None else None),
        "label": int(label),
        "src": src,
    }

LABEL_AUTHENTIC = 0
LABEL_FORGED    = 1

nuclei_items = [to_item(p, LABEL_AUTHENTIC, None, src="nuclei")
                for p in tqdm(nuclei_imgs, desc="Build nuclei items", unit="img")]

augment_items = []
aug_missing = 0
for p in tqdm(augment_imgs, desc="Build augment items (+mask match)", unit="img"):
    m = augment_masks_map.get(p.stem, None)
    if m is None:
        aug_missing += 1
    augment_items.append(to_item(p, LABEL_FORGED, m, src="augment"))
if aug_missing:
    print(f"[WARN] augment images missing masks: {aug_missing}")

# 4) De-duplicate before appending (avoid double-adding if you re-run)
existing = set()
for it in train_items:
    existing.add(it.get("path") or it.get("img"))

added_nuclei = [it for it in nuclei_items  if it["path"] not in existing]
added_aug    = [it for it in augment_items if it["path"] not in existing]

train_items.extend(added_nuclei)
for it in added_nuclei: existing.add(it["path"])
train_items.extend(added_aug)
for it in added_aug:    existing.add(it["path"])

random.shuffle(train_items)

print(f"[append] nuclei added: {len(added_nuclei)} / {len(nuclei_items)}")
print(f"[append] augment added: {len(added_aug)} / {len(augment_items)}")

# 5) Normalize keys (so downstream code can rely on 'path'/'mask_path')
def _normalize_items(items):
    fixed = []
    for it in items:
        img_path  = it.get("path") or it.get("img") or it.get("image") or it.get("image_path")
        mask_path = it.get("mask_path") or it.get("mask")
        fixed.append({
            "path": str(img_path) if img_path is not None else None,
            "img":  str(img_path) if img_path is not None else None,
            "mask": (str(mask_path) if mask_path is not None else None),
            "mask_path": (str(mask_path) if mask_path is not None else None),
            "label": int(it.get("label", 0)),
            "src": it.get("src", "unknown"),
        })
    return fixed

train_items = _normalize_items(train_items)

# 6) Show counts by src to verify we really have augment/nuclei now
from collections import Counter
c = Counter(it.get("src","unknown") for it in train_items)
print("Train by src:", dict(c))

# A quick peek at examples (optional)
print("Example augment item:",
      next((it for it in train_items if it["src"]=="augment"), None))
print("Example nuclei item:",
      next((it for it in train_items if it["src"]=="nuclei"), None))


In [None]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

def dinov2_norm(img_t):
    return TF.normalize(img_t, IMAGENET_MEAN, IMAGENET_STD)

# ===== Build train/valid datasets using REAL augment path + optional on-the-fly augs =====
from torch.utils.data import ConcatDataset, DataLoader

# ---- Robust source tagging (prefer existing 'src'; else detect by path) ----
AUG_PATH_TOKEN    = "/kaggle/input/data-augment"
NUCLEI_PATH_TOKEN = "/kaggle/input/nuclei-segmentation-in-microscope-cell-images/nuclei/nuclei"  # lowercase compare
RECODAI_TOKEN     = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"

def _src_of(it):
    # 1) honor pre-set src if present
    if "src" in it and it["src"]:
        return it["src"]
    # 2) detect by path
    p = (it.get("path") or it.get("img") or "").lower()
    if AUG_PATH_TOKEN in p:
        return "augment"
    if NUCLEI_PATH_TOKEN in p:
        return "nuclei"
    if RECODAI_TOKEN in p:
        return "recodai"
    return "unknown"

# ---- Split training items by source ----
train_by_src = {"augment": [], "nuclei": [], "recodai": [], "unknown": []}
for it in train_items:
    train_by_src[_src_of(it)].append(it)

print("Train split by src:")
for k in ["recodai", "nuclei", "augment", "unknown"]:
    print(f"  {k:8s}: {len(train_by_src[k])}")

# ---- Build datasets ----
# Base = real data (no pre-generated augmentations)
_base_pool = train_by_src["recodai"] + train_by_src["nuclei"] + train_by_src["unknown"]
_base_train_ds = ForgeryDatasetV2(
    _base_pool,
    image_size=IMAGE_SIZE,
    is_train=False  # deterministic branch (no random augs)
)

# Augment-from-disk = ONLY the files under /kaggle/input/data-augment
_aug_disk_pool = train_by_src["augment"]
if len(_aug_disk_pool) == 0:
    print("[WARN] No augment (disk) items found. Did the 'ADD TRAIN-ONLY DATA' cell run with the corrected paths?")
_aug_disk_ds = ForgeryDatasetV2(
    _aug_disk_pool,
    image_size=IMAGE_SIZE,
    is_train=False  # these are already augmented renders
)

# (Optional) On-the-fly augmentation branch (stochastic)
# NOTE: Make sure your AugmentedWrapper accepts **tensor** images; adapt if it expects PIL.
DO_ON_THE_FLY = True
if DO_ON_THE_FLY:
    _aug_onthefly_ds = AugmentedWrapper(
        ForgeryDatasetV2(_base_pool, image_size=IMAGE_SIZE, is_train=True)
    )
    train_ds = ConcatDataset([_base_train_ds, _aug_disk_ds, _aug_onthefly_ds])
else:
    train_ds = ConcatDataset([_base_train_ds, _aug_disk_ds])

print("Dataset sizes:")
print("  base:            ", len(_base_train_ds))
print("  augment (disk):  ", len(_aug_disk_ds))
if DO_ON_THE_FLY:
    print("  augment (OTF):   ", len(_aug_onthefly_ds))
print("  TOTAL train:     ", len(train_ds))

# ---- Build loaders ----
valid_ds = ForgeryDatasetV2(val_items, image_size=IMAGE_SIZE, is_train=False, transforms=dinov2_norm)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True
)

In [None]:
# immediately before your training loop
print("len(train_ds) =", len(train_ds))
print("len(valid_ds) =", len(valid_ds))

In [None]:
train_dataset = ForgeryDatasetV2(
    train_items,
    transforms=train_transforms if 'train_transforms' in globals() else None,
    image_size=IMAGE_SIZE if 'IMAGE_SIZE' in globals() else (512,512),
    is_train=True
)

valid_dataset = ForgeryDatasetV2(
    val_items,
    transforms=valid_transforms if 'valid_transforms' in globals() else None,
    image_size=IMAGE_SIZE if 'IMAGE_SIZE' in globals() else (512,512),
    is_train=False
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=VAL_BATCH_SIZE if 'VAL_BATCH_SIZE' in globals() else BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [None]:

# === VALIDATION HELPERS (ROBUST) — REPLACE ===
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import f1_score

def _batch_to_tensors(batch, device):
    """
    Returns: imgs [B,3,H,W], labels [B] (0/1), optional masks [B,1,H,W] or None.
    Tries to infer from common batch shapes/keys.
    """
    imgs = None
    labels = None
    masks = None

    # Case A: dict-like batch
    if isinstance(batch, dict):
        # images
        if "image" in batch:
            imgs = batch["image"]
        elif "images" in batch:
            imgs = batch["images"]
        else:
            raise KeyError("Could not find images in batch (expected keys: 'image' or 'images').")

        # masks (optional)
        if "mask" in batch:
            masks = batch["mask"]
        elif "masks" in batch:
            masks = batch["masks"]

        # image-level labels (optional)
        if "image_label" in batch:
            labels = batch["image_label"]
        elif "labels" in batch:
            labels = batch["labels"]

    # Case B: tuple/list batch
    elif isinstance(batch, (list, tuple)):
        # Common patterns:
        #  (images,) | (images, masks) | (images, labels) | (images, masks, labels)
        if len(batch) == 1:
            imgs = batch[0]
        elif len(batch) == 2:
            imgs, second = batch
            if torch.is_tensor(second):
                if second.ndim >= 3:
                    masks = second
                else:
                    labels = second
            else:
                labels = second
        elif len(batch) >= 3:
            imgs, masks, labels = batch[0], batch[1], batch[2]
        else:
            raise TypeError(f"Unsupported batch structure with len={len(batch)}")
    else:
        raise TypeError(f"Unsupported batch type: {type(batch)}")

    imgs = imgs.to(device)

    if masks is not None:
        if masks.ndim == 3:
            masks = masks.unsqueeze(1)
        masks = masks.to(device).float()

    if labels is None:
        if masks is None:
            raise ValueError("No image-level labels found and no masks to derive them from.")
        labels = (masks.view(masks.size(0), -1).max(dim=1).values > 0.5).long()

    labels = labels.detach().cpu().numpy().astype(np.uint8)
    return imgs, labels, masks

@torch.no_grad()
def collect_val_probs(model, val_loader, device):
    """
    Returns:
      prob_masks: FloatTensor [N,1,IMAGE_SIZE,IMAGE_SIZE] in [0,1]
      labels:     np.array [N] (0/1)
    """
    model.eval()
    all_probs = []
    all_labels = []

    for batch in val_loader:
        imgs, labels, _ = _batch_to_tensors(batch, device)

        logits = model(imgs)                       # expect [B,1,h',w'] (segmentation)
        if isinstance(logits, dict) and "logits" in logits:
            logits = logits["logits"]
        if logits.ndim == 3:                       # [B,h',w'] -> [B,1,h',w']
            logits = logits.unsqueeze(1)

        probs = torch.sigmoid(logits)

        if probs.shape[-2:] != (IMAGE_SIZE, IMAGE_SIZE):
            probs = F.interpolate(probs, size=(IMAGE_SIZE, IMAGE_SIZE), mode='bilinear', align_corners=False)

        all_probs.append(probs.cpu())
        all_labels.append(labels)

    prob_masks = torch.cat(all_probs, dim=0)
    labels = np.concatenate(all_labels, axis=0)
    return prob_masks, labels

def image_level_preds(prob_masks, pixel_thr, area_frac):
    with torch.no_grad():
        bin_masks = (prob_masks >= pixel_thr).float()
        if USE_MORPH:
            outs = []
            for i in range(bin_masks.shape[0]):
                outs.append(maybe_morphology(bin_masks[i:i+1]))
            bin_masks = torch.stack(outs, dim=0)
        area = bin_masks.mean(dim=[1,2,3]).cpu().numpy()
    return (area >= area_frac).astype(np.uint8)

def image_f1(prob_masks, labels, pixel_thr, area_frac):
    preds = image_level_preds(prob_masks, pixel_thr, area_frac)
    return f1_score(labels, preds)


In [None]:

# === GRID SEARCH (NEW) ===
import numpy as np

def grid_search_f1(prob_masks, labels, thr_values=None, area_values=None):
    if thr_values is None:
        thr_values = np.linspace(0.20, 0.80, 13)      # 0.20..0.80 step ≈0.05
    if area_values is None:
        area_values = np.geomspace(5e-4, 2e-2, 10)    # 0.05% .. 2% of pixels

    best = {"f1": -1.0, "thr": None, "area": None}
    for t in thr_values:
        for a in area_values:
            f1 = image_f1(prob_masks, labels, t, a)
            if f1 > best["f1"]:
                best.update({"f1": float(f1), "thr": float(t), "area": float(a)})
    print(f"[GRID] best F1={best['f1']:.4f} at thr={best['thr']:.3f}, area={best['area']:.5f}")
    return best


In [None]:
# === TRAIN/VALIDATE & SAVE-BEST (FIXED: no scheduler kwarg) ===
model = build_model().to(DEVICE)

# Optimizer & optional scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

USE_SCHEDULER = True  # set False if you don't want LR scheduling
if USE_SCHEDULER:
    from torch.optim.lr_scheduler import CosineAnnealingLR
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

best_score = -1.0
num_epochs = 2  # change as you like

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    # ---- TRAIN ----
    model.train()
    # IMPORTANT: do NOT pass 'scheduler=' here, since your train_one_epoch doesn't accept it
    train_loss = train_one_epoch(model, train_loader, optimizer, DEVICE)
    print(f"train_loss: {train_loss:.4f}")

    # ---- VALIDATE & CALIBRATE ----
    val_probs, val_labels = collect_val_probs(model, valid_loader, DEVICE)
    search = grid_search_f1(val_probs, val_labels)
    val_score = search['f1']
    print(f"[VAL] F1={val_score:.4f} at thr={search['thr']:.3f}, area={search['area']:.5f}")

    preds = image_level_preds(val_probs, search['thr'], search['area'])
    print(f"[VAL] % predicted forged: {100*preds.mean():.2f}%  (GT {100*val_labels.mean():.2f}%)")

    # ---- SAVE IF BEST (by leaderboard metric) ----
    if val_score > best_score:
        best_score = val_score
        to_save = {
            "state_dict": model.state_dict(),
            "image_size": IMAGE_SIZE,
            "mean": MEAN, "std": STD,
            "use_morph": USE_MORPH, "morph_kernel": MORPH_KERNEL,
            "pixel_thr": search['thr'],
            "area_frac": search['area'],
            "val_f1": best_score,
            "arch": getattr(model, "name", "model"),
            "epoch": epoch,
        }
        torch.save(to_save, CKPT_PATH)
        print(f"✅ Saved new best to {CKPT_PATH} | F1={best_score:.4f}")

    # ---- STEP SCHEDULER (outside train_one_epoch) ----
    if USE_SCHEDULER:
        scheduler.step()
        # (optional) show current LR
        if hasattr(optimizer, "param_groups"):
            print("lr:", optimizer.param_groups[0]["lr"])


In [None]:

# === SANITY CHECK (NEW) ===
# Load and evaluate on validation with saved thresholds
ckpt = torch.load(CKPT_PATH, map_location="cpu")

# Rebuild/Load model
# Expect a build_model() function defined earlier in the notebook
try:
    model = build_model()
except NameError as e:
    raise RuntimeError("build_model() must be defined before running sanity check.") from e

model.load_state_dict(ckpt["state_dict"])
model.to(DEVICE).eval()

# Restore preprocessing/thresholds
IMAGE_SIZE = int(ckpt["image_size"])
MEAN, STD = tuple(ckpt["mean"]), tuple(ckpt["std"])
USE_MORPH = bool(ckpt["use_morph"])
MORPH_KERNEL = int(ckpt["morph_kernel"])
PIXEL_THR = float(ckpt["pixel_thr"])
AREA_FRAC = float(ckpt["area_frac"])

img_transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
])

with torch.no_grad():
    val_probs, val_labels = collect_val_probs(model, valid_loader, DEVICE)
    preds = image_level_preds(val_probs, PIXEL_THR, AREA_FRAC)
    print(f"Sanity — GT forged rate:   {100*val_labels.mean():.2f}%")
    print(f"Sanity — Pred forged rate: {100*preds.mean():.2f}%")
    print(f"Sanity — Val F1 (fixed):   {image_f1(val_probs, val_labels, PIXEL_THR, AREA_FRAC):.4f}")


In [None]:

# === INFERENCE / TEST (REPLACE) ===
# Load the best checkpoint and rebuild preprocessing exactly.

ckpt = torch.load(CKPT_PATH, map_location="cpu")
# Expect a build_model() defined earlier
model = build_model()
model.load_state_dict(ckpt["state_dict"])
model.to(DEVICE).eval()

IMAGE_SIZE = ckpt["image_size"]
MEAN, STD = tuple(ckpt["mean"]), tuple(ckpt["std"])
USE_MORPH = bool(ckpt["use_morph"])
MORPH_KERNEL = int(ckpt["morph_kernel"])
PIXEL_THR = float(ckpt["pixel_thr"])
AREA_FRAC = float(ckpt["area_frac"])

# rebuild transforms from saved params (identical to train/val)
img_transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
])

@torch.no_grad()
def infer_batch(imgs):  # imgs: tensor [B,3,H,W] already transformed with img_transform
    logits = model(imgs.to(DEVICE))  # [B,1,h',w']
    if isinstance(logits, dict) and "logits" in logits:
        logits = logits["logits"]
    if logits.ndim == 3:
        logits = logits.unsqueeze(1)
    probs = torch.sigmoid(logits)

    if probs.shape[-2:] != (IMAGE_SIZE, IMAGE_SIZE):
        probs = F.interpolate(probs, size=(IMAGE_SIZE, IMAGE_SIZE), mode='bilinear', align_corners=False)

    bin_masks = (probs >= PIXEL_THR).float()
    if USE_MORPH:
        outs = []
        for i in range(bin_masks.shape[0]):
            outs.append(maybe_morphology(bin_masks[i:i+1]))
        bin_masks = torch.stack(outs, dim=0)

    # image-level decision: forged if area >= AREA_FRAC
    area = bin_masks.mean(dim=[1,2,3])  # [B]
    y_pred = (area >= AREA_FRAC).long()  # 1=forged, 0=authentic
    return y_pred.cpu().numpy(), probs.cpu()  # (labels, prob masks)

# Example loop over your test loader:
image_level_predictions = []
for batch in test_loader:
    if isinstance(batch, dict):
        imgs = batch["image"]
    elif isinstance(batch, (list, tuple)):
        imgs = batch[0]
    else:
        raise TypeError("Unsupported test batch type; expected dict or tuple/list.")
    y_pred, _ = infer_batch(imgs)
    image_level_predictions.extend(y_pred.tolist())

# TODO: Format & write your submission from 'image_level_predictions'


In [None]:

# === DEBUG BATCH STRUCTURE (OPTIONAL) ===
b = next(iter(valid_loader))
print(type(b))
if isinstance(b, dict):
    print("dict keys:", b.keys())
elif isinstance(b, (list, tuple)):
    print("tuple/list length:", len(b))
    for i, x in enumerate(b):
        if torch.is_tensor(x):
            print(f"  idx {i}: tensor shape {tuple(x.shape)}, dtype {x.dtype}")
        else:
            print(f"  idx {i}: type {type(x)}")
else:
    print("unknown batch type")


In [None]:
# === EXPORT MODEL (NEW) ===
# Save the fully trained & calibrated model checkpoint to share or reuse

export_path = "dinov2_forgery_detector_final.pth"

# Load the best calibrated checkpoint (from training)
ckpt = torch.load(CKPT_PATH, map_location="cpu")

# If not already loaded in memory, rebuild and load model weights
try:
    model
except NameError:
    model = build_model()
    model.load_state_dict(ckpt["state_dict"])
    model.to(DEVICE).eval()

# Create export dictionary (same structure as training save)
export_ckpt = {
    "arch": getattr(model, "name", "dinov2_forgery_detector"),
    "state_dict": model.state_dict(),
    "image_size": int(ckpt.get("image_size", IMAGE_SIZE)),
    "mean": tuple(ckpt.get("mean", MEAN)),
    "std": tuple(ckpt.get("std", STD)),
    "use_morph": bool(ckpt.get("use_morph", USE_MORPH)),
    "morph_kernel": int(ckpt.get("morph_kernel", MORPH_KERNEL)),
    "pixel_thr": float(ckpt.get("pixel_thr", PIXEL_THR)),
    "area_frac": float(ckpt.get("area_frac", AREA_FRAC)),
    "val_f1": float(ckpt.get("val_f1", 0.0)),
}

# Save export
torch.save(export_ckpt, export_path)
print(f"✅ Model exported to {export_path}")

# Optional: check file size
import os
print(f"File size: {os.path.getsize(export_path)/1e6:.2f} MB")

# Optional: verify reloading works
test_load = torch.load(export_path, map_location="cpu")
print("Reload OK — keys:", list(test_load.keys()))


In [None]:
import torch
import torch.nn as nn

# 1) Reload best checkpoint cleanly ow
export_model = build_model().to(DEVICE)
state = torch.load(f"{OUT_DIR}/dinov2_vitb14_linear_fold{FOLD_TO_RUN}.pt", map_location=DEVICE)

export_model.load_state_dict(state["state_dict"], strict=True)
export_model.eval()

# 2) Wrap the model so forward returns only the logits tensor (works for dict OR tensor)
class LogitsOut(nn.Module):
    def __init__(self, m: nn.Module):
        super().__init__()
        self.m = m.eval()
        for p in self.m.parameters():
            p.requires_grad_(False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.m(x)
        if isinstance(y, dict) and "out" in y:
            return y["out"]            # torchvision seg + our DINOv2Seg return {"out": logits}
        if torch.is_tensor(y):
            return y                   # if a plain tensor is returned
        raise RuntimeError(f"Unexpected model output type: {type(y)}")

wrapper = LogitsOut(export_model).to(DEVICE)

# 3) Script first (preferred); if that fails, fall back to trace
example = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
ts_path = f"{OUT_DIR}/model_fold{FOLD_TO_RUN}_ts.pt"  # generic name

try:
    with torch.no_grad():
        scripted = torch.jit.script(wrapper)
        _ = scripted(example)
        scripted.save(ts_path)
        print("Saved TorchScript (script):", ts_path)
except Exception as e:
    print("Script failed, falling back to trace. Reason:", e)
    with torch.no_grad():
        traced = torch.jit.trace(wrapper, example, strict=False)
        _ = traced(example)
        traced.save(ts_path)
        print("Saved TorchScript (trace):", ts_path)

# 4) Optional ONNX export (trace-friendly; ViTs may need trace to succeed)
onnx_path = f"{OUT_DIR}/model_fold{FOLD_TO_RUN}.onnx"
try:
    import onnx  # ensures onnx is installed
    with torch.no_grad():
        torch.onnx.export(
            wrapper,
            example,
            onnx_path,
            input_names=["input"],
            output_names=["logits"],
            opset_version=17,  # 17 or 18; try 18 if 17 complains
            do_constant_folding=True,
            dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
        )
    print("Saved ONNX:", onnx_path)
except Exception as e:
    print("ONNX export skipped:", e)
