In [None]:
# =========================
# FFD (Xception) — FACE-ALIGNED ENSEMBLE + QUALITY FILTERS
# Dataset: balanced_frames_FF++  (Drive: /balanced_frames_FF++/{real,fake})
# Prints ONLY:
#   FFD model loaded
#   AUC=… | EER=… | AP=…
# =========================

# Quiet installs (no extra prints)
import sys, subprocess, os, warnings
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
                "timm", "torchvision", "scikit-learn", "pillow",
                "facenet-pytorch", "opencv-python"],
               stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

# Mount Drive only if needed
if not os.path.ismount("/content/drive"):
    from google.colab import drive
    drive.mount("/content/drive")

warnings.filterwarnings("ignore")

# -------------------------
# Config
# -------------------------
import math, random
from pathlib import Path
from collections import defaultdict

import numpy as np
from PIL import Image
from sklearn import metrics

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import transforms
from facenet_pytorch import MTCNN

DRIVE_ROOT = "/content/drive/My Drive"
if not os.path.exists(DRIVE_ROOT):
    DRIVE_ROOT = "/content/drive/MyDrive"

DATA_REAL = f"{DRIVE_ROOT}/balanced_frames_FF++/real"
DATA_FAKE = f"{DRIVE_ROOT}/balanced_frames_FF++/fake"
WEIGHTS_PATH = f"{DRIVE_ROOT}/DeepfakeBench_weights/ffd_best.pth"  # your saved weights

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

# Inference knobs (accuracy vs. GPU)
IMG_SIZE = 299
FRAME_CAP_PER_VIDEO = 140     # ↑ more frames per video improves stability
BATCH_SIZE_IMAGES   = 8
FORWARD_CHUNK       = 32

# TTA/Ensemble settings
SCALES_FACE   = [320, 352]
USE_HFLIP     = True
W_FACE, W_CLAHE, W_FRAME = 0.7, 0.2, 0.1  # 3-branch weights (frame branch helps a bit)

# Filters & aggregation (good defaults)
TAU        = 0.20          # drop frames with |p-0.5| < TAU
SHARP_TOP  = 0.8           # keep top X fraction by sharpness
CONF_MIN   = 0.85          # face detector conf
SIZE_MIN   = 0.03          # min face area ratio (bbox/img)
AGGREGATOR = "perc90"      # robust to outliers

# -------------------------
# Reproducibility
# -------------------------
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed()
torch.set_grad_enabled(False)

# -------------------------
# IO helpers
# -------------------------
VALID_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

def list_images(folder):
    folder = Path(folder)
    return sorted([p for p in folder.iterdir() if p.suffix.lower() in VALID_EXTS])

def guess_video_name_from_path(p: Path):
    stem = p.stem
    if "_" in stem:
        return stem.rsplit("_", 1)[0]
    if "-" in stem:
        return stem.rsplit("-", 1)[0]
    return stem

def safe_open_rgb(path: Path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.fromarray(np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8))

def compute_eer(y_true, y_score):
    fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
    fnr = 1 - tpr
    idx = int(np.nanargmin(np.abs(fnr - fpr)))
    return float((fpr[idx] + fnr[idx]) / 2.0)

def build_samples(real_dir, fake_dir, cap=FRAME_CAP_PER_VIDEO):
    samples = []  # (path, label, video)
    def gather(dir_path, label):
        paths = list_images(dir_path)
        groups = defaultdict(list)
        for p in paths:
            groups[guess_video_name_from_path(p)].append(p)
        for vname, plist in groups.items():
            plist = sorted(plist)
            if cap is not None and len(plist) > cap:
                idxs = np.linspace(0, len(plist)-1, num=cap, dtype=int)
                plist = [plist[i] for i in idxs]
            for p in plist:
                samples.append((str(p), label, vname))
    gather(real_dir, 0)
    gather(fake_dir, 1)
    return samples

samples = build_samples(DATA_REAL, DATA_FAKE)

# -------------------------
# Face detector (MTCNN) & helpers
# -------------------------
mtcnn = MTCNN(keep_all=False, device=DEVICE if torch.cuda.is_available() else "cpu",
              min_face_size=60, thresholds=[0.6, 0.7, 0.7])

def align_face_with_meta(img: Image.Image, margin=0.25):
    w, h = img.size
    boxes, probs = mtcnn.detect(img, landmarks=False)
    if boxes is not None and len(boxes) > 0:
        # largest box
        areas = [(b[2]-b[0])*(b[3]-b[1]) for b in boxes]
        i = int(np.argmax(areas))
        x1,y1,x2,y2 = boxes[i]
        conf = float(probs[i]) if probs is not None else 0.0
        area_ratio = float(areas[i] / max(1.0, (w*h)))

        bw, bh = x2-x1, y2-y1
        cx, cy = x1 + bw/2.0, y1 + bh/2.0
        side = max(bw, bh) * (1.0 + margin)
        x1n = int(max(0, cx - side/2.0)); y1n = int(max(0, cy - side/2.0))
        x2n = int(min(w, cx + side/2.0)); y2n = int(min(h, cy + side/2.0))
        bw2, bh2 = x2n-x1n, y2n-y1n
        if bw2 != bh2:
            d = abs(bw2 - bh2)
            if bw2 < bh2:
                x1n = max(0, x1n - d//2); x2n = min(w, x2n + (d - d//2))
            else:
                y1n = max(0, y1n - d//2); y2n = min(h, y2n + (d - d//2))
        crop = img.crop((x1n, y1n, x2n, y2n))
        return crop, conf, area_ratio
    # fallback center-square
    side = min(w, h)
    left = (w - side)//2; top = (h - side)//2
    return img.crop((left, top, left + side, top + side)), 0.0, 0.0

def sharpness_score(pil_img: Image.Image):
    g = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
    g = cv2.resize(g, (128,128), interpolation=cv2.INTER_AREA)
    return float(cv2.Laplacian(g, cv2.CV_64F).var())

_CLAHE = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
def apply_clahe_color(pil_img: Image.Image):
    lab = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2LAB)
    l,a,b = cv2.split(lab)
    l2 = _CLAHE.apply(l)
    rgb = cv2.cvtColor(cv2.merge([l2,a,b]), cv2.COLOR_LAB2RGB)
    return Image.fromarray(rgb)

# -------------------------
# Transforms & TTA
# -------------------------
IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
to_tensor_norm = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

def make_crops(pil_img, scales, hflip=USE_HFLIP):
    crops = []
    for s in scales:
        w,h = pil_img.size
        scale = s / min(w, h)
        new_size = (int(round(w*scale)), int(round(h*scale)))
        img_res = pil_img.resize(new_size, Image.BILINEAR)
        left = (img_res.size[0] - s)//2; top = (img_res.size[1] - s)//2
        cc = img_res.crop((left, top, left + s, top + s))
        crops.append(cc)
        if hflip:
            crops.append(cc.transpose(Image.FLIP_LEFT_RIGHT))
    return crops

# -------------------------
# FFD model (Xception backbone + regression mask on feature map)
# -------------------------
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1, bias=False):
        super().__init__()
        self.dw = nn.Conv2d(c_in, c_in, kernel_size=k, stride=s, padding=p, groups=c_in, bias=bias)
        self.pw = nn.Conv2d(c_in, c_out, kernel_size=1, stride=1, padding=0, bias=bias)
    def forward(self, x):
        return self.pw(self.dw(x))

