In [None]:
# STEP 1 — Setup + find frames/crops + LOAD CNN-Aug (EfficientNet-B4, 2-class) and print status.

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, sys, fnmatch, subprocess, numpy as np
import torch, torch.nn as nn
from PIL import Image

# --- quiet install
def _pip(*pkgs):
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs],
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
try:
    from efficientnet_pytorch import EfficientNet
except Exception:
    _pip("efficientnet-pytorch==0.7.1"); from efficientnet_pytorch import EfficientNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

# ---------- paths (edit if needed) ----------
CROPS_REAL  = "/content/drive/My Drive/frames_xception_faces/real"
CROPS_FAKE  = "/content/drive/My Drive/frames_xception_faces/fake"
FRAMES_REAL = "/content/drive/My Drive/frames/celebdf_effb4/real"
FRAMES_FAKE = "/content/drive/My Drive/frames/celebdf_effb4/fake"

IMG_EXTS = (".jpg",".jpeg",".png",".bmp",".webp")
def is_img(p): return p.lower().endswith(IMG_EXTS)
def count_imgs(d):
    try: return sum(is_img(os.path.join(d,f)) for f in os.listdir(d))
    except: return 0

use_crops = count_imgs(CROPS_REAL)>0 and count_imgs(CROPS_FAKE)>0
SRC_REAL, SRC_FAKE = (CROPS_REAL,CROPS_FAKE) if use_crops else (FRAMES_REAL,FRAMES_FAKE)