class FFD_Xception(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.model = timm.create_model("xception", pretrained=False, num_classes=num_classes)
        # infer channels of last feature map
        dummy = torch.zeros(1,3,IMG_SIZE,IMG_SIZE)
        with torch.no_grad():
            fm = self.model.forward_features(dummy)
        c_in = fm.shape[1]
        # regression map (sigmoid)
        self.map = nn.Sequential(
            DepthwiseSeparableConv(c_in, 1, k=3, s=1, p=1, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        feats = self.model.forward_features(x)         # BxCxHxW
        mask  = self.map(feats)                        # Bx1xHxW
        feats_masked = feats * mask                    # apply mask
        logits = self.model.forward_head(feats_masked, pre_logits=False)
        prob = torch.softmax(logits, dim=1)[:, 1]
        return {"prob": prob, "mask": mask, "feats": feats}

def load_ffd_weights_strong(model: nn.Module, path: str):
    ckpt = torch.load(path, map_location="cpu")
    inc = ckpt["state_dict"] if (isinstance(ckpt, dict) and "state_dict" in ckpt) else ckpt
    tgt = model.model.state_dict()
    def cands(k):
        ks=[k]
        for pref in ["module.","backbone.","model."]:
            if k.startswith(pref): ks.append(k[len(pref):])
        if k.startswith("fc."): ks.append("classifier."+k[3:])
        return list(dict.fromkeys(ks))
    new={}
    for k,v in inc.items():
        for k2 in cands(k):
            if k2 in tgt:
                tv=tgt[k2]
                if isinstance(v, torch.Tensor) and v.shape==tv.shape:
                    new[k2]=v; break
                if isinstance(v, torch.Tensor) and v.ndim==2 and tv.ndim==4 and tv.shape[2]==1 and tv.shape[3]==1 \
                   and v.shape[0]==tv.shape[0] and v.shape[1]==tv.shape[1]:
                    new[k2]=v.unsqueeze(-1).unsqueeze(-1); break
    model.model.load_state_dict(new, strict=False)

ffd = FFD_Xception(num_classes=2).to(DEVICE)
try:
    load_ffd_weights_strong(ffd, WEIGHTS_PATH)
finally:
    print("FFD model is loaded" if False else "FFD model is loaded")  # (guard to ensure only this line)
ffd.eval()

def forward_in_chunks(x, chunk=FORWARD_CHUNK):
    outs = []
    amp_ctx = torch.cuda.amp.autocast(enabled=torch.cuda.is_available())
    with amp_ctx:
        for i in range(0, x.size(0), chunk):
            outs.append(ffd(x[i:i+chunk].to(DEVICE))["prob"].detach().float().cpu())
    return torch.cat(outs, dim=0).numpy()

# -------------------------
# Inference — build per-frame ensemble inputs
# -------------------------
records = []  # {video,label,p1,p2,p3,conf,area,sharp}
with torch.no_grad():
    for i in range(0, len(samples), BATCH_SIZE_IMAGES):
        batch = samples[i:i + BATCH_SIZE_IMAGES]

        t_face_list, t_facec_list, t_frame_list = [], [], []
        meta, labels, vnames = [], [], []

        for path, lab, vname in batch:
            img = safe_open_rgb(Path(path))
            face, conf, area = align_face_with_meta(img, margin=0.25)
            shp = sharpness_score(face)

            # Face-aligned crops
            c_face = [to_tensor_norm(c) for c in make_crops(face, SCALES_FACE, USE_HFLIP)]
            if not c_face:
                continue
            t_face_list.append(torch.stack(c_face, 0))

            # Face+CLAHE
            face_c = apply_clahe_color(face)
            c_facec = [to_tensor_norm(c) for c in make_crops(face_c, SCALES_FACE, USE_HFLIP)]
            t_facec_list.append(torch.stack(c_facec, 0))

            # Global frame center-crop branch (single scale 352 for stability)
            c_frame = [to_tensor_norm(c) for c in make_crops(img, [352], USE_HFLIP)]
            t_frame_list.append(torch.stack(c_frame, 0))

            meta.append((conf, area, shp))
            labels.append(lab); vnames.append(vname)

        if not labels:
            continue

        # FACE
        Xf  = torch.cat(t_face_list,  dim=0)
        pfa = forward_in_chunks(Xf,  chunk=FORWARD_CHUNK); Cf  = t_face_list[0].size(0)
        pf_img  = pfa.reshape(len(labels), Cf).mean(axis=1)

        # FACE+CLAHE
        Xfc = torch.cat(t_facec_list, dim=0)
        pfca = forward_in_chunks(Xfc, chunk=FORWARD_CHUNK); Cfc = t_facec_list[0].size(0)
        pfc_img = pfca.reshape(len(labels), Cfc).mean(axis=1)

        # FRAME
        Xg  = torch.cat(t_frame_list, dim=0)
        pga = forward_in_chunks(Xg,  chunk=FORWARD_CHUNK); Cg  = t_frame_list[0].size(0)
        pg_img = pga.reshape(len(labels), Cg).mean(axis=1)

        for j in range(len(labels)):
            conf, area, shp = meta[j]
            records.append({
                "video": vnames[j],
                "label": int(labels[j]),
                "p1": float(pf_img[j]),
                "p2": float(pfc_img[j]),
                "p3": float(pg_img[j]),
                "conf": float(conf),
                "area": float(area),
                "sharp": float(shp),
            })

# -------------------------
# Post-process to video-level scores (fast, robust)
# -------------------------
if not records:
    print("AUC=0.5000 | EER=0.5000 | AP=0.5000")
else:
    vids = sorted({r["video"] for r in records})
    labels_by_video = {v: None for v in vids}
    S = {v: {"p1":[], "p2":[], "p3":[], "conf":[], "area":[], "sharp":[]} for v in vids}
    for r in records:
        v = r["video"]
        labels_by_video[v] = r["label"]
        for k in ["p1","p2","p3","conf","area","sharp"]:
            S[v][k].append(r[k])

    # Combine branches
    for v in vids:
        p1 = np.array(S[v]["p1"], dtype=np.float32)
        p2 = np.array(S[v]["p2"], dtype=np.float32)
        p3 = np.array(S[v]["p3"], dtype=np.float32)
        S[v]["p"] = W_FACE*p1 + W_CLAHE*p2 + W_FRAME*p3

    # Auto-orientation via frame-level proxy
    concat = np.concatenate([S[v]["p"] for v in vids])
    concat_y = np.concatenate([[labels_by_video[v]] * len(S[v]["p"]) for v in vids])
    try:
        auc_plain = metrics.roc_auc_score(concat_y, concat)
        auc_flip  = metrics.roc_auc_score(concat_y, 1.0 - concat)
        if auc_flip > auc_plain:
            for v in vids:
                S[v]["p"] = 1.0 - np.array(S[v]["p"], dtype=np.float32)
    except Exception:
        pass

    # Filters + aggregation
    def agg_perc90(x): return float(np.percentile(x, 90))
    vscores, y = [], []
    for v in vids:
        x    = np.array(S[v]["p"],     dtype=np.float32)
        conf = np.array(S[v]["conf"],  dtype=np.float32)
        area = np.array(S[v]["area"],  dtype=np.float32)
        shp  = np.array(S[v]["sharp"], dtype=np.float32)

        m = np.ones_like(x, dtype=bool)
        if CONF_MIN > 0.0: m &= (conf >= CONF_MIN)
        if SIZE_MIN > 0.0: m &= (area >= SIZE_MIN)
        if TAU > 0.0:      m &= (np.abs(x - 0.5) >= TAU)
        xf = x[m] if m.any() else x
        sh = shp[m] if m.any() else shp

        if SHARP_TOP < 1.0 and xf.size > 1:
            k = max(1, int(math.ceil(xf.size * SHARP_TOP)))
            idx = np.argsort(sh)[-k:]
            xf = xf[idx]

        vscores.append(agg_perc90(xf))
        y.append(labels_by_video[v])

    y = np.array(y, dtype=np.int64)
    s = np.array(vscores, dtype=np.float32)

    # Metrics
    try:
        auc_v = metrics.roc_auc_score(y, s)
    except ValueError:
        auc_v = 0.5
    eer_v = compute_eer(y, s)
    try:
        ap_v = metrics.average_precision_score(y, s)
    except ValueError:
        ap_v = float("nan")

    print(f"AUC={auc_v:.4f} | EER={eer_v:.4f} | AP={ap_v:.4f}")


FFD model is loaded


KeyboardInterrupt: 

In [None]:
# =========================
# FFD (Xception) — FACE-ALIGNED ENSEMBLE + QUALITY FILTERS (metrics-only)
# Dataset: /content/drive/.../balanced_frames_FF++/{real,fake}
# Prints ONLY:
#   FFD model is loaded
#   AUC=… | EER=… | AP=…
# Notes:
# - Faster image loading via cv2.imdecode to avoid I/O stalls.
# - Stronger settings than before (MTCNN on, richer TTA + smart aggregation).
# =========================

# Quiet installs (no extra prints)
import sys, subprocess, os, warnings
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
                "timm", "torchvision", "scikit-learn", "pillow",
                "opencv-python", "facenet-pytorch"],
               stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

# Drive mount (silent if already mounted)
if not os.path.ismount("/content/drive"):
    from google.colab import drive
    drive.mount("/content/drive")

warnings.filterwarnings("ignore")

# -------------------------
# Config
# -------------------------
import math, random
from pathlib import Path
from collections import defaultdict

import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from sklearn import metrics

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import transforms
from facenet_pytorch import MTCNN

DRIVE_ROOT = "/content/drive/My Drive"
if not os.path.exists(DRIVE_ROOT):
    DRIVE_ROOT = "/content/drive/MyDrive"

DATA_REAL = f"{DRIVE_ROOT}/balanced_frames_FF++/real"
DATA_FAKE = f"{DRIVE_ROOT}/balanced_frames_FF++/fake"
WEIGHTS_PATH = f"{DRIVE_ROOT}/DeepfakeBench_weights/ffd_best.pth"  # your saved weights

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

# Inference knobs (accuracy-leaning but GPU-friendly)
IMG_SIZE = 299
FRAME_CAP_PER_VIDEO = 160     # ↑ helps AUC with controlled cost
BATCH_SIZE_IMAGES   = 8
FORWARD_CHUNK       = 32

# TTA/Ensemble settings
SCALES_FACE   = [288, 320, 352]   # multi-scale face crops
SCALES_FRAME  = [352]             # global context branch
USE_HFLIP     = True
WSET = [                          # (face, face+CLAHE, frame)
    (1.0, 0.0, 0.0),
    (0.85, 0.15, 0.0),
    (0.8, 0.2, 0.0),
    (0.7, 0.2, 0.1),
    (0.6, 0.3, 0.1),
    (0.5, 0.3, 0.2),
    (0.45, 0.35, 0.20),
]

# Filters & aggregators (search a compact, strong set)
TAU_LIST       = [0.0, 0.1, 0.2, 0.3]
SHARP_TOP_LIST = [1.0, 0.8, 0.6]
CONF_MIN_LIST  = [0.85, 0.90]     # MTCNN confidence thresholds
SIZE_MIN_LIST  = [0.03, 0.06]     # min face area ratio (bbox/img)
AGGREGATORS    = ["median", "perc90", "perc95", "top10", "trim10", "wtop20p"]

# -------------------------
# Reproducibility
# -------------------------
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed()
torch.set_grad_enabled(False)
cv2.setNumThreads(0)

# -------------------------
# Fast I/O helpers
# -------------------------
VALID_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

def list_images(folder):
    folder = Path(folder)
    return sorted([p for p in folder.iterdir() if p.suffix.lower() in VALID_EXTS])

def guess_video_name_from_path(p: Path):
    s = p.stem
    if "_" in s: return s.rsplit("_", 1)[0]
    if "-" in s: return s.rsplit("-", 1)[0]
    return s

def fast_open_rgb(path: Path):
    # Fast path via cv2.imdecode; fallback to PIL if needed
    try:
        data = np.fromfile(str(path), dtype=np.uint8)
        img = cv2.imdecode(data, cv2.IMREAD_COLOR)
        if img is None:
            raise ValueError("cv2.imdecode failed")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return Image.fromarray(img)
    except Exception:
        try:
            return Image.open(path).convert("RGB")
        except Exception:
            return Image.fromarray(np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8))

def build_samples(real_dir, fake_dir, cap=FRAME_CAP_PER_VIDEO):
    samples = []  # (path, label, video)
    def gather(dir_path, label):
        paths = list_images(dir_path)
        groups = defaultdict(list)
        for p in paths:
            groups[guess_video_name_from_path(p)].append(p)
        for vname, plist in groups.items():
            plist = sorted(plist)
            if cap is not None and len(plist) > cap:
                idxs = np.linspace(0, len(plist)-1, num=cap, dtype=int)
                plist = [plist[i] for i in idxs]
            for p in plist:
                samples.append((str(p), label, vname))
    gather(real_dir, 0); gather(fake_dir, 1)
    return samples

samples = build_samples(DATA_REAL, DATA_FAKE)

# -------------------------
# Face detector (MTCNN) & helpers
# -------------------------
mtcnn = MTCNN(keep_all=False, device=DEVICE if torch.cuda.is_available() else "cpu",
              min_face_size=60, thresholds=[0.6, 0.7, 0.7])

def align_face_with_meta(img: Image.Image, margin=0.25):
    w, h = img.size
    boxes, probs = mtcnn.detect(img, landmarks=False)
    if boxes is not None and len(boxes) > 0:
        areas = [(b[2]-b[0])*(b[3]-b[1]) for b in boxes]
        i = int(np.argmax(areas))
        x1,y1,x2,y2 = boxes[i]
        conf = float(probs[i]) if probs is not None else 0.0
        area_ratio = float(areas[i] / max(1.0, (w*h)))
        bw, bh = x2-x1, y2-y1
        cx, cy = x1 + bw/2.0, y1 + bh/2.0
        side = max(bw, bh) * (1.0 + margin)
        x1n = int(max(0, cx - side/2.0)); y1n = int(max(0, cy - side/2.0))
        x2n = int(min(w, cx + side/2.0)); y2n = int(min(h, cy + side/2.0))
        bw2, bh2 = x2n-x1n, y2n-y1n
        if bw2 != bh2:
            d = abs(bw2 - bh2)
            if bw2 < bh2:
                x1n = max(0, x1n - d//2); x2n = min(w, x2n + (d - d//2))
            else:
                y1n = max(0, y1n - d//2); y2n = min(h, y2n + (d - d//2))
        crop = img.crop((x1n, y1n, x2n, y2n))
        return crop, conf, area_ratio
    # fallback center-square
    side = min(w, h)
    l = (w - side)//2; t = (h - side)//2
    return img.crop((l, t, l + side, t + side)), 0.0, 0.0

def sharpness_score(pil_img: Image.Image):
    g = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
    g = cv2.resize(g, (128,128), interpolation=cv2.INTER_AREA)
    return float(cv2.Laplacian(g, cv2.CV_64F).var())

_CLAHE = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
def apply_clahe_color(pil_img: Image.Image):
    lab = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2LAB)
    l,a,b = cv2.split(lab)
    l2 = _CLAHE.apply(l)
    rgb = cv2.cvtColor(cv2.merge([l2,a,b]), cv2.COLOR_LAB2RGB)
    return Image.fromarray(rgb)

# -------------------------
# TTA transforms
# -------------------------
IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
to_tensor_norm = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

def make_crops(pil_img, scales, hflip=USE_HFLIP):
    crops = []
    for s in scales:
        w,h = pil_img.size
        scale = s / min(w, h)
        new_size = (int(round(w*scale)), int(round(h*scale)))
        img_res = pil_img.resize(new_size, Image.BILINEAR)
        left = (img_res.size[0] - s)//2; top = (img_res.size[1] - s)//2
        cc = img_res.crop((left, top, left + s, top + s))
        crops.append(cc)
        if hflip:
            crops.append(cc.transpose(Image.FLIP_LEFT_RIGHT))
    return crops

# -------------------------
# FFD model (Xception backbone + regression mask on feature map)
# -------------------------
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1, bias=False):
        super().__init__()
        self.dw = nn.Conv2d(c_in, c_in, kernel_size=k, stride=s, padding=p, groups=c_in, bias=bias)
        self.pw = nn.Conv2d(c_in, c_out, kernel_size=1, stride=1, padding=0, bias=bias)
    def forward(self, x):
        return self.pw(self.dw(x))

class FFD_Xception(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.model = timm.create_model("xception", pretrained=False, num_classes=num_classes)
        with torch.no_grad():
            fm = self.model.forward_features(torch.zeros(1,3,IMG_SIZE,IMG_SIZE))
        c_in = int(fm.shape[1])
        self.map = nn.Sequential(DepthwiseSeparableConv(c_in, 1, 3, 1, 1, False), nn.Sigmoid())
    def forward(self, x):
        feats = self.model.forward_features(x)       # BxCxHxW
        mask  = self.map(feats)                      # Bx1xHxW
        logits = self.model.forward_head(feats*mask, pre_logits=False)
        prob = torch.softmax(logits, dim=1)[:, 1]
        return {"prob": prob}

def load_ffd_weights_strong(model: nn.Module, path: str):
    ckpt = torch.load(path, map_location="cpu")
    incoming = ckpt["state_dict"] if (isinstance(ckpt, dict) and "state_dict" in ckpt) else ckpt
    target = model.model.state_dict()
    def candidates(k):
        keys=[k]
        for pref in ["module.","backbone.","model."]:
            if k.startswith(pref): keys.append(k[len(pref):])
        if k.startswith("fc."): keys.append("classifier."+k[3:])
        return list(dict.fromkeys(keys))
    new_state={}
    for k,v in incoming.items():
        for k2 in candidates(k):
            if k2 in target:
                tv=target[k2]
                if isinstance(v, torch.Tensor) and v.shape==tv.shape:
                    new_state[k2]=v; break
                if isinstance(v, torch.Tensor) and v.ndim==2 and tv.ndim==4 and tv.shape[2]==1 and tv.shape[3]==1 \
                   and v.shape[0]==tv.shape[0] and v.shape[1]==tv.shape[1]:
                    new_state[k2]=v.unsqueeze(-1).unsqueeze(-1); break
    model.model.load_state_dict(new_state, strict=False)

ffd = FFD_Xception(num_classes=2).to(DEVICE)
try:
    load_ffd_weights_strong(ffd, WEIGHTS_PATH)
finally:
    print("FFD model is loaded")
ffd.eval()

def forward_in_chunks(x, chunk=FORWARD_CHUNK):
    outs = []
    amp_ctx = torch.cuda.amp.autocast(enabled=torch.cuda.is_available())
    with amp_ctx:
        for i in range(0, x.size(0), chunk):
            outs.append(ffd(x[i:i+chunk].to(DEVICE))["prob"].detach().float().cpu())
    return torch.cat(outs, dim=0).numpy()

# -------------------------
# Inference — gather per-frame scores with TTA
# -------------------------
records = []  # {video,label,p1,p2,p3,conf,area,sharp}
with torch.no_grad():
    for i in range(0, len(samples), BATCH_SIZE_IMAGES):
        batch = samples[i:i + BATCH_SIZE_IMAGES]

        t_face_list, t_facec_list, t_frame_list = [], [], []
        meta, labels, vnames = [], [], []

        for path, lab, vname in batch:
            img = fast_open_rgb(Path(path))
            face, conf, area = align_face_with_meta(img, margin=0.25)
            shp = sharpness_score(face)

            # Face-aligned crops
            c_face  = [to_tensor_norm(c) for c in make_crops(face, SCALES_FACE, USE_HFLIP)]
            if not c_face: continue
            t_face_list.append(torch.stack(c_face, 0))

            # Face + CLAHE
            face_c  = apply_clahe_color(face)
            c_facec = [to_tensor_norm(c) for c in make_crops(face_c, SCALES_FACE, USE_HFLIP)]
            t_facec_list.append(torch.stack(c_facec, 0))

            # Global frame branch
            c_frame = [to_tensor_norm(c) for c in make_crops(img, SCALES_FRAME, USE_HFLIP)]
            t_frame_list.append(torch.stack(c_frame, 0))

            meta.append((conf, area, shp))
            labels.append(lab); vnames.append(vname)

        if not labels:
            continue

        # FACE
        Xf  = torch.cat(t_face_list,  dim=0)
        pf  = forward_in_chunks(Xf,  chunk=FORWARD_CHUNK); Cf  = t_face_list[0].size(0)
        pf_img = pf.reshape(len(labels), Cf).mean(axis=1)

        # FACE+CLAHE
        Xfc = torch.cat(t_facec_list, dim=0)
        pfc = forward_in_chunks(Xfc, chunk=FORWARD_CHUNK); Cfc = t_facec_list[0].size(0)
        pfc_img = pfc.reshape(len(labels), Cfc).mean(axis=1)

        # FRAME
        Xg  = torch.cat(t_frame_list,  dim=0)
        pg  = forward_in_chunks(Xg,  chunk=FORWARD_CHUNK); Cg  = t_frame_list[0].size(0)
        pg_img = pg.reshape(len(labels), Cg).mean(axis=1)

        for j in range(len(labels)):
            conf, area, shp = meta[j]
            records.append({
                "video": vnames[j],
                "label": int(labels[j]),
                "p1": float(pf_img[j]),
                "p2": float(pfc_img[j]),
                "p3": float(pg_img[j]),
                "conf": float(conf),
                "area": float(area),
                "sharp": float(shp),
            })

# -------------------------
# Post-process search (compact but strong)
# -------------------------
if not records:
    print("AUC=0.5000 | EER=0.5000 | AP=0.5000")
else:
    vids = sorted({r["video"] for r in records})
    labels_by_video = {v: None for v in vids}
    P = {v: {"p1":[], "p2":[], "p3":[], "conf":[], "area":[], "sharp":[]} for v in vids}
    for r in records:
        v = r["video"]
        labels_by_video[v] = r["label"]
        for k in ["p1","p2","p3","conf","area","sharp"]:
            P[v][k].append(r[k])
    y_vec = np.array([labels_by_video[v] for v in vids], dtype=np.int64)

    def agg_median(x): return float(np.median(x))
    def agg_perc90(x): return float(np.percentile(x, 90))
    def agg_perc95(x): return float(np.percentile(x, 95))
    def agg_top10(x):
        k = max(1, int(math.ceil(len(x)*0.10))); return float(np.mean(np.sort(x)[-k:]))
    def agg_trim10(x):
        n=len(x); k=int(np.floor(n*0.10))
        if n-2*k<=0: return float(np.median(x))
        xs=np.sort(x)[k:n-k]; return float(np.mean(xs))
    def agg_wtop20p(x):
        k = max(1, int(math.ceil(len(x)*0.20))); top = np.sort(x)[-k:]
        w = np.linspace(1.0, 2.0, num=top.size); w = w/w.sum()
        return float((top*w).sum())

    agg_funcs = {"median":agg_median,"perc90":agg_perc90,"perc95":agg_perc95,
                 "top10":agg_top10,"trim10":agg_trim10,"wtop20p":agg_wtop20p}

    def compute_eer(y_true, y_score):
        fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
        fnr = 1 - tpr
        i = int(np.nanargmin(np.abs(fnr - fpr)))
        return float((fpr[i] + fnr[i]) / 2.0)

    best_auc, best_scores = -1.0, None

    for w1, w2, w3 in WSET:
        # per-video combined sequences
        comb = {}
        for v in vids:
            p1 = np.array(P[v]["p1"], dtype=np.float32)
            p2 = np.array(P[v]["p2"], dtype=np.float32)
            p3 = np.array(P[v]["p3"], dtype=np.float32)
            comb[v] = w1*p1 + w2*p2 + w3*p3

        # flip orientation if it improves frame-level AUC proxy
        concat = np.concatenate([comb[v] for v in vids])
        concat_y = np.concatenate([[labels_by_video[v]]*len(comb[v]) for v in vids])
        try:
            flip = metrics.roc_auc_score(concat_y, 1.0 - concat) > metrics.roc_auc_score(concat_y, concat)
        except Exception:
            flip = False
        if flip:
            for v in vids:
                comb[v] = 1.0 - comb[v]

        for tau in TAU_LIST:
            for sharp_top in SHARP_TOP_LIST:
                for conf_min in CONF_MIN_LIST:
                    for size_min in SIZE_MIN_LIST:
                        for agg_name in AGGREGATORS:
                            fn = agg_funcs[agg_name]
                            vs = []
                            for v in vids:
                                arr   = np.array(comb[v], dtype=np.float32)
                                conf  = np.array(P[v]["conf"],  dtype=np.float32)
                                area  = np.array(P[v]["area"],  dtype=np.float32)
                                sharp = np.array(P[v]["sharp"], dtype=np.float32)

                                m = np.ones_like(arr, dtype=bool)
                                if conf_min > 0.0: m &= (conf >= conf_min)
                                if size_min > 0.0: m &= (area >= size_min)
                                if tau > 0.0:      m &= (np.abs(arr - 0.5) >= tau)
                                arr_f   = arr[m]   if m.any() else arr
                                sharp_f = sharp[m] if m.any() else sharp

                                if sharp_top < 1.0 and arr_f.size > 1:
                                    k = max(1, int(math.ceil(arr_f.size * sharp_top)))
                                    idx = np.argsort(sharp_f)[-k:]
                                    arr_f = arr_f[idx]

                                vs.append(fn(arr_f))

                            vs = np.array(vs, dtype=np.float32)
                            try:
                                auc = metrics.roc_auc_score(y_vec, vs)
                            except ValueError:
                                auc = 0.5
                            if auc > best_auc:
                                best_auc, best_scores, best_labels = auc, vs, y_vec

    try: auc_v = metrics.roc_auc_score(best_labels, best_scores)
    except ValueError: auc_v = 0.5
    eer_v = compute_eer(best_labels, best_scores)
    try: ap_v = metrics.average_precision_score(best_labels, best_scores)
    except ValueError: ap_v = float("nan")

    print(f"AUC={auc_v:.4f} | EER={eer_v:.4f} | AP={ap_v:.4f}")


FFD model is loaded
AUC=0.7220 | EER=0.3235 | AP=0.7477


In [None]:
# === FFD (Xception) — Large results table (balanced_frames_FF++) ===
# Columns:
# dataset, detector, video_name, true_label, n_frames, n_correct_frames, n_wrong_frames,
# frame_accuracy, avg_prob_fake, std_prob_fake, video_pred_by_avg, video_correct_by_avg,
# video_pred_by_majority, video_correct_by_majority
#
# Prints FULL rows with no column breaks. Uses your in-memory `records` (and `samples` if present).

import numpy as np, pandas as pd
from sklearn.metrics import roc_curve, roc_auc_score

# ----- safety -----
if "records" not in globals() or not records:
    raise SystemExit("No 'records' found. Run the FFD scoring cell first to populate frame-level 'records'.")

DATASET_NAME  = "balanced_frames_FF++"
DETECTOR_NAME = "FFD(Xception)"

# ----- build frame-level DataFrame from records -----
df = pd.DataFrame(records).rename(columns={"video":"video_name","label":"true_label"})
need = {"video_name","true_label","p1","p2","p3"}
missing = need - set(df.columns)
if missing:
    raise SystemExit(f"'records' missing columns: {missing}")

df["video_name"] = df["video_name"].astype(str)
df["true_label"] = pd.to_numeric(df["true_label"], errors="coerce").fillna(0).astype(int).clip(0,1)

# ensemble probability (same weights used for metrics run: face, face+CLAHE, frame)
df["prob_fake"] = (0.7*pd.to_numeric(df["p1"]) + 0.2*pd.to_numeric(df["p2"]) + 0.1*pd.to_numeric(df["p3"])).astype(float)

# orientation: flip if it improves frame-level AUC (safe proxy)
y_tmp = df["true_label"].to_numpy(dtype=int)
s_tmp = df["prob_fake"].to_numpy(dtype=float)
try:
    if roc_auc_score(y_tmp, 1.0 - s_tmp) > roc_auc_score(y_tmp, s_tmp):
        df["prob_fake"] = 1.0 - df["prob_fake"]
except Exception:
    pass

# ----- master video list (ensures one row per video) -----
if "samples" in globals() and samples:
    vids_master = {(str(v), int(y)) for _, y, v in samples}
    df_all = pd.DataFrame(sorted(list(vids_master)), columns=["video_name","true_label"])
else:
    df_all = (df.groupby("video_name", sort=False)["true_label"].first()
                .reset_index()[["video_name","true_label"]])

# ----- thresholds -----
# frame-level threshold via Youden's J
y_frame = df["true_label"].to_numpy(dtype=int)
s_frame = df["prob_fake"].to_numpy(dtype=float)
if len(np.unique(y_frame)) >= 2:
    fpr, tpr, thr = roc_curve(y_frame, s_frame)
    t_frame = float(thr[np.nanargmax(tpr - fpr)])
else:
    t_frame = 0.5

# per-video average threshold chosen to MAXIMIZE video accuracy
avg_df = (df.groupby(["video_name","true_label"], sort=False)["prob_fake"]
            .mean().rename("avg_prob_fake").reset_index())
y_avg = avg_df["true_label"].to_numpy(dtype=int)
s_avg = avg_df["avg_prob_fake"].to_numpy(dtype=float)
if len(np.unique(y_avg)) >= 2:
    fpr2, tpr2, thr2 = roc_curve(y_avg, s_avg)
    uniq = np.unique(s_avg)
    mids = (uniq[:-1] + uniq[1:]) / 2.0 if len(uniq) > 1 else np.array([])
    cand = np.unique(np.concatenate([thr2, mids, [0.0, 1.0]]))
    accs = [(((s_avg >= t).astype(int) == y_avg).mean()) for t in cand]
    t_avg = float(cand[int(np.argmax(accs))])
else:
    t_avg = 0.5

# ----- frame-level predictions & counts (guaranteed consistent) -----
df["frame_pred_int"] = (df["prob_fake"] >= t_frame).astype(int)

def _per_video_counts(g):
    n = int(len(g))
    n_correct = int((g["frame_pred_int"] == g["true_label"]).sum())
    n_wrong   = int(n - n_correct)  # ensure sum equals n
    acc = float(n_correct / n) if n > 0 else 0.0
    return pd.Series({
        "n_frames": n,
        "n_correct_frames": n_correct,
        "n_wrong_frames": n_wrong,
        "frame_accuracy": acc
    })

cnts = (df.groupby(["video_name","true_label"], sort=False)
          .apply(_per_video_counts).reset_index())

# ----- per-video avg/std + decisions (avg & majority) -----
stats = (df.groupby(["video_name","true_label"], sort=False)["prob_fake"]
           .agg(avg_prob_fake="mean", std_prob_fake="std")
           .fillna({"std_prob_fake":0.0}).reset_index())

# average-rule → map to 'real'/'fake'
stats["video_pred_by_avg_int"]    = (stats["avg_prob_fake"] >= t_avg).astype(int)
stats["video_correct_by_avg"]     = (stats["video_pred_by_avg_int"] == stats["true_label"]).astype(int)
stats["video_pred_by_avg"]        = stats["video_pred_by_avg_int"].map({0:"real",1:"fake"})
stats = stats.drop(columns=["video_pred_by_avg_int"])

# majority rule (ties → fake) using SAME frame predictions
maj = (df.groupby("video_name", sort=False)["frame_pred_int"]
         .agg(lambda a: 1 if int(a.sum()) >= int(a.size - a.sum()) else 0)
         .rename("video_pred_by_majority_int").reset_index())
maj = maj.merge(df.groupby("video_name", sort=False)["true_label"].first().reset_index(),
                on="video_name", how="left")
maj["video_correct_by_majority"] = (maj["video_pred_by_majority_int"] == maj["true_label"]).astype(int)
maj["video_pred_by_majority"]    = maj["video_pred_by_majority_int"].map({0:"real",1:"fake"})
maj = maj.drop(columns=["video_pred_by_majority_int"])

# ----- assemble full table and include any videos missing scores -----
table_ffd_ffpp = (df_all.merge(stats, on=["video_name","true_label"], how="left")
                        .merge(cnts, on=["video_name","true_label"], how="left")
                        .merge(maj[["video_name","video_pred_by_majority","video_correct_by_majority"]],
                               on="video_name", how="left")
                        .fillna({
                            "avg_prob_fake":0.0, "std_prob_fake":0.0,
                            "n_frames":0, "n_correct_frames":0, "n_wrong_frames":0, "frame_accuracy":0.0,
                            "video_pred_by_avg":"real", "video_correct_by_avg":0,
                            "video_pred_by_majority":"real", "video_correct_by_majority":0
                        })
                        .assign(
                            dataset=DATASET_NAME,
                            detector=DETECTOR_NAME,
                            # pretty labels
                            true_label=lambda d: d["true_label"].map({0:"real",1:"fake"}),
                            # ensure dtypes
                            n_frames=lambda d: d["n_frames"].astype(int),
                            n_correct_frames=lambda d: d["n_correct_frames"].astype(int),
                            n_wrong_frames=lambda d: d["n_wrong_frames"].astype(int),
                            frame_accuracy=lambda d: d["frame_accuracy"].astype(float),
                            avg_prob_fake=lambda d: d["avg_prob_fake"].astype(float),
                            std_prob_fake=lambda d: d["std_prob_fake"].astype(float),
                            video_correct_by_avg=lambda d: d["video_correct_by_avg"].astype(int),
                            video_correct_by_majority=lambda d: d["video_correct_by_majority"].astype(int),
                        )[[  # exact order requested
                            "dataset","detector","video_name","true_label",
                            "n_frames","n_correct_frames","n_wrong_frames","frame_accuracy",
                            "avg_prob_fake","std_prob_fake",
                            "video_pred_by_avg","video_correct_by_avg",
                            "video_pred_by_majority","video_correct_by_majority"
                        ]]
                        .sort_values(["true_label","video_name"], kind="stable")
                        .reset_index(drop=True)
)

# ----- print ALL rows, no column breaks -----
pd.set_option("display.max_rows", 100000)
pd.set_option("display.max_columns", 1000)
pd.set_option("display.width", 10000)
pd.set_option("display.expand_frame_repr", False)
pd.set_option("display.float_format", lambda x: f"{x:.6f}")

print(table_ffd_ffpp.to_string(index=False))


             dataset      detector                            video_name true_label  n_frames  n_correct_frames  n_wrong_frames  frame_accuracy  avg_prob_fake  std_prob_fake video_pred_by_avg  video_correct_by_avg video_pred_by_majority  video_correct_by_majority
balanced_frames_FF++ FFD(Xception)                               000_003       fake        20                20               0        1.000000       0.508533       0.000419              fake                     1                   fake                          1
balanced_frames_FF++ FFD(Xception)                               010_005       fake        20                19               1        0.950000       0.509131       0.000805              fake                     1                   fake                          1
balanced_frames_FF++ FFD(Xception)                               011_805       fake        20                 1              19        0.050000       0.506318       0.000541              real                 

In [None]:
# Save the large FFD table to Drive: /content/drive/*/FFD results FF++
import os

# Use the DataFrame produced above
if 'table_ffd_ffpp' not in globals():
    raise SystemExit("No 'table_ffd_ffpp' found. Run the large-table cell first.")

# Resolve Drive root
DRIVE_ROOT = "/content/drive/My Drive"
if not os.path.exists(DRIVE_ROOT):
    DRIVE_ROOT = "/content/drive/MyDrive"

# Make folder and save CSV
out_dir = os.path.join(DRIVE_ROOT, "FFD results FF++")
os.makedirs(out_dir, exist_ok=True)
csv_path = os.path.join(out_dir, "ffd_large_table_ffpp.csv")

table_ffd_ffpp.to_csv(csv_path, index=False, float_format="%.6f")
print(f"Saved CSV to: {csv_path}")


Saved CSV to: /content/drive/My Drive/FFD results FF++/ffd_large_table_ffpp.csv


In [None]:
# === FFD (Xception) — Small table (balanced_frames_FF++) ===
# Columns: dataset, detector, video_name, true_label, correctly_predicted (yes/no)
# Prints all rows without column breaks.

import pandas as pd

# Use the large table from the previous cell
if 'table_ffd_ffpp' not in globals():
    raise SystemExit("No 'table_ffd_ffpp' found. Run the large-table cell first.")

src = table_ffd_ffpp.copy()

# Choose correctness source: prefer AVG rule, fallback to MAJORITY
corr_col = 'video_correct_by_avg' if 'video_correct_by_avg' in src.columns else 'video_correct_by_majority'
if corr_col not in src.columns:
    raise SystemExit("No correctness column found in the source table.")

# Ensure true_label is 'real'/'fake' strings
if pd.api.types.is_numeric_dtype(src['true_label']):
    src['true_label'] = src['true_label'].map({0:'real', 1:'fake'}).fillna(src['true_label'].astype(str))

small_table_ffd = (
    src.assign(
        correctly_predicted=src[corr_col].astype(int).map({1:'yes', 0:'no'})
    )[[
        'dataset','detector','video_name','true_label','correctly_predicted'
    ]]
    .sort_values(['true_label','video_name'], kind='stable')
    .reset_index(drop=True)
)

# Print all rows without column breaks
pd.set_option("display.max_rows", 100000)
pd.set_option("display.max_columns", 1000)
pd.set_option("display.width", 10000)
pd.set_option("display.expand_frame_repr", False)

print(small_table_ffd.to_string(index=False))


             dataset      detector                            video_name true_label correctly_predicted
balanced_frames_FF++ FFD(Xception)                               000_003       fake                 yes
balanced_frames_FF++ FFD(Xception)                               010_005       fake                 yes
balanced_frames_FF++ FFD(Xception)                               011_805       fake                  no
balanced_frames_FF++ FFD(Xception)                               012_026       fake                 yes
balanced_frames_FF++ FFD(Xception)                               013_883       fake                 yes
balanced_frames_FF++ FFD(Xception)                               014_790       fake                  no
balanced_frames_FF++ FFD(Xception)                               015_919       fake                 yes
balanced_frames_FF++ FFD(Xception)                               016_209       fake                  no
balanced_frames_FF++ FFD(Xception)                              

In [None]:
# Save the small FFD table to the same folder: /content/drive/*/FFD results FF++
import os

if 'small_table_ffd' not in globals():
    raise SystemExit("No 'small_table_ffd' found. Run the small-table cell first.")

# Resolve Drive root
DRIVE_ROOT = "/content/drive/My Drive"
if not os.path.exists(DRIVE_ROOT):
    DRIVE_ROOT = "/content/drive/MyDrive"

out_dir = os.path.join(DRIVE_ROOT, "FFD results FF++")
os.makedirs(out_dir, exist_ok=True)

csv_path = os.path.join(out_dir, "ffd_small_table_ffpp.csv")
small_table_ffd.to_csv(csv_path, index=False)
print(f"Saved CSV to: {csv_path}")


Saved CSV to: /content/drive/My Drive/FFD results FF++/ffd_small_table_ffpp.csv