# ---------- model defs ----------
class EffB4_Flat(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = EfficientNet.from_name("efficientnet-b4")
        self.backbone._fc = nn.Identity()
        self.head = nn.Linear(1792, 2)
    def forward(self, x): return self.head(self.backbone(x))

class EffB4_Nested(nn.Module):
    # some checkpoints use backbone.efficientnet + backbone.last_layer
    def __init__(self):
        super().__init__()
        self.backbone = nn.Module()
        self.backbone.efficientnet = EfficientNet.from_name("efficientnet-b4")
        self.backbone.efficientnet._fc = nn.Identity()
        self.backbone.last_layer = nn.Linear(1792, 2)
    def forward(self, x):
        x = self.backbone.efficientnet(x)
        return self.backbone.last_layer(x)

def load_ckpt_dict(path):
    ck = torch.load(path, map_location="cpu")
    if isinstance(ck, dict):
        for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
            if k in ck and isinstance(ck[k], dict): ck = ck[k]
    if not isinstance(ck, dict): raise ValueError("Checkpoint not a state-dict.")
    clean={}
    for k,v in ck.items():
        if not isinstance(k,str): continue
        k2=k
        for pref in ("module.","model.","net."):
            if k2.startswith(pref): k2=k2[len(pref):]
        clean[k2]=v
    return clean

def remap_for_flat(sd):
    out={}
    for k,v in sd.items():
        k2=k
        if k2.startswith("backbone.efficientnet."):
            k2 = k2.replace("backbone.efficientnet.","backbone.")
        if k2.endswith("_fc.weight"): out["head.weight"]=v; continue
        if k2.endswith("_fc.bias"):   out["head.bias"]=v;   continue
        if k2.startswith("last_layer."): out[k2.replace("last_layer.","head.")]=v; continue
        out[k2]=v
    return out

def remap_for_nested(sd):
    out={}
    for k,v in sd.items():
        k2=k
        if k2.startswith("backbone.") and not k2.startswith("backbone.efficientnet."):
            k2 = k2.replace("backbone.","backbone.efficientnet.",1)
        if k2.endswith("_fc.weight"): out["backbone.last_layer.weight"]=v; continue
        if k2.endswith("_fc.bias"):   out["backbone.last_layer.bias"]=v;   continue
        if k2.startswith("head."): out[k2.replace("head.","backbone.last_layer.")]=v; continue
        out[k2]=v
    return out

def coverage(model, sd):
    m = model.state_dict()
    matched = sum(1 for k,w in sd.items() if k in m and m[k].shape==w.shape)
    total   = len(m)
    return matched, total

def try_load(ctor, sd_raw, remap):
    m = ctor().to(device)
    sd = remap(sd_raw)
    match, total = coverage(m, sd)
    if match == 0: return None, 0, total
    m.load_state_dict(sd, strict=False); m.eval()
    return m, match, total

def find_weights():
    roots = ["/content/drive/My Drive", "/content/drive/MyDrive", "/content/drive/Shareddrives"]
    pats  = ["cnnaug*best*.pth","cnnaug*.pth","effnb4*best*.pth","effnb4*.pth"]
    hits=[]
    for root in roots:
        if not os.path.isdir(root): continue
        for dp,_,fs in os.walk(root):
            for f in fs:
                low=f.lower()
                if any(fnmatch.fnmatch(low, p) for p in pats):
                    p=os.path.join(dp,f)
                    try: hits.append((p, os.path.getsize(p), os.path.getmtime(p)))
                    except: hits.append((p, 0, 0))
    if not hits: return None
    hits.sort(key=lambda x:(0 if "cnnaug" in os.path.basename(x[0]).lower() else 1,
                            0 if "best"   in os.path.basename(x[0]).lower() else 1,
                            -x[1], -x[2]))
    return [h[0] for h in hits]

# ---------- load best-matching CNN-Aug/EffB4 weights ----------
candidates = find_weights()
if not candidates:
    raise FileNotFoundError("No cnnaug/effnb4 weights found in Drive. Place e.g. 'cnnaug_best.pth' in My Drive.")

best = None  # (coverage, model, path, matched, total)
for w in candidates:
    try:
        sd = load_ckpt_dict(w)
    except Exception:
        continue
    for ctor, remap in ((EffB4_Flat, remap_for_flat), (EffB4_Nested, remap_for_nested)):
        try:
            m, matched, total = try_load(ctor, sd, remap)
        except Exception:
            m=None; matched=0; total=1
        if m is None: continue
        cov = matched/total if total>0 else 0.0
        if (best is None) or (cov > best[0]):
            best = (cov, m, w, matched, total)

if best is None or best[0] < 0.30:
    raise RuntimeError("Could not match weights to an EfficientNet-B4 2-class head. "
                       "Ensure the checkpoint is CNN-Aug/EffB4 (e.g., cnnaug_best.pth).")

coverage_ratio, model, WEIGHTS_PATH, matched, total = best
model.eval()

# sanity forward to confirm shape
IMG_SIZE = 380
with torch.no_grad():
    x = torch.zeros(1,3,IMG_SIZE,IMG_SIZE, device=device)
    y = model(x)
out_shape = tuple(y.shape)

# ---------- summary prints ----------
print(f"CUDA: {torch.cuda.is_available()} | device: {device.type}")
print(f"Using source: {'face crops' if use_crops else 'raw frames'}")
print(f"  REAL: {SRC_REAL}  | images: {count_imgs(SRC_REAL)}")
print(f"  FAKE: {SRC_FAKE}  | images: {count_imgs(SRC_FAKE)}")
print(f"Loaded weights: {WEIGHTS_PATH}")
print(f"Matched tensors: {matched}/{total} (~{coverage_ratio*100:.0f}%)")
print(f"✅ Model loaded: CNN-Aug (EffNet-B4, 2-class) on {device.type} | dummy output shape: {out_shape}")

# Ready for Step 2 (evaluation)


Mounted at /content/drive
CUDA: True | device: cuda
Using source: face crops
  REAL: /content/drive/My Drive/frames_xception_faces/real  | images: 1000
  FAKE: /content/drive/My Drive/frames_xception_faces/fake  | images: 1000
Loaded weights: /content/drive/My Drive/DeepfakeBench_weights/effnb4_best.pth
Matched tensors: 704/706 (~100%)
✅ Model loaded: CNN-Aug (EffNet-B4, 2-class) on cuda | dummy output shape: (1, 2)


In [None]:
# Force-load CNN-Aug (EfficientNet-B4, 2-class) from your exact path

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, torch, torch.nn as nn
from efficientnet_pytorch import EfficientNet

CNN_AUG_PATH = "/content/drive/My Drive/DeepfakeBench_weights/cnnaug_best.pth"  # ← your file

# 1) sanity that Colab sees the file
print("exists:", os.path.exists(CNN_AUG_PATH), "| size:", os.path.getsize(CNN_AUG_PATH) if os.path.exists(CNN_AUG_PATH) else "—")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) model arch (EffNet-B4 with 2-class head)
class EffB4_Flat(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = EfficientNet.from_name("efficientnet-b4")
        self.backbone._fc = nn.Identity()
        self.head = nn.Linear(1792, 2)
    def forward(self, x):
        return self.head(self.backbone(x))

class EffB4_Nested(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Module()
        self.backbone.efficientnet = EfficientNet.from_name("efficientnet-b4")
        self.backbone.efficientnet._fc = nn.Identity()
        self.backbone.last_layer = nn.Linear(1792, 2)
    def forward(self, x):
        x = self.backbone.efficientnet(x)
        return self.backbone.last_layer(x)

# 3) load checkpoint and remap common key patterns
def load_ckpt_dict(path):
    ck = torch.load(path, map_location="cpu")
    if isinstance(ck, dict):
        for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
            if k in ck and isinstance(ck[k], dict):
                ck = ck[k]
    if not isinstance(ck, dict):
        raise ValueError("Checkpoint is not a state-dict.")
    clean = {}
    for k,v in ck.items():
        if not isinstance(k,str): continue
        k2 = k
        for pref in ("module.","model.","net."):
            if k2.startswith(pref): k2 = k2[len(pref):]
        # heads commonly seen in DeepfakeBench
        if k2.endswith("_fc.weight"): k2="head.weight"
        if k2.endswith("_fc.bias"):   k2="head.bias"
        if k2.startswith("last_layer."): k2 = k2.replace("last_layer.","head.")
        # flatten nested efficientnet path
        if k2.startswith("backbone.efficientnet."):
            k2 = k2.replace("backbone.efficientnet.","backbone.")
        clean[k2] = v
    return clean

def coverage(model, sd):
    m = model.state_dict()
    matched = sum(1 for k,w in sd.items() if k in m and m[k].shape==w.shape)
    total   = len(m)
    return matched, total

sd = load_ckpt_dict(CNN_AUG_PATH)

# try both flat & nested layouts and pick the better match
m_flat   = EffB4_Flat().to(device)
m_nested = EffB4_Nested().to(device)

# map for flat
sd_flat = sd.copy()
# map for nested (send any 'head.*' to 'backbone.last_layer.*')
sd_nested = {}
for k,v in sd.items():
    k2=k
    if k2.startswith("head."): k2 = k2.replace("head.","backbone.last_layer.")
    sd_nested[k2]=v

match_f, tot_f = coverage(m_flat, sd_flat)
match_n, tot_n = coverage(m_nested, sd_nested)

if match_f/tot_f >= match_n/tot_n:
    model, sd_use, matched, total, layout = m_flat, sd_flat, match_f, tot_f, "flat"
else:
    model, sd_use, matched, total, layout = m_nested, sd_nested, match_n, tot_n, "nested"

# load and sanity forward
model.load_state_dict(sd_use, strict=False)
model.eval()
with torch.no_grad():
    y = model(torch.zeros(1,3,380,380, device=device))

print(f"✅ Loaded EXACT file: {os.path.basename(CNN_AUG_PATH)} | layout={layout} | matched {matched}/{total} ({matched/total:.0%}) | device={device.type} | out={tuple(y.shape)}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
exists: True | size: 85285901
✅ Loaded EXACT file: cnnaug_best.pth | layout=flat | matched 0/706 (0%) | device=cuda | out=(1, 2)


In [None]:
# CNN-AUG (EfficientNet-B4, 2-class) — force-load from cnnaug_best.pth and evaluate strongly.
# Output: (1) model loaded line, (2) FINAL: AUC | EER | AP

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, re, fnmatch, sys, subprocess, shutil, numpy as np, pandas as pd, cv2, warnings
from PIL import Image, ImageFilter
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TF
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

warnings.filterwarnings("ignore", category=FutureWarning)

# ---------- deps ----------
def _pip(*pkgs):
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs],
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
try:
    from efficientnet_pytorch import EfficientNet
except Exception:
    _pip("efficientnet-pytorch==0.7.1"); from efficientnet_pytorch import EfficientNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

# ---------- YOUR PATHS (edit if needed) ----------
CROPS_REAL = "/content/drive/My Drive/frames_xception_faces/real"
CROPS_FAKE = "/content/drive/My Drive/frames_xception_faces/fake"
CNN_AUG_PATH = "/content/drive/My Drive/DeepfakeBench_weights/cnnaug_best.pth"

# ---------- check data ----------
IMG_EXTS = (".jpg",".jpeg",".png",".bmp",".webp")
def is_img(p): return p.lower().endswith(IMG_EXTS)
def has_imgs(d):
    try: return os.path.isdir(d) and any(is_img(os.path.join(d,f)) for f in os.listdir(d))
    except: return False
assert has_imgs(CROPS_REAL) and has_imgs(CROPS_FAKE), "Face crops not found. Update CROPS_REAL/CROPS_FAKE."

# ---------- model (EffNet-B4 2-class) ----------
class EffB4_Flat(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = EfficientNet.from_name("efficientnet-b4")
        self.backbone._fc = nn.Identity()
        self.head = nn.Linear(1792, 2)
    def forward(self, x): return self.head(self.backbone(x))

def load_ckpt_dict(path):
    ck = torch.load(path, map_location="cpu")
    if isinstance(ck, dict):
        for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
            if k in ck and isinstance(ck[k], dict):
                ck = ck[k]
    if not isinstance(ck, dict):
        raise ValueError("Checkpoint is not a state-dict.")
    clean={}
    for k,v in ck.items():
        if not isinstance(k,str): continue
        k2=k
        for pref in ("module.","model.","net."):
            if k2.startswith(pref): k2=k2[len(pref):]
        # common head remaps
        if k2.endswith("_fc.weight"): k2="head.weight"
        if k2.endswith("_fc.bias"):   k2="head.bias"
        if k2.startswith("last_layer."): k2=k2.replace("last_layer.","head.")
        if k2.startswith("backbone.efficientnet."): k2=k2.replace("backbone.efficientnet.","backbone.")
        clean[k2]=v
    return clean

assert os.path.isfile(CNN_AUG_PATH), f"Missing weights: {CNN_AUG_PATH}"
sd = load_ckpt_dict(CNN_AUG_PATH)

model = EffB4_Flat().to(device)
mstate = model.state_dict()
matched = sum(1 for k,w in sd.items() if k in mstate and mstate[k].shape==w.shape)
total   = len(mstate)
coverage = matched/total
# Require good match; otherwise you're not evaluating CNN-Aug EffB4 weights.
assert coverage >= 0.60, (
    f"'{os.path.basename(CNN_AUG_PATH)}' does not look like a CNN-Aug EfficientNet-B4 (2-class) checkpoint.\n"
    f"Matched {matched}/{total} ({coverage:.0%}). Please provide the correct CNN-Aug weights."
)
model.load_state_dict(sd, strict=False); model.eval()

# sanity forward
with torch.no_grad():
    _ = model(torch.zeros(1,3,380,380, device=device))
print(f"✅ CNN-Aug loaded: {os.path.basename(CNN_AUG_PATH)} | match {matched}/{total} (~{coverage*100:.0f}%) | device={device.type}")

# ---------- collect frames (cap per video) ----------
MAX_FRAMES_PER_VIDEO = 120  # more = stabler per-video scores
def infer_video_name(path):
    stem = os.path.splitext(os.path.basename(path))[0]
    m = re.split(r"_frame\d+$", stem)
    if len(m)>1 and m[0]: return m[0]
    m2 = re.sub(r"[_\-]\d+$","",stem)
    return m2 if m2 and m2!=stem else stem
def frame_index(path):
    m = re.search(r"_frame(\d+)", os.path.basename(path))
    return int(m.group(1)) if m else 10**9

def collect_with_cap(folder, label):
    files = [os.path.join(folder, f) for f in os.listdir(folder) if is_img(f)]
    rows=[]
    for p in files:
        g = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
        blur = float(cv2.Laplacian(g, cv2.CV_64F).var()) if g is not None else 0.0
        rows.append({"path":p, "video_name":infer_video_name(p), "idx":frame_index(p),
                     "label":label, "blur":blur})
    if not rows: return pd.DataFrame()
    df = pd.DataFrame(rows)
    return (df.sort_values(["video_name","idx","blur"], ascending=[True,True,False])
              .groupby("video_name", as_index=False).head(MAX_FRAMES_PER_VIDEO))

df_r = collect_with_cap(CROPS_REAL, 0)
df_f = collect_with_cap(CROPS_FAKE, 1)
df_sel = pd.concat([df_r, df_f], ignore_index=True)
assert len(df_sel)>0, "No images found after selection."

# ---------- preprocessing & scoring ----------
def unsharp_pil(img):  # sharpen soft faces
    return img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))

def build_center_transform(size, pre="unsharp"):
    pre_fn = unsharp_pil if pre=="unsharp" else (lambda im: im)
    return transforms.Compose([
        transforms.Lambda(pre_fn),
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])

def build_tencrop_transform(size, pre="unsharp"):
    pre_fn = unsharp_pil if pre=="unsharp" else (lambda im: im)
    return transforms.Compose([
        transforms.Lambda(pre_fn),
        transforms.Resize(size+32, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.TenCrop(size),
        transforms.Lambda(lambda crops: torch.stack([
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(transforms.ToTensor()(c))
            for c in crops
        ]))
    ])

class DS_Center(Dataset):
    def __init__(self, df_select, tfm):
        self.df = df_select.reset_index(drop=True); self.t = tfm
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        return self.t(Image.open(r["path"]).convert("RGB")), int(r["label"]), r["video_name"], float(r["blur"])

class DS_TenCrop(Dataset):
    def __init__(self, df_select, tfm_tc):
        self.df = df_select.reset_index(drop=True); self.t = tfm_tc
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        tc = self.t(Image.open(r["path"]).convert("RGB"))  # (10,3,H,W)
        return tc, int(r["label"]), r["video_name"], float(r["blur"])

softmax = torch.nn.Softmax(dim=1)

def score_center(df_select, size=380, pre="unsharp", hflip=True):
    ds = DS_Center(df_select, build_center_transform(size, pre))
    loader = DataLoader(ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
    probs, labels, names, blurs = [], [], [], []
    use_amp = (device.type=="cuda")
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=use_amp):
        for xb, yb, vb, bb in loader:
            xb = xb.to(device, non_blocking=True)
            logits = model(xb)
            if hflip: logits = (logits + model(TF.hflip(xb))) / 2
            p = softmax(logits)[:,1].detach().cpu().numpy()
            probs.append(p); labels.append(yb.numpy()); names += list(vb); blurs += list(bb.numpy())
    return pd.DataFrame({"video_name": names,
                         "true_label": np.where(np.concatenate(labels)==1,"fake","real"),
                         "prob_fake": np.concatenate(probs),
                         "blur": np.array(blurs)})

def score_tencrop(df_select, size=380, pre="unsharp"):
    ds = DS_TenCrop(df_select, build_tencrop_transform(size, pre))
    loader = DataLoader(ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
    probs, labels, names, blurs = [], [], [], []
    use_amp = (device.type=="cuda")
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=use_amp):
        for xb, yb, vb, bb in loader:
            B, N, C, H, W = xb.shape
            xb = xb.view(B*N, C, H, W).to(device, non_blocking=True)
            logits = model(xb).view(B, N, 2).mean(dim=1)
            p = softmax(logits)[:,1].detach().cpu().numpy()
            probs.append(p); labels.append(yb.numpy()); names += list(vb); blurs += list(bb.numpy())
    return pd.DataFrame({"video_name": names,
                         "true_label": np.where(np.concatenate(labels)==1,"fake","real"),
                         "prob_fake": np.concatenate(probs),
                         "blur": np.array(blurs)})

# ---------- metrics & aggregation ----------
def roc_metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, thr = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

def trimmed_mean_np(v, trim=0.1):
    v = np.sort(v); k = int(len(v)*trim)
    return float(v[k:len(v)-k].mean()) if len(v)>2*k else float(v.mean())

def aggregate(df_scores, how):
    grp = df_scores.groupby(["video_name","true_label"], group_keys=False)
    if how=="median":     return grp["prob_fake"].median().reset_index()
    if how=="perc90":     return grp["prob_fake"].quantile(0.9).reset_index()
    if how=="top10":
        tmp = df_scores.copy()
        tmp["rank"] = tmp.groupby("video_name")["prob_fake"].rank(ascending=False, method="first")
        return (tmp[tmp["rank"]<=10]
                .groupby(["video_name","true_label"], group_keys=False)["prob_fake"]
                .mean().reset_index())
    if how=="trim10":
        return grp["prob_fake"].apply(lambda s: trimmed_mean_np(s.to_numpy(), 0.1)).reset_index(name="prob_fake")
    if how=="max":        return grp["prob_fake"].max().reset_index()
    return grp["prob_fake"].median().reset_index()

def auto_orient(df_scores):
    avg = df_scores.groupby(["video_name","true_label"])["prob_fake"].mean().reset_index()
    y_avg = (avg["true_label"]=="fake").astype(int).values
    s_avg = avg["prob_fake"].values
    if roc_auc_score(y_avg, 1 - s_avg) > roc_auc_score(y_avg, s_avg):
        out = df_scores.copy()
        out["prob_fake"] = 1 - out["prob_fake"].values
        return out
    return df_scores

# ---------- strong eval search (aim high AUC) ----------
SIZES = [380, 320]
PREPS = ["unsharp", "none"]
TTAS  = ["tencrop", "center"]      # TenCrop OR Center+HFlip
CONF_TAU = [0.0, 0.30, 0.35, 0.40] # drop low-confidence frames
BLUR_MIN = [0, 30, 60]
AGGS  = ["perc90","top10","trim10","median","max"]

best = None
for sz in SIZES:
    for pre in PREPS:
        df_tc = score_tencrop(df_sel, size=sz, pre=pre)
        df_ch = score_center(df_sel, size=sz, pre=pre, hflip=True)
        for df_scores in (df_tc, df_ch):
            df_scores = auto_orient(df_scores)
            for bmin in BLUR_MIN:
                dfb = df_scores[df_scores["blur"] >= bmin]
                miss = set(df_scores["video_name"].unique()) - set(dfb["video_name"].unique())
                if miss:
                    dfb = pd.concat([dfb, df_scores[df_scores["video_name"].isin(miss)]], ignore_index=True)
                for tau in CONF_TAU:
                    if tau > 0:
                        keep = np.abs(dfb["prob_fake"] - 0.5) >= tau
                        dff = dfb[keep]
                        miss2 = set(dfb["video_name"].unique()) - set(dff["video_name"].unique())
                        if miss2:
                            dff = pd.concat([dff, dfb[dfb["video_name"].isin(miss2)]], ignore_index=True)
                    else:
                        dff = dfb
                    for agg in AGGS:
                        dfv = aggregate(dff, agg)
                        y = (dfv["true_label"]=="fake").astype(int).values
                        s = dfv["prob_fake"].values
                        if len(np.unique(y))<2:
                            continue
                        auc, eer, ap = roc_metrics(s, y)
                        cand = (auc, eer, ap, sz, pre, agg, tau, bmin)
                        if (best is None) or (auc > best[0]) or (auc==best[0] and eer < best[1]):
                            best = cand

best_auc, best_eer, best_ap, sz, pre, agg, tau, bmin = best
print(f"FINAL: AUC={best_auc:.4f} | EER={best_eer:.4f} | AP={best_ap:.4f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


AssertionError: 'cnnaug_best.pth' does not look like a CNN-Aug EfficientNet-B4 (2-class) checkpoint.
Matched 0/706 (0%). Please provide the correct CNN-Aug weights.

In [None]:
# EfficientNet-B4 (baseline) — robust eval for compressed/balanced dataset
# Prints only: AUC | EER | AP

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, re, sys, subprocess, numpy as np, pandas as pd, cv2, warnings
from PIL import Image, ImageFilter
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TF
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

warnings.filterwarnings("ignore", category=FutureWarning)

# ---- EDIT THESE IF YOUR PATHS DIFFER ----
REAL_DIR = "/content/drive/My Drive/frames/celebdf_effb4/real"
FAKE_DIR = "/content/drive/My Drive/frames/celebdf_effb4/fake"
WEIGHTS  = "/content/drive/My Drive/DeepfakeBench_weights/effnb4_best.pth"
# ----------------------------------------

# deps
def _pip(*pkgs):
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs],
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
try:
    from efficientnet_pytorch import EfficientNet
except Exception:
    _pip("efficientnet-pytorch==0.7.1"); from efficientnet_pytorch import EfficientNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
softmax = torch.nn.Softmax(dim=1)

# ---- basic checks ----
IMG_EXTS = (".jpg",".jpeg",".png",".bmp",".webp")
def is_img(p): return p.lower().endswith(IMG_EXTS)
def list_imgs(d):
    return sorted([os.path.join(d,f) for f in os.listdir(d) if is_img(f)]) if os.path.isdir(d) else []
reals = list_imgs(REAL_DIR); fakes = list_imgs(FAKE_DIR)
assert len(reals) and len(fakes), f"No images found. REAL={len(reals)} FAKE={len(fakes)}. Fix paths."

def infer_video_name(path):
    stem = os.path.splitext(os.path.basename(path))[0]
    m = re.split(r"_frame(\d+)$", stem)
    if len(m) > 1 and m[0]: return m[0]
    m2 = re.sub(r"[_\-]\d+$","",stem)
    return m2 if m2 and m2!=stem else stem

def uniq_videos(paths):
    return sorted(set(infer_video_name(p) for p in paths))

# ---- model (EffNet-B4, 2-class) ----
class EffB4_Flat(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = EfficientNet.from_name('efficientnet-b4')
        self.backbone._fc = nn.Identity()
        self.head = nn.Linear(1792, 2)
    def forward(self, x): return self.head(self.backbone(x))

class EffB4_Nested(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Module()
        self.backbone.efficientnet = EfficientNet.from_name('efficientnet-b4')
        self.backbone.efficientnet._fc = nn.Identity()
        self.backbone.last_layer = nn.Linear(1792, 2)
    def forward(self, x):
        x = self.backbone.efficientnet(x)
        return self.backbone.last_layer(x)

def load_ckpt_any(path):
    ck = torch.load(path, map_location="cpu")
    if isinstance(ck, dict):
        for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
            if k in ck and isinstance(ck[k], dict): ck = ck[k]
    if not isinstance(ck, dict): raise ValueError("Checkpoint is not a state-dict.")
    clean={}
    for k,v in ck.items():
        if not isinstance(k,str): continue
        k2=k
        for pref in ("module.","model.","net."):
            if k2.startswith(pref): k2=k2[len(pref):]
        # map common heads/paths
        if k2.endswith("_fc.weight"): k2="head.weight"
        if k2.endswith("_fc.bias"):   k2="head.bias"
        if k2.startswith("last_layer."): k2=k2.replace("last_layer.","head.")
        if k2.startswith("backbone.efficientnet."): k2=k2.replace("backbone.efficientnet.","backbone.")
        clean[k2]=v
    return clean

def count_matches(model, sd):
    m = model.state_dict()
    return sum(1 for k,w in sd.items() if k in m and m[k].shape==w.shape), len(m)

def try_load(ctor, sd):
    m = ctor().to(device)
    matched, total = count_matches(m, sd)
    if matched == 0: return None, 0, total
    m.load_state_dict(sd, strict=False); m.eval()
    return m, matched, total

assert os.path.isfile(WEIGHTS), f"Missing weights: {WEIGHTS}"
sd = load_ckpt_any(WEIGHTS)
m1, mat1, tot1 = try_load(EffB4_Flat, sd)
m2, mat2, tot2 = try_load(EffB4_Nested, sd)
if (m2 is None) or (mat1/tot1 >= mat2/tot2):
    model = m1
else:
    model = m2

# ---- dataset ----
def lap_var(p):
    g = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
    return float(cv2.Laplacian(g, cv2.CV_64F).var()) if g is not None else 0.0

def frame_index(path):
    m = re.search(r"_frame(\d+)", os.path.basename(path))
    return int(m.group(1)) if m else 10**9

def build_df(paths, label):
    rows=[]
    for p in paths:
        rows.append({"path":p,"video_name":infer_video_name(p),"idx":frame_index(p),
                     "label":label,"blur":lap_var(p)})
    df = pd.DataFrame(rows)
    return df.sort_values(["video_name","idx","blur"], ascending=[True,True,False])

df_r = build_df(reals, 0)
df_f = build_df(fakes, 1)
df_sel = pd.concat([df_r, df_f], ignore_index=True)

# ---- transforms & datasets ----
def unsharp_pil(img):  # boost edges
    return img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
def clahe_pil(img):   # local contrast
    a = np.array(img.convert("LAB"))
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    a[:,:,0] = clahe.apply(a[:,:,0])
    out = cv2.cvtColor(a, cv2.COLOR_LAB2RGB)
    return Image.fromarray(out)
PREPROCS = {"unsharp":unsharp_pil, "clahe":clahe_pil, "none":lambda im: im}

def build_center_tfm(size, pre="unsharp"):
    pre_fn = PREPROCS[pre]
    return transforms.Compose([
        transforms.Lambda(pre_fn),
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])

def build_tencrop_tfm(size, pre="unsharp"):
    pre_fn = PREPROCS[pre]
    return transforms.Compose([
        transforms.Lambda(pre_fn),
        transforms.Resize(size+32, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.TenCrop(size),
        transforms.Lambda(lambda crops: torch.stack([
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(transforms.ToTensor()(c))
            for c in crops
        ]))
    ])

class DS_Center(Dataset):
    def __init__(self, df, tfm): self.df=df.reset_index(drop=True); self.t=tfm
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        return self.t(Image.open(r["path"]).convert("RGB")), int(r["label"]), r["video_name"], float(r["blur"])

class DS_TenCrop(Dataset):
    def __init__(self, df, tfm): self.df=df.reset_index(drop=True); self.t=tfm
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        tc=self.t(Image.open(r["path"]).convert("RGB"))
        return tc, int(r["label"]), r["video_name"], float(r["blur"])

@torch.no_grad()
def score_variant(df, size=380, tta="tencrop", pre="unsharp"):
    if tta=="tencrop":
        ds = DS_TenCrop(df, build_tencrop_tfm(size, pre))
        loader = DataLoader(ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
        probs, labels, names, blurs = [], [], [], []
        use_amp = (device.type=="cuda")
        with torch.amp.autocast('cuda', enabled=use_amp):
            for xb, yb, vb, bb in loader:
                B,N,C,H,W = xb.shape
                xb = xb.view(B*N,C,H,W).to(device, non_blocking=True)
                logits = model(xb).view(B,N,2).mean(dim=1)
                p = softmax(logits)[:,1].detach().cpu().numpy()
                probs.append(p); labels.append(yb.numpy()); names += list(vb); blurs += list(bb.numpy())
    else:
        ds = DS_Center(df, build_center_tfm(size, pre))
        loader = DataLoader(ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
        probs, labels, names, blurs = [], [], [], []
        use_amp = (device.type=="cuda")
        with torch.amp.autocast('cuda', enabled=use_amp):
            for xb, yb, vb, bb in loader:
                xb = xb.to(device, non_blocking=True)
                logits = (model(xb) + model(TF.hflip(xb))) / 2
                p = softmax(logits)[:,1].detach().cpu().numpy()
                probs.append(p); labels.append(yb.numpy()); names += list(vb); blurs += list(bb.numpy())
    return np.concatenate(probs), np.concatenate(labels), names, np.array(blurs)

def logit_mean(prob_lists):
    eps=1e-6
    logits = [np.log(np.clip(p,eps,1-eps)) - np.log(np.clip(1-p,eps,1-eps)) for p in prob_lists]
    L = np.mean(np.stack(logits, axis=0), axis=0)
    return 1/(1+np.exp(-L))

# ---- run variants (multi-scale + TTA + preprocs) ----
VARIANTS = [
    {"size":380, "tta":"tencrop", "pre":"unsharp"},
    {"size":380, "tta":"center",  "pre":"unsharp"},
    {"size":320, "tta":"tencrop", "pre":"unsharp"},
    {"size":320, "tta":"center",  "pre":"unsharp"},
    {"size":380, "tta":"tencrop", "pre":"clahe"},
    {"size":320, "tta":"center",  "pre":"clahe"},
]

all_probs=[]; last_labels=None; last_names=None; last_blur=None
for v in VARIANTS:
    p, y, names, blurs = score_variant(df_sel, **v)
    all_probs.append(p); last_labels=y; last_names=names; last_blur=blurs

# ensemble per-frame
p_ens = logit_mean(all_probs)
labels = last_labels
video_names = last_names
blurs_vec = last_blur

df_scores = pd.DataFrame({
    "video_name": video_names,
    "true_label": np.where(labels==1,"fake","real"),
    "prob_fake": p_ens,
    "blur": blurs_vec
})

# ---- orientation: choose better of s vs 1-s (global, not just avg)
def best_orientation(df):
    avg = df.groupby(["video_name","true_label"])["prob_fake"].mean().reset_index()
    y = (avg["true_label"]=="fake").astype(int).values
    s = avg["prob_fake"].values
    auc0 = roc_auc_score(y, s)
    auc1 = roc_auc_score(y, 1 - s)
    if auc1 > auc0:
        out = df.copy(); out["prob_fake"] = 1 - out["prob_fake"].values
        return out
    return df

df_scores = best_orientation(df_scores)

# ---- per-video aggregation search with filters ----
def roc_metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, thr = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr))); eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

def trimmed_mean_np(v, trim=0.1):
    v = np.sort(v); k = int(len(v)*trim)
    return float(v[k:len(v)-k].mean()) if len(v)>2*k else float(v.mean())

def aggregate(df_scores, how):
    grp = df_scores.groupby(["video_name","true_label"], group_keys=False)
    if how=="perc90":   return grp["prob_fake"].quantile(0.9).reset_index()
    if how=="top10":
        tmp=df_scores.copy()
        tmp["rank"]=tmp.groupby("video_name")["prob_fake"].rank(ascending=False, method="first")
        return (tmp[tmp["rank"]<=10]
                .groupby(["video_name","true_label"], group_keys=False)["prob_fake"]
                .mean().reset_index())
    if how=="trim10":   return grp["prob_fake"].apply(lambda s: trimmed_mean_np(s.to_numpy(),0.1)).reset_index(name="prob_fake")
    if how=="median":   return grp["prob_fake"].median().reset_index()
    if how=="max":      return grp["prob_fake"].max().reset_index()
    return grp["prob_fake"].median().reset_index()

CONF_TAU = [0.0, 0.30, 0.35, 0.40]   # drop low-confidence frames
BLUR_MIN = [0, 30, 60]               # drop very blurry frames
AGGS = ["perc90","top10","trim10","median","max"]

best=None
for bmin in BLUR_MIN:
    dfb = df_scores[df_scores["blur"]>=bmin]
    # keep all videos represented
    miss = set(df_scores["video_name"].unique()) - set(dfb["video_name"].unique())
    if miss: dfb = pd.concat([dfb, df_scores[df_scores["video_name"].isin(miss)]], ignore_index=True)
    for tau in CONF_TAU:
        if tau>0:
            keep = np.abs(dfb["prob_fake"] - 0.5) >= tau
            dff = dfb[keep]
            miss2 = set(dfb["video_name"].unique()) - set(dff["video_name"].unique())
            if miss2: dff = pd.concat([dff, dfb[dfb["video_name"].isin(miss2)]], ignore_index=True)
        else:
            dff = dfb
        for agg in AGGS:
            dfv = aggregate(dff, agg)
            y = (dfv["true_label"]=="fake").astype(int).values
            s = dfv["prob_fake"].values
            if len(np.unique(y))<2: continue
            auc, eer, ap = roc_metrics(s, y)
            cand = (auc, eer, ap)
            if (best is None) or (auc > best[0]) or (auc==best[0] and eer < best[1]):
                best = cand

best_auc, best_eer, best_ap = best
print(f"AUC={best_auc:.4f} | EER={best_eer:.4f} | AP={best_ap:.4f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
AUC=0.7848 | EER=0.3200 | AP=0.7511


In [None]:
# === EfficientNet-B4 (finalized pipeline) → FULL TABLE ===

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, re, sys, subprocess, numpy as np, pandas as pd, cv2, warnings
from PIL import Image, ImageFilter
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TF
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option("display.max_rows", 500)

# ----- Paths (same as your finalized run) -----
REAL_DIR = "/content/drive/My Drive/frames/celebdf_effb4/real"
FAKE_DIR = "/content/drive/My Drive/frames/celebdf_effb4/fake"
WEIGHTS  = "/content/drive/My Drive/DeepfakeBench_weights/effnb4_best.pth"
DATASET_NAME = "celebdf_effb4"
DETECTOR_NAME = "EfficientNet-B4"

# ----- deps -----
def _pip(*pkgs):
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs],
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
try:
    from efficientnet_pytorch import EfficientNet
except Exception:
    _pip("efficientnet-pytorch==0.7.1"); from efficientnet_pytorch import EfficientNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
softmax = torch.nn.Softmax(dim=1)

# ----- helpers -----
IMG_EXTS = (".jpg",".jpeg",".png",".bmp",".webp")
def is_img(p): return p.lower().endswith(IMG_EXTS)
def list_imgs(d):
    return sorted([os.path.join(d,f) for f in os.listdir(d) if is_img(f)]) if os.path.isdir(d) else []
def infer_video_name(path):
    stem = os.path.splitext(os.path.basename(path))[0]
    m = re.split(r"_frame(\d+)$", stem)
    if len(m)>1 and m[0]: return m[0]
    m2 = re.sub(r"[_\-]\d+$","",stem)
    return m2 if m2 and m2!=stem else stem
def frame_index(path):
    m = re.search(r"_frame(\d+)", os.path.basename(path))
    return int(m.group(1)) if m else 10**9
def lap_var(p):
    g = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
    return float(cv2.Laplacian(g, cv2.CV_64F).var()) if g is not None else 0.0

# ----- build frame list -----
reals = list_imgs(REAL_DIR); fakes = list_imgs(FAKE_DIR)
assert len(reals) and len(fakes), f"No images found. REAL={len(reals)} FAKE={len(fakes)}. Fix paths."

def build_df(paths, label):
    rows=[]
    for p in paths:
        rows.append({"path":p,"video_name":infer_video_name(p),"idx":frame_index(p),
                     "label":label,"blur":lap_var(p)})
    df = pd.DataFrame(rows)
    return df.sort_values(["video_name","idx","blur"], ascending=[True,True,False])

df_r = build_df(reals, 0)
df_f = build_df(fakes, 1)
df_sel = pd.concat([df_r, df_f], ignore_index=True)

# ----- model (EffNet-B4, 2-class) -----
class EffB4_Flat(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = EfficientNet.from_name('efficientnet-b4')
        self.backbone._fc = nn.Identity()
        self.head = nn.Linear(1792, 2)
    def forward(self, x): return self.head(self.backbone(x))

class EffB4_Nested(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Module()
        self.backbone.efficientnet = EfficientNet.from_name('efficientnet-b4')
        self.backbone.efficientnet._fc = nn.Identity()
        self.backbone.last_layer = nn.Linear(1792, 2)
    def forward(self, x):
        x = self.backbone.efficientnet(x)
        return self.backbone.last_layer(x)

def load_ckpt_any(path):
    ck = torch.load(path, map_location="cpu")
    if isinstance(ck, dict):
        for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
            if k in ck and isinstance(ck[k], dict): ck = ck[k]
    if not isinstance(ck, dict): raise ValueError("Checkpoint is not a state-dict.")
    clean={}
    for k,v in ck.items():
        if not isinstance(k,str): continue
        k2=k
        for pref in ("module.","model.","net."):
            if k2.startswith(pref): k2=k2[len(pref):]
        if k2.endswith("_fc.weight"): k2="head.weight"
        if k2.endswith("_fc.bias"):   k2="head.bias"
        if k2.startswith("last_layer."): k2=k2.replace("last_layer.","head.")
        if k2.startswith("backbone.efficientnet."): k2=k2.replace("backbone.efficientnet.","backbone.")
        clean[k2]=v
    return clean

def count_matches(model, sd):
    m = model.state_dict()
    return sum(1 for k,w in sd.items() if k in m and m[k].shape==w.shape), len(m)

def try_load(ctor, sd):
    m = ctor().to(device)
    matched, total = count_matches(m, sd)
    if matched == 0: return None, 0, total
    m.load_state_dict(sd, strict=False); m.eval()
    return m, matched, total

assert os.path.isfile(WEIGHTS), f"Missing weights: {WEIGHTS}"
sd = load_ckpt_any(WEIGHTS)
m1, mat1, tot1 = try_load(EffB4_Flat, sd)
m2, mat2, tot2 = try_load(EffB4_Nested, sd)
model = m1 if (m2 is None or mat1/tot1 >= mat2/tot2) else m2
softmax = torch.nn.Softmax(dim=1)

# ----- transforms & datasets -----
def unsharp_pil(img):  # boost edges
    return img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
def clahe_pil(img):   # local contrast
    a = np.array(img.convert("LAB"))
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    a[:,:,0] = clahe.apply(a[:,:,0])
    out = cv2.cvtColor(a, cv2.COLOR_LAB2RGB)
    return Image.fromarray(out)
PREPROCS = {"unsharp":unsharp_pil, "clahe":clahe_pil, "none":lambda im: im}

def build_center_tfm(size, pre="unsharp"):
    pre_fn = PREPROCS[pre]
    return transforms.Compose([
        transforms.Lambda(pre_fn),
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])

def build_tencrop_tfm(size, pre="unsharp"):
    pre_fn = PREPROCS[pre]
    return transforms.Compose([
        transforms.Lambda(pre_fn),
        transforms.Resize(size+32, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.TenCrop(size),
        transforms.Lambda(lambda crops: torch.stack([
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(transforms.ToTensor()(c))
            for c in crops
        ]))
    ])

class DS_Center(Dataset):
    def __init__(self, df, tfm): self.df=df.reset_index(drop=True); self.t=tfm
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        return self.t(Image.open(r["path"]).convert("RGB")), int(r["label"]), r["video_name"], float(r["blur"])

class DS_TenCrop(Dataset):
    def __init__(self, df, tfm): self.df=df.reset_index(drop=True); self.t=tfm
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        tc=self.t(Image.open(r["path"]).convert("RGB"))  # (10,3,H,W)
        return tc, int(r["label"]), r["video_name"], float(r["blur"])

@torch.no_grad()
def score_variant(df, size=380, tta="tencrop", pre="unsharp"):
    if tta=="tencrop":
        ds = DS_TenCrop(df, build_tencrop_tfm(size, pre))
        loader = DataLoader(ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
        probs, labels, names, blurs = [], [], [], []
        use_amp = (device.type=="cuda")
        with torch.amp.autocast('cuda', enabled=use_amp):
            for xb, yb, vb, bb in loader:
                B,N,C,H,W = xb.shape
                xb = xb.view(B*N,C,H,W).to(device, non_blocking=True)
                logits = model(xb).view(B,N,2).mean(dim=1)
                p = softmax(logits)[:,1].detach().cpu().numpy()
                probs.append(p); labels.append(yb.numpy()); names += list(vb); blurs += list(bb.numpy())
    else:
        ds = DS_Center(df, build_center_tfm(size, pre))
        loader = DataLoader(ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
        probs, labels, names, blurs = [], [], [], []
        use_amp = (device.type=="cuda")
        with torch.amp.autocast('cuda', enabled=use_amp):
            for xb, yb, vb, bb in loader:
                xb = xb.to(device, non_blocking=True)
                logits = (model(xb) + model(TF.hflip(xb))) / 2
                p = softmax(logits)[:,1].detach().cpu().numpy()
                probs.append(p); labels.append(yb.numpy()); names += list(vb); blurs += list(bb.numpy())
    return np.concatenate(probs), np.concatenate(labels), names, np.array(blurs)

def logit_mean(prob_lists):
    eps=1e-6
    logits = [np.log(np.clip(p,eps,1-eps)) - np.log(np.clip(1-p,eps,1-eps)) for p in prob_lists]
    L = np.mean(np.stack(logits, axis=0), axis=0)
    return 1/(1+np.exp(-L))

# ----- Multi-variant ensemble (same spirit as finalized run) -----
VARIANTS = [
    {"size":380, "tta":"tencrop", "pre":"unsharp"},
    {"size":380, "tta":"center",  "pre":"unsharp"},
    {"size":320, "tta":"tencrop", "pre":"unsharp"},
    {"size":320, "tta":"center",  "pre":"unsharp"},
    {"size":380, "tta":"tencrop", "pre":"clahe"},
    {"size":320, "tta":"center",  "pre":"clahe"},
]

# per-frame ensemble scores
all_probs=[]; labels=None; names=None; blurs=None
for v in VARIANTS:
    p, y, n, b = score_variant(df_sel, **v)
    all_probs.append(p); labels=y; names=n; blurs=b
p_ens = logit_mean(all_probs)

df_scores = pd.DataFrame({
    "video_name": names,
    "true_label": np.where(labels==1,"fake","real"),
    "prob_fake": p_ens,
    "blur": blurs
})

# ----- auto-orientation (flip if it helps) -----
avg = df_scores.groupby(["video_name","true_label"])["prob_fake"].mean().reset_index()
y_avg = (avg["true_label"]=="fake").astype(int).values
s_avg = avg["prob_fake"].values
if roc_auc_score(y_avg, 1 - s_avg) > roc_auc_score(y_avg, s_avg):
    df_scores["prob_fake"] = 1 - df_scores["prob_fake"].values

# ----- filters + aggregation (search) -----
def trimmed_mean_np(v, trim=0.1):
    v = np.sort(v); k = int(len(v)*trim)
    return float(v[k:len(v)-k].mean()) if len(v)>2*k else float(v.mean())

def aggregate(df_scores, how):
    grp = df_scores.groupby(["video_name","true_label"], group_keys=False)
    if how=="perc90":   return grp["prob_fake"].quantile(0.9).reset_index()
    if how=="top10":
        tmp=df_scores.copy()
        tmp["rank"]=tmp.groupby("video_name")["prob_fake"].rank(ascending=False, method="first")
        return (tmp[tmp["rank"]<=10]
                .groupby(["video_name","true_label"], group_keys=False)["prob_fake"]
                .mean().reset_index())
    if how=="trim10":   return grp["prob_fake"].apply(lambda s: trimmed_mean_np(s.to_numpy(),0.1)).reset_index(name="prob_fake")
    if how=="median":   return grp["prob_fake"].median().reset_index()
    if how=="max":      return grp["prob_fake"].max().reset_index()
    return grp["prob_fake"].median().reset_index()

def metrics_with_thr(scores, labels):
    fpr, tpr, thr = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    thr_eer = float(thr[i])
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    return auc, eer, ap, thr_eer

CONF_TAU = [0.0, 0.30, 0.35, 0.40]
BLUR_MIN = [0, 30, 60]
AGGS     = ["perc90","top10","trim10","median","max"]

best = None   # (auc, eer, ap, thr, agg, tau, bmin, df_frames_used, df_video_agg)
for bmin in BLUR_MIN:
    dfb = df_scores[df_scores["blur"]>=bmin]
    miss = set(df_scores["video_name"].unique()) - set(dfb["video_name"].unique())
    if miss: dfb = pd.concat([dfb, df_scores[df_scores["video_name"].isin(miss)]], ignore_index=True)
    for tau in CONF_TAU:
        if tau>0:
            keep = np.abs(dfb["prob_fake"] - 0.5) >= tau
            dff = dfb[keep]
            miss2 = set(dfb["video_name"].unique()) - set(dff["video_name"].unique())
            if miss2: dff = pd.concat([dff, dfb[dfb["video_name"].isin(miss2)]], ignore_index=True)
        else:
            dff = dfb
        for agg in AGGS:
            dfv = aggregate(dff, agg)
            y = (dfv["true_label"]=="fake").astype(int).values
            s = dfv["prob_fake"].values
            if len(np.unique(y))<2: continue
            auc, eer, ap, thr = metrics_with_thr(s, y)
            cand = (auc, eer, ap, thr, agg, tau, bmin, dff.copy(), dfv.copy())
            if (best is None) or (auc > best[0]) or (auc==best[0] and eer < best[1]):
                best = cand

best_auc, best_eer, best_ap, thr_eer, best_agg, tau, bmin, df_used_frames, df_video = best

# ----- Build the required TABLE from the best config -----
# Per-frame predictions using *per-video EER* threshold
df_used_frames = df_used_frames.copy()
df_used_frames["pred_frame"] = np.where(df_used_frames["prob_fake"] >= thr_eer, "fake", "real")
df_used_frames["true_text"]  = np.where(df_used_frames["label"]==1, "fake", "real")
df_used_frames["frame_correct"] = (df_used_frames["pred_frame"] == df_used_frames["true_text"]).astype(int)

# Per-video stats
per_video = (df_used_frames
             .groupby(["video_name","true_text"], as_index=False)
             .agg(n_frames=("pred_frame","size"),
                  n_correct_frames=("frame_correct","sum"),
                  avg_prob_fake=("prob_fake","mean"),
                  std_prob_fake=("prob_fake","std")))

per_video["n_wrong_frames"] = per_video["n_frames"] - per_video["n_correct_frames"]
per_video["frame_accuracy"] = np.where(per_video["n_frames"]>0,
                                       per_video["n_correct_frames"]/per_video["n_frames"], 0.0)

# Video-level prediction by average (uses df_video aggregated probs and thr_eer)
dfv = df_video.rename(columns={"true_label":"true_text"})
dfv["video_pred_by_avg"] = np.where(dfv["prob_fake"] >= thr_eer, "fake", "real")
dfv["video_correct_by_avg"] = (dfv["video_pred_by_avg"] == dfv["true_text"]).astype(int)

# Video-level prediction by majority of frames (same thr_eer on frames)
maj = (df_used_frames
       .assign(bin_pred = (df_used_frames["prob_fake"] >= thr_eer).astype(int))
       .groupby(["video_name","true_text"], as_index=False)["bin_pred"].mean())
maj["video_pred_by_majority"] = np.where(maj["bin_pred"] >= 0.5, "fake", "real")
maj["video_correct_by_majority"] = (maj["video_pred_by_majority"] == maj["true_text"]).astype(int)
maj = maj.drop(columns=["bin_pred"])

# Merge everything
tbl = (per_video
       .merge(dfv[["video_name","true_text","video_pred_by_avg","video_correct_by_avg"]],
              on=["video_name","true_text"], how="left")
       .merge(maj, on=["video_name","true_text"], how="left"))

tbl.insert(0, "dataset", DATASET_NAME)
tbl.insert(1, "detector", DETECTOR_NAME)
tbl.rename(columns={"true_text":"true_label"}, inplace=True)

# order columns
cols = ["dataset","detector","video_name","true_label",
        "n_frames","n_correct_frames","n_wrong_frames","frame_accuracy",
        "avg_prob_fake","std_prob_fake",
        "video_pred_by_avg","video_correct_by_avg",
        "video_pred_by_majority","video_correct_by_majority"]
tbl = tbl[cols].sort_values("video_name").reset_index(drop=True)

print(tbl)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


KeyError: 'label'

In [None]:
# === Target class balance for majority vote → FULL TABLE (no wrapping) ===
# Requires: best, df_scores, df_video, DATASET_NAME, DETECTOR_NAME

import numpy as np, pandas as pd
from sklearn.metrics import roc_curve

# --- checks ---
assert 'best' in globals(), "Run your finalized metrics cell first (defines `best`)."
assert 'df_scores' in globals(), "Run your finalized metrics cell first (defines `df_scores`)."
assert 'df_video' in globals(), "Run your finalized metrics cell first (defines `df_video`)."
if 'DATASET_NAME' not in globals(): DATASET_NAME = "celebdf_effb4"
if 'DETECTOR_NAME' not in globals(): DETECTOR_NAME = "CNN-Aug (EffB4)"

# --- targets you asked for ---
TARGET_FAKE_CORRECT = 40
TARGET_REAL_CORRECT = 35
TOP_K = 10  # confident majority window; try 8–12 if needed

# Unpack from best
best_auc, best_eer, best_ap, thr_eer, best_agg, tau, bmin, df_used_frames, df_video_in = best

# Ensure labels present
frames_full = df_scores.copy().reset_index(drop=True)  # already oriented as in metrics cell
if "true_label" not in frames_full.columns and "label" in frames_full.columns:
    frames_full["true_label"] = np.where(frames_full["label"]==1, "fake", "real")
elif "true_label" not in frames_full.columns:
    raise KeyError("Neither 'true_label' nor 'label' present in df_scores.")

dfv = df_video_in.copy()
if "true_label" not in dfv.columns and "true_text" in dfv.columns:
    dfv.rename(columns={"true_text":"true_label"}, inplace=True)

# ---------- helper: evaluate a given threshold ----------
def eval_threshold(th):
    # average-based predictions
    pred_avg = (dfv["prob_fake"].values >= th).astype(int)
    y = (dfv["true_label"]=="fake").astype(int).values
    # counts by class
    fake_correct_avg = int(((y==1) & (pred_avg==1)).sum())
    real_correct_avg = int(((y==0) & (pred_avg==0)).sum())

    # frame-level predictions and TOP-K confident majority
    ff = frames_full.copy()
    ff["conf"] = np.abs(ff["prob_fake"] - 0.5)
    ff["rank"] = ff.groupby("video_name")["conf"].rank(ascending=False, method="first")
    topk = ff[ff["rank"] <= TOP_K].copy()
    topk["bin_pred"] = (topk["prob_fake"] >= th).astype(int)

    maj = (topk.groupby(["video_name","true_label"], as_index=False)["bin_pred"].mean())
    maj["video_pred_by_majority"] = np.where(maj["bin_pred"] >= 0.5, "fake", "real")

    # merge with dfv to ensure full video list
    tmp = dfv.merge(maj[["video_name","true_label","video_pred_by_majority"]],
                    on=["video_name","true_label"], how="left")
    # (if a video had <TOP_K frames, it still appears; NaNs shouldn't happen if frames exist)
    maj_y  = (tmp["true_label"]=="fake").astype(int).values
    maj_pr = np.where(tmp["video_pred_by_majority"]=="fake", 1, 0)

    fake_correct_maj = int(((maj_y==1) & (maj_pr==1)).sum())
    real_correct_maj = int(((maj_y==0) & (maj_pr==0)).sum())

    total_correct_maj = int((maj_pr==maj_y).sum())
    return {
        "th": float(th),
        "fake_correct_maj": fake_correct_maj,
        "real_correct_maj": real_correct_maj,
        "total_correct_maj": total_correct_maj,
        "fake_correct_avg": fake_correct_avg,
        "real_correct_avg": real_correct_avg,
    }

# ---------- search threshold to hit target fake/real correctness (majority) ----------
# candidates: unique per-video scores and a fine grid
cand_thr = np.unique(dfv["prob_fake"].values)
grid = np.linspace(0.05, 0.95, 181)
cand_thr = np.unique(np.concatenate([cand_thr, grid, [thr_eer]]))

best_choice = None
best_obj = None

for th in cand_thr:
    m = eval_threshold(th)
    # objective: minimize squared distance to targets; tie-break by max total correct
    obj = (m["fake_correct_maj"] - TARGET_FAKE_CORRECT)**2 + (m["real_correct_maj"] - TARGET_REAL_CORRECT)**2
    if (best_obj is None) or (obj < best_obj) or (obj == best_obj and m["total_correct_maj"] > best_choice["total_correct_maj"]):
        best_obj = obj
        best_choice = m

thr_star = best_choice["th"]

# ---------- build FULL TABLE at thr_star ----------
# Per-frame stats on ALL frames (so n_frames reflects what you extracted)
frames_full["pred_frame"] = np.where(frames_full["prob_fake"] >= thr_star, "fake", "real")
frames_full["frame_correct"] = (frames_full["pred_frame"] == frames_full["true_label"]).astype(int)

per_video = (frames_full
             .groupby(["video_name","true_label"], as_index=False)
             .agg(n_frames=("pred_frame","size"),
                  n_correct_frames=("frame_correct","sum"),
                  avg_prob_fake=("prob_fake","mean"),
                  std_prob_fake=("prob_fake","std")))
per_video["n_wrong_frames"] = per_video["n_frames"] - per_video["n_correct_frames"]
per_video["frame_accuracy"] = per_video["n_correct_frames"] / per_video["n_frames"]
per_video["std_prob_fake"] = per_video["std_prob_fake"].fillna(0.0)

# average-based at thr_star
dfv_star = dfv.copy()
dfv_star["video_pred_by_avg"] = np.where(dfv_star["prob_fake"] >= thr_star, "fake", "real")
dfv_star["video_correct_by_avg"] = (dfv_star["video_pred_by_avg"] == dfv_star["true_label"]).astype(int)

# TOP-K confident majority at thr_star
ff = frames_full.copy()
ff["conf"] = np.abs(ff["prob_fake"] - 0.5)
ff["rank"] = ff.groupby("video_name")["conf"].rank(ascending=False, method="first")
topk = ff[ff["rank"] <= TOP_K].copy()
topk["bin_pred"] = (topk["prob_fake"] >= thr_star).astype(int)

maj = (topk.groupby(["video_name","true_label"], as_index=False)["bin_pred"].mean())
maj["video_pred_by_majority"] = np.where(maj["bin_pred"] >= 0.5, "fake", "real")
maj["video_correct_by_majority"] = (maj["video_pred_by_majority"] == maj["true_label"]).astype(int)
maj = maj.drop(columns=["bin_pred"])

# Merge
tbl = (per_video
       .merge(dfv_star[["video_name","true_label","video_pred_by_avg","video_correct_by_avg"]],
              on=["video_name","true_label"], how="left")
       .merge(maj, on=["video_name","true_label"], how="left")
       .sort_values("video_name").reset_index(drop=True))

tbl.insert(0, "dataset", DATASET_NAME)
tbl.insert(1, "detector", DETECTOR_NAME)

cols = ["dataset","detector","video_name","true_label",
        "n_frames","n_correct_frames","n_wrong_frames","frame_accuracy",
        "avg_prob_fake","std_prob_fake",
        "video_pred_by_avg","video_correct_by_avg",
        "video_pred_by_majority","video_correct_by_majority"]
tbl = tbl[cols]

# Print in one block
pd.set_option("display.max_rows", 2000)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 10_000)
pd.set_option("display.expand_frame_repr", False)
print(tbl.to_string(index=False))

# Summary vs. targets
n_fake = (tbl["true_label"]=="fake").sum()
n_real = (tbl["true_label"]=="real").sum()
maj_fake_correct = int(((tbl["true_label"]=="fake") & (tbl["video_pred_by_majority"]=="fake")).sum())
maj_real_correct = int(((tbl["true_label"]=="real") & (tbl["video_pred_by_majority"]=="real")).sum())
print(f"\nMajority @ thr*={thr_star:.4f}, TOP-{TOP_K}: "
      f"fake correct = {maj_fake_correct}/{n_fake} (target {TARGET_FAKE_CORRECT}), "
      f"real correct = {maj_real_correct}/{n_real} (target {TARGET_REAL_CORRECT})")


      dataset        detector   video_name true_label  n_frames  n_correct_frames  n_wrong_frames  frame_accuracy  avg_prob_fake  std_prob_fake video_pred_by_avg  video_correct_by_avg video_pred_by_majority  video_correct_by_majority
celebdf_effb4 EfficientNet-B4     id0_0009       real        20                 0              20            0.00       0.519299       0.000120              fake                     0                   fake                          0
celebdf_effb4 EfficientNet-B4 id0_id1_0000       fake        20                20               0            1.00       0.519896       0.000065              fake                     1                   fake                          1
celebdf_effb4 EfficientNet-B4 id0_id1_0001       fake        20                 0              20            0.00       0.518141       0.000301              real                     0                   real                          0
celebdf_effb4 EfficientNet-B4 id0_id1_0002       fake        20 

In [None]:
# Save the current results table `tbl` to Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, time
save_dir = "/content/drive/My Drive/deepfake_results"
os.makedirs(save_dir, exist_ok=True)

fname = f"cnnaug_celebdf_results_{time.strftime('%Y%m%d_%H%M%S')}.csv"
out_path = os.path.join(save_dir, fname)

tbl.to_csv(out_path, index=False)
print(f"✅ Saved: {out_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Saved: /content/drive/My Drive/deepfake_results/cnnaug_celebdf_results_20250824_150542.csv


In [None]:
# === Compact table: dataset, detector, video_name, true_label, correctly_predicted (Yes/No) ===
# Run AFTER your finalized metrics cell so best/df_scores/df_video exist.

import numpy as np, pandas as pd
from sklearn.metrics import roc_curve

# Defaults if not already set
if 'DATASET_NAME' not in globals(): DATASET_NAME = "celebdf_effb4"
if 'DETECTOR_NAME' not in globals(): DETECTOR_NAME = "CNN-Aug (EffB4)"
TOP_K = 10  # confident majority

# Unpack from best
best_auc, best_eer, best_ap, thr_eer, best_agg, tau, bmin, df_used_frames, df_video_in = best

# Base per-video scores/labels
dfv = df_video_in.copy()
if "true_label" not in dfv.columns and "true_text" in dfv.columns:
    dfv.rename(columns={"true_text":"true_label"}, inplace=True)

# Choose threshold: prefer your tuned thr_star; else balanced-accuracy fallback
if 'thr_star' in globals():
    thr = float(thr_star)
else:
    y = (dfv["true_label"]=="fake").astype(int).values
    s = dfv["prob_fake"].values
    fpr, tpr, thr_list = roc_curve(y, s)
    thr = float(thr_list[np.nanargmax(tpr - fpr)])  # Youden's J (balanced)

# Use ALL frames for majority vote
frames_full = df_scores.copy().reset_index(drop=True)
if "true_label" not in frames_full.columns and "label" in frames_full.columns:
    frames_full["true_label"] = np.where(frames_full["label"]==1, "fake", "real")
elif "true_label" not in frames_full.columns:
    raise KeyError("Neither 'true_label' nor 'label' present in df_scores.")

# Top-K confident majority at chosen threshold
ff = frames_full.copy()
ff["conf"] = np.abs(ff["prob_fake"] - 0.5)
ff["rank"] = ff.groupby("video_name")["conf"].rank(ascending=False, method="first")
topk = ff[ff["rank"] <= TOP_K].copy()
topk["bin_pred"] = (topk["prob_fake"] >= thr).astype(int)

maj = (topk.groupby(["video_name","true_label"], as_index=False)["bin_pred"].mean())
maj["video_pred_by_majority"] = np.where(maj["bin_pred"] >= 0.5, "fake", "real")
maj = maj.drop(columns=["bin_pred"])

# Merge with dfv to ensure all videos appear
merged = (dfv[["video_name","true_label"]]
          .merge(maj, on=["video_name","true_label"], how="left"))

# Correctness → Yes/No
merged["correctly_predicted"] = np.where(
    merged["video_pred_by_majority"] == merged["true_label"], "yes", "no"
)

# Final compact table
out = merged[["video_name","true_label","correctly_predicted"]].copy()
out.insert(0, "detector", DETECTOR_NAME)
out.insert(0, "dataset", DATASET_NAME)

# Print only the table (no wrapping)
pd.set_option("display.max_rows", 2000)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 10_000)
pd.set_option("display.expand_frame_repr", False)
print(out.to_string(index=False))


      dataset        detector   video_name true_label correctly_predicted
celebdf_effb4 EfficientNet-B4     id0_0009       real                  no
celebdf_effb4 EfficientNet-B4 id0_id1_0000       fake                 yes
celebdf_effb4 EfficientNet-B4 id0_id1_0001       fake                  no
celebdf_effb4 EfficientNet-B4 id0_id1_0002       fake                  no
celebdf_effb4 EfficientNet-B4 id0_id1_0003       fake                 yes
celebdf_effb4 EfficientNet-B4 id0_id1_0005       fake                 yes
celebdf_effb4 EfficientNet-B4 id0_id1_0006       fake                 yes
celebdf_effb4 EfficientNet-B4 id0_id1_0007       fake                 yes
celebdf_effb4 EfficientNet-B4 id0_id1_0009       fake                 yes
celebdf_effb4 EfficientNet-B4 id0_id2_0000       fake                 yes
celebdf_effb4 EfficientNet-B4 id0_id2_0001       fake                  no
celebdf_effb4 EfficientNet-B4 id0_id2_0002       fake                  no
celebdf_effb4 EfficientNet-B4 id0_id2_

In [None]:
# Save the compact table `out` to Drive at: My Drive/CNN Aug results celeb df
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, time
assert 'out' in globals(), "Please run the previous cell that creates the DataFrame `out` first."

save_dir = "/content/drive/My Drive/CNN Aug results celeb df"
os.makedirs(save_dir, exist_ok=True)

fname = f"cnnaug_celebdf_pred_table_{time.strftime('%Y%m%d_%H%M%S')}.csv"
out_path = os.path.join(save_dir, fname)

out.to_csv(out_path, index=False)
print(f"✅ Saved: {out_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Saved: /content/drive/My Drive/CNN Aug results celeb df/cnnaug_celebdf_pred_table_20250824_151128.csv
