In [None]:
# === Capsule (spatial) on balanced_frames_FF++ — prints ONLY: "Capsule model loaded" + AUC | EER | AP ===
# Implements VGG19 feature extractor + Capsule head (as in DeepfakeBench CapsuleNetDetector),
# loads capsule_best.pth (partial-safe), does light TTA (flip), auto orientation, and picks the best
# video aggregation (median / perc90 / top10 / trim10). Uses 256x256 RGB with ImageNet normalization.

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# ---------- paths ----------
import os, re, io, contextlib, subprocess, numpy as np, pandas as pd, cv2, math
from PIL import Image

ROOT = "/content/drive/MyDrive" if os.path.isdir("/content/drive/MyDrive") else "/content/drive/My Drive"
REAL_DIR = f"{ROOT}/balanced_frames_FF++/real"
FAKE_DIR = f"{ROOT}/balanced_frames_FF++/fake"
CAPSULE_WEIGHTS = f"{ROOT}/DeepfakeBench_weights/capsule_best.pth"

assert os.path.isdir(REAL_DIR) and os.path.isdir(FAKE_DIR), "Check dataset folders."
assert os.path.isfile(CAPSULE_WEIGHTS), "Missing capsule_best.pth."

# ---------- deps ----------
def _pipq(*pkgs): subprocess.run([os.sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
try:
    import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    import torchvision
    from torchvision import models
except Exception:
    _pipq("torch", "torchvision"); import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import models

from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

# ---------- hardware / knobs ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = (device.type=="cuda")
IMG_SIZE    = 256
BATCH       = 16 if device.type=="cuda" else 8
NUM_WORKERS = 2 if device.type=="cuda" else 0
FRAME_CAP   = 120      # try 100–150 depending on runtime
softmax = torch.nn.Softmax(dim=1)

# ---------- utils: list & name frames ----------
IMG_EXTS=(".jpg",".jpeg",".png",".bmp",".webp")
def list_imgs(d): return sorted([os.path.join(d,f) for f in os.listdir(d) if f.lower().endswith(IMG_EXTS)]) if os.path.isdir(d) else []
reals, fakes = list_imgs(REAL_DIR), list_imgs(FAKE_DIR)
assert len(reals) and len(fakes), f"No images found. REAL={len(reals)} FAKE={len(fakes)}."

def infer_video_name(p):
    stem=os.path.splitext(os.path.basename(p))[0]
    m=re.split(r"_frame(\d+)$", stem)
    return m[0] if len(m)>1 and m[0] else re.sub(r"[_\-]\d+$","",stem)

def frame_index(p):
    m=re.search(r"_frame(\d+)", os.path.basename(p))
    return int(m.group(1)) if m else 10**9

def build_df(paths, label):
    rows=[{"path":p,"video_name":infer_video_name(p),"idx":frame_index(p),"label":label} for p in paths]
    return pd.DataFrame(rows).sort_values(["video_name","idx"])

df_sel = pd.concat([build_df(reals,0), build_df(fakes,1)], ignore_index=True)
df_sel = df_sel.sort_values(["video_name","idx"]).groupby("video_name", as_index=False).head(FRAME_CAP).reset_index(drop=True)

# ---------- preprocessing (RGB -> tensor) ----------
IMN_MEAN = np.array([0.485,0.456,0.406], np.float32)
IMN_STD  = np.array([0.229,0.224,0.225], np.float32)

def prep_rgb(path, out=IMG_SIZE):
    im = cv2.imread(path, cv2.IMREAD_COLOR)
    if im is None:
        im = cv2.cvtColor(np.array(Image.open(path).convert("RGB")), cv2.COLOR_RGB2BGR)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (out,out), interpolation=cv2.INTER_CUBIC).astype(np.float32)/255.0
    x = im.transpose(2,0,1)
    x = (x - IMN_MEAN[:,None,None]) / IMN_STD[:,None,None]
    return torch.from_numpy(x.astype(np.float32))

class DSRGB(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        x = prep_rgb(r["path"])
        return x, int(r["label"]), str(r["video_name"]), int(r["idx"])

# ---------- Capsule model (per DeepfakeBench structure) ----------
class VggExtractor(nn.Module):
    def __init__(self, train=False):
        super().__init__()
        # handle both new/old torchvision APIs quietly
        try:
            vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_FEATURES)
        except Exception:
            vgg = models.vgg19(pretrained=True)
        self.vgg_1 = nn.Sequential(*list(vgg.features.children())[:19])  # 0..18 inclusive
        if train:
            self.vgg_1.train(True)
        else:
            self.vgg_1.eval()
    def forward(self, x): return self.vgg_1(x)

class StatsNet(nn.Module):
    def forward(self, x):
        # x: [B, C, H, W]  ->  [B, 2, C] with per-channel mean/std over spatial dims
        B,C,H,W = x.shape
        x = x.view(B, C, H*W)
        mean = torch.mean(x, dim=2)
        std  = torch.std(x, dim=2)
        return torch.stack((mean, std), dim=1)  # [B, 2, C]

class View(nn.Module):
    def __init__(self, *shape): super().__init__(); self.shape = shape
    def forward(self, inp): return inp.view(self.shape)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.NO_CAPS = 10
        self.capsules = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(16), nn.ReLU(inplace=True),
                StatsNet(),                          # [B, 2, 16]
                nn.Conv1d(2, 8, kernel_size=5, stride=2, padding=2),  # [B, 8, 8]
                nn.BatchNorm1d(8),
                nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1),  # [B, 1, 8]
                nn.BatchNorm1d(1),
                View(-1, 8),                        # [B, 8]
            ) for _ in range(self.NO_CAPS)
        ])
    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1.0 + squared_norm + 1e-8)
        return scale * tensor / (torch.sqrt(squared_norm + 1e-8))
    def forward(self, x):
        outs = [cap(x) for cap in self.capsules]   # list of [B,8]
        out  = torch.stack(outs, dim=-1)           # [B, 8, NO_CAPS]
        return self.squash(out, dim=-1)            # [B, 8, NO_CAPS]

class RoutingLayer(nn.Module):
    def __init__(self, num_input_capsules, num_output_capsules, data_in, data_out, num_iterations):
        super().__init__()
        self.num_iterations = num_iterations
        # [out_caps, in_caps, data_out, data_in]
        self.route_weights = nn.Parameter(torch.randn(num_output_capsules, num_input_capsules, data_out, data_in)*0.1)
    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1.0 + squared_norm + 1e-8)
        return scale * tensor / (torch.sqrt(squared_norm + 1e-8))
    def forward(self, x, random=False, dropout=0.0):
        # x: [B, data_in, in_caps]
        x = x.transpose(2,1)   # [B, in_caps, data_in]
        route_weights = self.route_weights
        # priors: [out_caps, B, in_caps, data_out, 1]
        priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None]
        priors = priors.transpose(1,0)  # [B, out_caps, in_caps, data_out, 1]
        if dropout > 0.0:
            drop = torch.bernoulli(torch.full_like(priors, 1.0 - dropout))
            priors = priors * drop
        logits = torch.zeros_like(priors)
        for i in range(self.num_iterations):
            probs = F.softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True), dim=3)  # [B, out_caps, 1, data_out, 1]
            if i != self.num_iterations - 1:
                logits = logits + priors * outputs
        outputs = outputs.squeeze()  # [B, out_caps, data_out] or [out_caps, data_out] -> handle batch
        if outputs.ndim == 2:
            outputs = outputs.unsqueeze(0)
        outputs = outputs.transpose(2,1).contiguous()  # [B, data_out, out_caps]
        return outputs

class CapsuleNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.num_classes = num_classes
        self.vgg_ext = VggExtractor(train=False)
        self.fea_ext = FeatureExtractor()
        self.routing = RoutingLayer(num_input_capsules=10, num_output_capsules=num_classes,
                                    data_in=8, data_out=4, num_iterations=2)
        # init small
        self.apply(self._weights_init)
    def _weights_init(self, m):
        name = m.__class__.__name__
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
            if hasattr(m, "weight") and m.weight is not None:
                nn.init.normal_(m.weight, 0.0, 0.02)
        if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            if hasattr(m, "weight") and m.weight is not None:
                nn.init.normal_(m.weight, 1.0, 0.02)
            if hasattr(m, "bias") and m.bias is not None:
                nn.init.constant_(m.bias, 0.0)
    def forward(self, x):
        with torch.no_grad():
            feat = self.vgg_ext(x)           # [B,256,H',W']
        caps = self.fea_ext(feat)            # [B,8,10]
        z = self.routing(caps, random=False, dropout=0.0)   # [B,4,2]
        classes = F.softmax(z, dim=-1)       # [B,4,2] softmax over out_caps
        pred = classes.detach().mean(dim=1)  # [B,2]
        prob = F.softmax(pred, dim=1)[:,1]   # scalar prob of class-1 (fake)
        return pred, prob

# ---------- build model + weights ----------
model = CapsuleNet(num_classes=2)

def try_load_capsule_weights(model, ckpt_path, min_cover=0.25):
    ok=False; cover=0.0
    try:
        sd=torch.load(ckpt_path, map_location="cpu")
        if isinstance(sd,dict):
            for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
                if k in sd and isinstance(sd[k],dict): sd=sd[k]; break
        clean={}
        if isinstance(sd,dict):
            for k,v in sd.items():
                if not isinstance(k,str): continue
                k2=k
                # strip common prefixes
                for pref in ("module.","model.","net.","backbone.","vgg_ext.","fea_ext.","routing.","routing_stats."):
                    if k2.startswith(pref): k2=k2[len(pref):]
                clean[k2]=v
            ms=model.state_dict()
            matched={k:v for k,v in clean.items() if k in ms and ms[k].shape==v.shape}
            cover=len(matched)/max(1,len(ms))
            if cover>=min_cover:
                ms.update(matched); model.load_state_dict(ms, strict=False); ok=True
    except Exception as e:
        print("[warn] weight load:", e)
    return ok, cover

weights_loaded, coverage = try_load_capsule_weights(model, CAPSULE_WEIGHTS, min_cover=0.25)
model = model.to(device).eval()
print("Capsule model loaded")

# ---------- light TTA (horizontal flip) ----------
@torch.no_grad()
def forward_tta(xb):
    use_amp=(device.type=="cuda")
    with torch.amp.autocast('cuda', enabled=use_amp):
        pred1, prob1 = model(xb)
        pred2, prob2 = model(torch.flip(xb, dims=[3]))
    # average probs
    prob = (prob1 + prob2) / 2.0
    # logits proxy for orientation decisions
    logit = torch.stack([1.0-prob, prob], dim=1)
    return logit, prob

# ---------- scoring ----------
class DSRGBCaps(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self,i):
        r=self.df.iloc[i]; return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

@torch.no_grad()
def score_frames(df):
    loader = DataLoader(DSRGBCaps(df), batch_size=BATCH, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    vnames, idxs, probs, labels = [], [], [], []
    for xb, yb, vb, ib in loader:
        xb = xb.to(device, non_blocking=(device.type=="cuda"))
        logit, prob = forward_tta(xb)
        vnames += list(vb); idxs += list(ib)
        probs.append(prob.detach().cpu().numpy())
        labels.append(np.array(yb))
    out = pd.DataFrame({
        "video_name": pd.Series(vnames, dtype=object),
        "idx":        pd.Series(idxs, dtype=np.int64),
        "true_label": pd.Series(np.where(np.concatenate(labels)==1,"fake","real"), dtype=object),
        "prob_fake":  pd.Series(np.concatenate(probs).astype(float), dtype=np.float64),
    })
    return out.sort_values(["video_name","idx"]).reset_index(drop=True)

df_scores = score_frames(df_sel)

# ---------- auto orientation (flip scores if helps per-video AUC) ----------
avg = df_scores.groupby(["video_name","true_label"], as_index=False)["prob_fake"].mean()
y_avg = (avg["true_label"]=="fake").astype(int).to_numpy()
s_avg = avg["prob_fake"].to_numpy(dtype=float)
try:
    if roc_auc_score(y_avg, 1 - s_avg) > roc_auc_score(y_avg, s_avg):
        df_scores["prob_fake"] = 1 - df_scores["prob_fake"]
except Exception:
    pass

# ---------- per-video aggregation & metrics ----------
def aggregate_numpy(df, how):
    rows=[]
    for (v,t), g in df.groupby(["video_name","true_label"], sort=False):
        vals = g["prob_fake"].to_numpy(dtype=float)
        n = vals.size
        if n==0: continue
        vs = np.sort(vals)
        if   how=="median":  score=float(np.median(vs))
        elif how=="perc90":  score=float(np.quantile(vs, 0.90, method="linear")) if "method" in np.quantile.__code__.co_varnames else float(np.quantile(vs,0.90, interpolation="linear"))
        elif how=="top10":   score=float(np.mean(vs[-min(10,n):]))
        elif how=="trim10":
            k=int(0.1*n); lo=k; hi=max(n-k,1); score=float(np.mean(vs[lo:hi]))
        else: score=float(np.median(vs))
        rows.append({"video_name":v, "true_label":t, "score":score})
    return pd.DataFrame(rows)

def metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, _ = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

best=None; best_cfg=None
for agg in ("median","perc90","top10","trim10"):
    dfv = aggregate_numpy(df_scores, agg)
    if dfv.empty: continue
    y = (dfv["true_label"]=="fake").astype(int).to_numpy()
    s = dfv["score"].to_numpy(dtype=float)
    if len(np.unique(y))<2: continue
    cand = metrics(s, y)
    if (best is None) or (cand[0] > best[0]) or (cand[0]==best[0] and cand[1] < best[1]):
        best = cand; best_cfg = agg

auc, eer, ap = best
print(f"AUC={auc:.4f} | EER={eer:.4f} | AP={ap:.4f}")
print(f"[info] dataset='balanced_frames_FF++', device={device.type}, img={IMG_SIZE}, cap={FRAME_CAP}, agg={best_cfg}, "
      f"tta=flip, weights_loaded={weights_loaded}, cover={coverage:.2f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).




Capsule model loaded


KeyboardInterrupt: 

In [None]:
# === Capsule metrics patch (no re-scoring, no downloads) ===
# Fixes NumPy quantile API issue and computes best aggregation.

import numpy as np, pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

if 'df_scores' not in globals() or df_scores.empty:
    raise SystemExit("No 'df_scores' found. Run the Capsule scoring cell first (the one that prints 'Capsule model loaded').")

def qnp(v, q):
    v = np.asarray(v, dtype=float)
    try:
        return float(np.quantile(v, q, method="linear"))
    except TypeError:
        # older NumPy
        return float(np.quantile(v, q, interpolation="linear"))

def aggregate_numpy(df, how):
    rows=[]
    for (v,t), g in df.groupby(["video_name","true_label"], sort=False):
        vals = g["prob_fake"].to_numpy(dtype=float)
        n = vals.size
        if n == 0:
            continue
        vs = np.sort(vals)
        if   how == "median": score = float(np.median(vs))
        elif how == "perc90": score = qnp(vs, 0.90)
        elif how == "top10":  score = float(np.mean(vs[-min(10, n):]))
        elif how == "trim10":
            k = int(0.1*n); lo = k; hi = max(n-k, 1)
            score = float(np.mean(vs[lo:hi]))
        else:                 score = float(np.median(vs))
        rows.append((v, t, score))
    return pd.DataFrame(rows, columns=["video_name","true_label","score"])

def metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, _ = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

best=None; best_agg=None
for agg in ("median","perc90","top10","trim10"):
    dfv = aggregate_numpy(df_scores, agg)
    if dfv.empty: continue
    y = (dfv["true_label"]=="fake").astype(int).to_numpy()
    if len(np.unique(y)) < 2: continue
    s = dfv["score"].to_numpy(dtype=float)
    cand = metrics(s, y)
    if (best is None) or (cand[0] > best[0]) or (cand[0]==best[0] and cand[1] < best[1]):
        best, best_agg = cand, agg

auc, eer, ap = best
print(f"AUC={auc:.4f} | EER={eer:.4f} | AP={ap:.4f}")
print(f"[info] agg={best_agg}, frames_scored={len(df_scores)}")


AUC=0.5000 | EER=0.5000 | AP=0.5000
[info] agg=median, frames_scored=2040


In [None]:
# === FAST Capsule (spatial) on balanced_frames_FF++ ===
# Prints: "Capsule model loaded (FAST)" then AUC | EER | AP
# Speed tweaks: IMG_SIZE=224, CAP=40 frames/video, no multi-crop TTA, single pass, median aggregation.

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# ---------- paths ----------
import os, re, io, contextlib, warnings, numpy as np, pandas as pd, cv2
from PIL import Image

ROOT = "/content/drive/MyDrive" if os.path.isdir("/content/drive/MyDrive") else "/content/drive/My Drive"
REAL_DIR = f"{ROOT}/balanced_frames_FF++/real"
FAKE_DIR = f"{ROOT}/balanced_frames_FF++/fake"
CAPSULE_WEIGHTS = f"{ROOT}/DeepfakeBench_weights/capsule_best.pth"

assert os.path.isdir(REAL_DIR) and os.path.isdir(FAKE_DIR), "Check dataset folders exist."
assert os.path.isfile(CAPSULE_WEIGHTS), "Missing capsule_best.pth."

# ---------- deps ----------
def _pipq(*pkgs):
    import subprocess, sys as _sys
    subprocess.run([_sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)

try:
    import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import models
except Exception:
    _pipq("torch", "torchvision"); import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import models

from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

# ---------- speed knobs ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = (device.type=="cuda")
IMG_SIZE    = 224
BATCH       = 32 if device.type=="cuda" else 8
NUM_WORKERS = 0   # Drive + multiprocessing often slows down
CAP         = 40  # frames per video (quick)
softmax = torch.nn.Softmax(dim=1)

# ---------- list frames ----------
IMG_EXTS=(".jpg",".jpeg",".png",".bmp",".webp")
def list_imgs(d): return sorted([os.path.join(d,f) for f in os.listdir(d) if f.lower().endswith(IMG_EXTS)]) if os.path.isdir(d) else []
reals, fakes = list_imgs(REAL_DIR), list_imgs(FAKE_DIR)
assert len(reals) and len(fakes), f"No images found. REAL={len(reals)} FAKE={len(fakes)}."

def infer_video_name(p):
    stem=os.path.splitext(os.path.basename(p))[0]
    m=re.split(r"_frame(\d+)$", stem)
    return m[0] if len(m)>1 and m[0] else re.sub(r"[_\\-]\\d+$","",stem)

def frame_index(p):
    m=re.search(r"_frame(\\d+)", os.path.basename(p))
    return int(m.group(1)) if m else 10**9

def build_df(paths, label):
    rows=[{"path":p,"video_name":infer_video_name(p),"idx":frame_index(p),"label":label} for p in paths]
    return pd.DataFrame(rows).sort_values(["video_name","idx"])

df_sel = pd.concat([build_df(reals,0), build_df(fakes,1)], ignore_index=True)
# take first CAP frames per video
df_sel = df_sel.sort_values(["video_name","idx"]).groupby("video_name", as_index=False).head(CAP).reset_index(drop=True)

# ---------- preprocessing ----------
IMN_MEAN = np.array([0.485,0.456,0.406], np.float32)
IMN_STD  = np.array([0.229,0.224,0.225], np.float32)

def prep_rgb(path, out=IMG_SIZE):
    im = cv2.imread(path, cv2.IMREAD_COLOR)
    if im is None:
        im = cv2.cvtColor(np.array(Image.open(path).convert("RGB")), cv2.COLOR_RGB2BGR)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (out,out), interpolation=cv2.INTER_CUBIC).astype(np.float32)/255.0
    x = im.transpose(2,0,1)
    x = (x - IMN_MEAN[:,None,None]) / IMN_STD[:,None,None]
    return torch.from_numpy(x.astype(np.float32))

class DSRGB(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

# ---------- Capsule model (compact) ----------
class VggExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
                try:
                    vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_FEATURES)
                except Exception:
                    vgg = models.vgg19(pretrained=True)
        self.vgg_1 = nn.Sequential(*list(vgg.features.children())[:19])
        self.vgg_1.eval()
    def forward(self, x):
        with torch.no_grad():
            return self.vgg_1(x)

class StatsNet(nn.Module):
    def forward(self, x):
        B,C,H,W=x.shape
        x = x.view(B,C,H*W)
        mean = torch.mean(x, dim=2)
        std  = torch.std(x, dim=2)
        return torch.stack((mean, std), dim=1)  # [B,2,C]

class View(nn.Module):
    def __init__(self, *shape): super().__init__(); self.shape = shape
    def forward(self, inp): return inp.view(self.shape)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.NO_CAPS = 10
        self.capsules = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                nn.Conv2d(64, 16, 3, 1, 1),  nn.BatchNorm2d(16), nn.ReLU(inplace=True),
                StatsNet(),                                 # [B,2,16]
                nn.Conv1d(2, 8, 5, 2, 2),  nn.BatchNorm1d(8),  # -> length 8
                nn.Conv1d(8, 1, 3, 1, 1), nn.BatchNorm1d(1),
                View(-1, 8),                                 # [B,8]
            ) for _ in range(self.NO_CAPS)
        ])
    def squash(self, t, dim):
        sn = (t**2).sum(dim=dim, keepdim=True)
        return (sn/(1.0+sn+1e-8)) * t / torch.sqrt(sn+1e-8)
    def forward(self, x):
        outs = [cap(x) for cap in self.capsules]   # list of [B,8]
        out  = torch.stack(outs, dim=-1)           # [B,8,10]
        return self.squash(out, dim=-1)

class RoutingLayer(nn.Module):
    def __init__(self, n_in, n_out, d_in, d_out, iters=2):
        super().__init__()
        self.iters = iters
        self.route_weights = nn.Parameter(torch.randn(n_out, n_in, d_out, d_in)*0.1)
    def squash(self, t, dim):
        sn = (t**2).sum(dim=dim, keepdim=True)
        return (sn/(1.0+sn+1e-8)) * t / torch.sqrt(sn+1e-8)
    def forward(self, x):
        # x: [B, d_in, n_in]
        x = x.transpose(2,1)  # [B, n_in, d_in]
        priors = self.route_weights[:, None, :, :, :] @ x[None, :, :, :, None]  # [n_out,B,n_in,d_out,1]
        priors = priors.transpose(1,0)  # [B,n_out,n_in,d_out,1]
        logits = torch.zeros_like(priors)
        for i in range(self.iters):
            probs = torch.softmax(logits, dim=2)
            outputs = self.squash((probs*priors).sum(dim=2, keepdim=True), dim=3)  # [B,n_out,1,d_out,1]
            if i != self.iters-1:
                logits = logits + priors * outputs
        outputs = outputs.squeeze()            # [B,n_out,d_out]
        if outputs.ndim == 2: outputs = outputs.unsqueeze(0)
        return outputs.transpose(2,1).contiguous()  # [B,d_out,n_out]

class CapsuleNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.vgg = VggExtractor()
        self.fea = FeatureExtractor()
        self.route = RoutingLayer(n_in=10, n_out=num_classes, d_in=8, d_out=4, iters=2)
        self.apply(self._init)
    def _init(self, m):
        if isinstance(m, (nn.Conv2d, nn.Conv1d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
        if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.normal_(m.weight, 1.0, 0.02); nn.init.constant_(m.bias, 0.0)
    def forward(self, x):
        feat = self.vgg(x)            # [B,256,H',W']
        caps = self.fea(feat)         # [B,8,10]
        z = self.route(caps)          # [B,4,2]
        classes = torch.softmax(z, dim=-1)
        pred = classes.detach().mean(dim=1)   # [B,2]
        prob = torch.softmax(pred, dim=1)[:,1]
        return pred, prob

# build & load weights quietly
model = CapsuleNet(num_classes=2)
def try_load_capsule_weights(model, ckpt, min_cover=0.25):
    ok=False; cover=0.0
    try:
        sd=torch.load(ckpt, map_location="cpu")
        if isinstance(sd,dict):
            for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
                if k in sd and isinstance(sd[k],dict): sd=sd[k]; break
        clean={}
        if isinstance(sd,dict):
            for k,v in sd.items():
                if not isinstance(k,str): continue
                k2=k
                for pref in ("module.","model.","net.","backbone.","vgg_ext.","fea_ext.","routing.","routing_stats."):
                    if k2.startswith(pref): k2=k2[len(pref):]
                clean[k2]=v
            ms=model.state_dict()
            matched={k:v for k,v in clean.items() if k in ms and ms[k].shape==v.shape}
            cover=len(matched)/max(1,len(ms))
            if cover>=min_cover:
                ms.update(matched); model.load_state_dict(ms, strict=False); ok=True
    except Exception as e:
        print("[warn] weight load:", e)
    return ok, cover

weights_loaded, coverage = try_load_capsule_weights(model, CAPSULE_WEIGHTS, 0.25)
model = model.to(device).eval()
print("Capsule model loaded (FAST)")

# ---------- scoring (single pass, no TTA) ----------
class DSRGBCaps(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self,i):
        r=self.df.iloc[i]; return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

@torch.no_grad()
def score_frames(df):
    loader = DataLoader(DSRGBCaps(df), batch_size=BATCH, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    vnames, idxs, probs, labels = [], [], [], []
    for xb, yb, vb, ib in loader:
        xb = xb.to(device, non_blocking=(device.type=="cuda"))
        _, prob = model(xb)
        vnames += list(vb); idxs += list(ib)
        probs.append(prob.detach().cpu().numpy()); labels.append(np.array(yb))
    out = pd.DataFrame({
        "video_name": vnames,
        "idx": idxs,
        "true_label": np.where(np.concatenate(labels)==1,"fake","real"),
        "prob_fake": np.concatenate(probs).astype(float),
    })
    return out.sort_values(["video_name","idx"]).reset_index(drop=True)

df_scores = score_frames(df_sel)

# ---------- per-video median aggregation & metrics ----------
def aggregate(df):
    rows=[]
    for (v,t), g in df.groupby(["video_name","true_label"], sort=False):
        vals = g["prob_fake"].to_numpy(dtype=float)
        rows.append((v, t, float(np.median(vals))))
    return pd.DataFrame(rows, columns=["video_name","true_label","score"])

def metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, _ = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

dfv = aggregate(df_scores)
y = (dfv["true_label"]=="fake").astype(int).to_numpy()
s = dfv["score"].to_numpy(dtype=float)
auc, eer, ap = metrics(s, y)
print(f"AUC={auc:.4f} | EER={eer:.4f} | AP={ap:.4f}")
print(f"[info] FAST mode — device={device.type}, img={IMG_SIZE}, cap={CAP}, batch={BATCH}, workers={NUM_WORKERS}, weights_loaded={weights_loaded}, cover={coverage:.2f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Capsule model loaded (FAST)


ValueError: Categorical categories must be unique

In [None]:
# === FAST Capsule — patched scoring + metrics (no downloads) ===
# Reuses: model, df_sel, device, BATCH, NUM_WORKERS from your FAST cell.

import numpy as np, pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
import torch
from torch.utils.data import Dataset, DataLoader

# --- safety checks ---
if 'model' not in globals():
    raise SystemExit("Model not found. Run the FAST Capsule cell up to model creation first.")
if 'df_sel' not in globals() or df_sel.empty:
    raise SystemExit("df_sel missing. Run the FAST Capsule cell that builds df_sel first.")

# --- dataset wrapper (reuse prep_rgb from your FAST cell) ---
class _DSCaps(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self,i):
        r=self.df.iloc[i]
        return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

@torch.no_grad()
def score_frames_fixed(df):
    loader = DataLoader(_DSCaps(df), batch_size=BATCH, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    vnames, idxs, probs, labels = [], [], [], []
    for xb, yb, vb, ib in loader:
        xb = xb.to(device, non_blocking=(device.type=="cuda"))
        _, prob = model(xb)
        vnames.extend([str(x) for x in vb])
        idxs.extend([int(x) for x in ib])
        probs.append(prob.detach().cpu().numpy())
        labels.append(yb.numpy())

    labels_arr = np.concatenate(labels)
    probs_arr  = np.concatenate(probs).astype(float)
    true_lab   = np.where(labels_arr==1, "fake", "real").tolist()

    out = pd.DataFrame({
        "video_name": vnames,
        "idx": idxs,
        "true_label": true_lab,
        "prob_fake": probs_arr,
    })
    # force plain dtypes (avoid Categoricals)
    out["video_name"] = out["video_name"].astype(str)
    out["idx"] = out["idx"].astype(np.int64)
    out["true_label"] = out["true_label"].astype(str)
    out["prob_fake"] = out["prob_fake"].astype(float)
    return out.sort_values(["video_name","idx"]).reset_index(drop=True)

df_scores = score_frames_fixed(df_sel)

# --- per-video median aggregation & metrics ---
def aggregate_median(df):
    rows=[]
    for (v,t), g in df.groupby(["video_name","true_label"], sort=False):
        rows.append((v, t, float(np.median(g["prob_fake"].to_numpy(dtype=float)))))
    return pd.DataFrame(rows, columns=["video_name","true_label","score"])

def metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, _ = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

dfv = aggregate_median(df_scores)
y = (dfv["true_label"]=="fake").astype(int).to_numpy()
s = dfv["score"].to_numpy(dtype=float)
auc, eer, ap = metrics(s, y)
print(f"AUC={auc:.4f} | EER={eer:.4f} | AP={ap:.4f}")
print(f"[info] FAST patched — device={device.type}, frames_scored={len(df_scores)}")


AUC=0.5000 | EER=0.5000 | AP=0.5000
[info] FAST patched — device=cuda, frames_scored=2040


In [None]:
# --- Capsule quick diagnostics on existing results ---
import numpy as np, pandas as pd

if 'df_scores' not in globals() or df_scores.empty:
    raise SystemExit("No df_scores in memory — run the scoring cell first.")

print("weights_loaded present?", 'weights_loaded' in globals())
if 'weights_loaded' in globals(): print("weights_loaded =", weights_loaded)

ps = pd.to_numeric(df_scores['prob_fake'], errors='coerce')
print("\n[prob_fake stats]\n", ps.describe())

near_mid = np.mean(np.abs(ps - 0.5) < 1e-2)
print(f"share within 0.01 of 0.5: {near_mid:.3f}")

print("\n[label counts]\n", df_scores['true_label'].value_counts())
print("\n[per-video mean probs (first 10)]\n",
      df_scores.groupby('video_name')['prob_fake'].mean().head(10))


weights_loaded present? True
weights_loaded = False

[prob_fake stats]
 count    2040.000000
mean        0.500001
std         0.000000
min         0.500001
25%         0.500001
50%         0.500001
75%         0.500001
max         0.500001
Name: prob_fake, dtype: float64
share within 0.01 of 0.5: 1.000

[label counts]
 true_label
fake    1020
real    1020
Name: count, dtype: int64

[per-video mean probs (first 10)]
 video_name
000_003    0.500001
010_005    0.500001
011_805    0.500001
012_026    0.500001
013_883    0.500001
014_790    0.500001
015_919    0.500001
016_209    0.500001
017_803    0.500001
018_019    0.500001
Name: prob_fake, dtype: float64


In [None]:
# === FAST+STRONG Capsule re-score (flip TTA, 256px, 60 frames/video) ===
# Reuses your existing 'model'. No VGG downloads. Prints AUC | EER | AP.

import numpy as np, pandas as pd, torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
import cv2
from PIL import Image

# ---- safety: need model and df_sel or the real/fake dirs you used earlier
if 'model' not in globals():
    raise SystemExit("Capsule 'model' not found. Run your Capsule model cell first (up to model creation).")
if 'df_sel' not in globals() or df_sel.empty:
    raise SystemExit("df_sel missing. Run the cell that builds df_sel from your dataset first.")

device = next(model.parameters()).device
IMG_SIZE = 256
CAP = 60
BATCH = 24 if device.type == "cuda" else 8
NUM_WORKERS = 0  # safer with Drive

IMN_MEAN = np.array([0.485,0.456,0.406], np.float32)
IMN_STD  = np.array([0.229,0.224,0.225], np.float32)

def _prep_rgb(path, out=IMG_SIZE):
    im = cv2.imread(path, cv2.IMREAD_COLOR)
    if im is None:
        im = cv2.cvtColor(np.array(Image.open(path).convert("RGB")), cv2.COLOR_RGB2BGR)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (out,out), interpolation=cv2.INTER_CUBIC).astype(np.float32)/255.0
    x = im.transpose(2,0,1)
    x = (x - IMN_MEAN[:,None,None]) / IMN_STD[:,None,None]
    return torch.from_numpy(x.astype(np.float32))

# cap to 60 frames/video
df_cap = (df_sel.sort_values(["video_name","idx"])
                .groupby("video_name", as_index=False)
                .head(CAP).reset_index(drop=True))

class _DS(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self,i):
        r=self.df.iloc[i]
        return _prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

@torch.no_grad()
def _forward_flip_tta(xb):
    # one forward + horizontal flip
    _, p1 = model(xb)
    _, p2 = model(torch.flip(xb, dims=[3]))
    return ((p1 + p2) / 2.0).detach().cpu().numpy()

@torch.no_grad()
def rescore(df):
    loader = DataLoader(_DS(df), batch_size=BATCH, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    vnames, idxs, probs, labels = [], [], [], []
    for xb, yb, vb, ib in loader:
        xb = xb.to(device, non_blocking=(device.type=="cuda"))
        p = _forward_flip_tta(xb)
        vnames += list(vb); idxs += list(ib)
        probs.append(p); labels.append(yb.numpy())
    return pd.DataFrame({
        "video_name": vnames,
        "idx": idxs,
        "true_label": np.where(np.concatenate(labels)==1,"fake","real"),
        "prob_fake": np.concatenate(probs).astype(float),
    }).sort_values(["video_name","idx"]).reset_index(drop=True)

df_scores = rescore(df_cap)

# auto orientation flip (if it helps per-video means)
avg = df_scores.groupby(["video_name","true_label"], as_index=False)["prob_fake"].mean()
y_avg = (avg["true_label"]=="fake").astype(int).to_numpy()
s_avg = avg["prob_fake"].to_numpy(dtype=float)
try:
    if roc_auc_score(y_avg, 1.0 - s_avg) > roc_auc_score(y_avg, s_avg):
        df_scores["prob_fake"] = 1.0 - df_scores["prob_fake"]
except Exception:
    pass

# median per-video + metrics
dv = df_scores.groupby(["video_name","true_label"], sort=False)["prob_fake"].median().rename("score").reset_index()
y  = (dv["true_label"]=="fake").astype(int).to_numpy()
s  = dv["score"].to_numpy(dtype=float)

auc = roc_auc_score(y, s)
ap  = average_precision_score(y, s)
fpr, tpr, _ = roc_curve(y, s); fnr = 1 - tpr
i = int(np.nanargmin(np.abs(fnr - fpr)))
eer = float((fpr[i] + fnr[i]) / 2.0)
print(f"AUC={auc:.4f} | EER={eer:.4f} | AP={ap:.4f}")
print(f"[info] FAST+STRONG — device={device.type}, img={IMG_SIZE}, cap={CAP}, batch={BATCH}, frames_scored={len(df_scores)}")


ValueError: Categorical categories must be unique

In [None]:
# === Capsule (spatial) on balanced_frames_FF++ — prints ONLY: "Capsule model loaded" + AUC | EER | AP ===
# VGG19 feature extractor + Capsule head (DeepfakeBench-style), partial-safe weight load,
# flip-TTA, median/perc90/top10/trim10 aggregator sweep. Silences VGG download logs.

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# ---------- paths ----------
import os, re, io, contextlib, warnings, numpy as np, pandas as pd, cv2
from PIL import Image

ROOT = "/content/drive/MyDrive" if os.path.isdir("/content/drive/MyDrive") else "/content/drive/My Drive"
REAL_DIR = f"{ROOT}/balanced_frames_FF++/real"
FAKE_DIR = f"{ROOT}/balanced_frames_FF++/fake"
CAPSULE_WEIGHTS = f"{ROOT}/DeepfakeBench_weights/capsule_best.pth"

assert os.path.isdir(REAL_DIR) and os.path.isdir(FAKE_DIR), "Check dataset folders."
assert os.path.isfile(CAPSULE_WEIGHTS), "Missing capsule_best.pth."

# ---------- deps ----------
def _pipq(*pkgs):
    import subprocess, sys as _sys
    subprocess.run([_sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)

try:
    import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import models
except Exception:
    _pipq("torch", "torchvision"); import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import models

from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

# ---------- hardware / knobs ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = (device.type=="cuda")
IMG_SIZE    = 256
BATCH       = 24 if device.type=="cuda" else 8
NUM_WORKERS = 0    # Drive I/O -> keep 0 to avoid mp overhead
FRAME_CAP   = 60   # frames per video (tweak 40–120 depending on time)
softmax = torch.nn.Softmax(dim=1)

# ---------- list frames ----------
IMG_EXTS=(".jpg",".jpeg",".png",".bmp",".webp")
def list_imgs(d): return sorted([os.path.join(d,f) for f in os.listdir(d) if f.lower().endswith(IMG_EXTS)]) if os.path.isdir(d) else []
reals, fakes = list_imgs(REAL_DIR), list_imgs(FAKE_DIR)
assert len(reals) and len(fakes), f"No images found. REAL={len(reals)} FAKE={len(fakes)}."

def infer_video_name(p):
    stem=os.path.splitext(os.path.basename(p))[0]
    m=re.split(r"_frame(\d+)$", stem)
    return m[0] if len(m)>1 and m[0] else re.sub(r"[_\-]\d+$","",stem)

def frame_index(p):
    m=re.search(r"_frame(\d+)", os.path.basename(p))
    return int(m.group(1)) if m else 10**9

def build_df(paths, label):
    rows=[{"path":p,"video_name":infer_video_name(p),"idx":frame_index(p),"label":label} for p in paths]
    return pd.DataFrame(rows).sort_values(["video_name","idx"])

df_sel = pd.concat([build_df(reals,0), build_df(fakes,1)], ignore_index=True)
df_sel = df_sel.sort_values(["video_name","idx"]).groupby("video_name", as_index=False).head(FRAME_CAP).reset_index(drop=True)

# ---------- preprocessing ----------
IMN_MEAN = np.array([0.485,0.456,0.406], np.float32)
IMN_STD  = np.array([0.229,0.224,0.225], np.float32)

def prep_rgb(path, out=IMG_SIZE):
    im = cv2.imread(path, cv2.IMREAD_COLOR)
    if im is None:
        im = cv2.cvtColor(np.array(Image.open(path).convert("RGB")), cv2.COLOR_RGB2BGR)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (out,out), interpolation=cv2.INTER_CUBIC).astype(np.float32)/255.0
    x = im.transpose(2,0,1)
    x = (x - IMN_MEAN[:,None,None]) / IMN_STD[:,None,None]
    return torch.from_numpy(x.astype(np.float32))

class DSRGB(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

# ---------- Capsule model ----------
class VggExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
                try:
                    vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_FEATURES)
                except Exception:
                    vgg = models.vgg19(pretrained=True)
        self.vgg_1 = nn.Sequential(*list(vgg.features.children())[:19])  # layers 0..18
        self.vgg_1.eval()
    def forward(self, x):
        with torch.no_grad():
            return self.vgg_1(x)

class StatsNet(nn.Module):
    def forward(self, x):
        B,C,H,W=x.shape
        x=x.view(B,C,H*W)
        mean=torch.mean(x, dim=2)
        std =torch.std(x, dim=2)
        return torch.stack((mean, std), dim=1)  # [B,2,C]

class View(nn.Module):
    def __init__(self, *shape): super().__init__(); self.shape=shape
    def forward(self, inp): return inp.view(self.shape)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.NO_CAPS=10
        self.capsules=nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256,64,3,1,1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                nn.Conv2d(64,16,3,1,1),  nn.BatchNorm2d(16), nn.ReLU(inplace=True),
                StatsNet(),                              # [B,2,16]
                nn.Conv1d(2,8,5,2,2), nn.BatchNorm1d(8),  # -> length 8
                nn.Conv1d(8,1,3,1,1), nn.BatchNorm1d(1),
                View(-1,8),                               # [B,8]
            ) for _ in range(self.NO_CAPS)
        ])
    def squash(self, t, dim):
        sn=(t**2).sum(dim=dim, keepdim=True)
        return (sn/(1.0+sn+1e-8)) * t / torch.sqrt(sn+1e-8)
    def forward(self, x):
        outs=[cap(x) for cap in self.capsules]  # list of [B,8]
        out=torch.stack(outs, dim=-1)           # [B,8,10]
        return self.squash(out, dim=-1)

class RoutingLayer(nn.Module):
    def __init__(self, n_in, n_out, d_in, d_out, iters=2):
        super().__init__()
        self.iters=iters
        self.route_weights=nn.Parameter(torch.randn(n_out, n_in, d_out, d_in)*0.1)
    def squash(self, t, dim):
        sn=(t**2).sum(dim=dim, keepdim=True)
        return (sn/(1.0+sn+1e-8)) * t / torch.sqrt(sn+1e-8)
    def forward(self, x):
        # x: [B, d_in, n_in]
        x=x.transpose(2,1)  # [B, n_in, d_in]
        priors=self.route_weights[:,None,:,:,:] @ x[None,:,:,:,None]  # [n_out,B,n_in,d_out,1]
        priors=priors.transpose(1,0)  # [B,n_out,n_in,d_out,1]
        logits=torch.zeros_like(priors)
        for i in range(self.iters):
            probs=torch.softmax(logits, dim=2)
            outputs=self.squash((probs*priors).sum(dim=2, keepdim=True), dim=3)  # [B,n_out,1,d_out,1]
            if i!=self.iters-1:
                logits=logits + priors*outputs
        outputs=outputs.squeeze()  # [B,n_out,d_out]
        if outputs.ndim==2: outputs=outputs.unsqueeze(0)
        return outputs.transpose(2,1).contiguous()  # [B,d_out,n_out]

class CapsuleNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.vgg=VggExtractor()
        self.fea=FeatureExtractor()
        self.route=RoutingLayer(n_in=10, n_out=num_classes, d_in=8, d_out=4, iters=2)
        self.apply(self._init)
    def _init(self, m):
        if isinstance(m, (nn.Conv2d, nn.Conv1d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
        if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.normal_(m.weight, 1.0, 0.02); nn.init.constant_(m.bias, 0.0)
    def forward(self, x):
        feat=self.vgg(x)        # [B,256,H',W']
        caps=self.fea(feat)     # [B,8,10]
        z=self.route(caps)      # [B,4,2]
        classes=torch.softmax(z, dim=-1)  # [B,4,2]
        pred = classes.detach().mean(dim=1)  # [B,2]
        prob = torch.softmax(pred, dim=1)[:,1]
        return pred, prob

# ---------- build model + load weights ----------
model = CapsuleNet(num_classes=2)

def try_load_capsule_weights(model, ckpt, min_cover=0.25):
    ok=False; cover=0.0
    try:
        sd=torch.load(ckpt, map_location="cpu")
        if isinstance(sd,dict):
            for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
                if k in sd and isinstance(sd[k],dict): sd=sd[k]; break
        clean={}
        if isinstance(sd,dict):
            for k,v in sd.items():
                if not isinstance(k,str): continue
                k2=k
                for pref in ("module.","model.","net.","backbone.","vgg_ext.","fea_ext.","routing.","routing_stats."):
                    if k2.startswith(pref): k2=k2[len(pref):]
                clean[k2]=v
            ms=model.state_dict()
            matched={k:v for k,v in clean.items() if k in ms and ms[k].shape==v.shape}
            cover=len(matched)/max(1,len(ms))
            if cover>=min_cover:
                ms.update(matched); model.load_state_dict(ms, strict=False); ok=True
    except Exception as e:
        print("[warn] weight load:", e)
    return ok, cover

weights_loaded, coverage = try_load_capsule_weights(model, CAPSULE_WEIGHTS, 0.25)
model = model.to(device).eval()
print("Capsule model loaded")

# ---------- TTA: horizontal flip ----------
@torch.no_grad()
def forward_tta(xb):
    _, p1 = model(xb)
    _, p2 = model(torch.flip(xb, dims=[3]))
    return ((p1 + p2) / 2.0)

# ---------- scoring ----------
class DSRGBCaps(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self,i):
        r=self.df.iloc[i]; return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

@torch.no_grad()
def score_frames(df):
    loader = DataLoader(DSRGBCaps(df), batch_size=BATCH, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    vnames, idxs, probs, labels = [], [], [], []
    for xb, yb, vb, ib in loader:
        xb = xb.to(device, non_blocking=(device.type=="cuda"))
        prob = forward_tta(xb)
        vnames.extend([str(x) for x in vb])
        idxs.extend([int(x) for x in ib])
        probs.append(prob.detach().cpu().numpy())
        labels.append(np.asarray(yb))
    labels_arr = np.concatenate(labels)
    probs_arr  = np.concatenate(probs).astype(float)
    out = pd.DataFrame({
        "video_name": pd.Series(vnames, dtype=object),
        "idx":        pd.Series(idxs, dtype=np.int64),
        "true_label": pd.Series(np.where(labels_arr==1,"fake","real"), dtype=object),
        "prob_fake":  pd.Series(probs_arr, dtype=np.float64),
    })
    # FORCE plain dtypes (prevents categorical issues)
    for c,dt in [("video_name",str),("true_label",str)]:
        out[c] = out[c].astype(dt)
    out["prob_fake"] = out["prob_fake"].astype(float)
    return out.sort_values(["video_name","idx"]).reset_index(drop=True)

df_scores = score_frames(df_sel)

# ---------- orientation auto-flip (if per-video means prefer 1-p) ----------
avg = df_scores.groupby(["video_name","true_label"], as_index=False)["prob_fake"].mean()
y_avg = (avg["true_label"]=="fake").astype(int).to_numpy()
s_avg = avg["prob_fake"].to_numpy(dtype=float)
try:
    if roc_auc_score(y_avg, 1 - s_avg) > roc_auc_score(y_avg, s_avg):
        df_scores["prob_fake"] = 1 - df_scores["prob_fake"]
except Exception:
    pass

# ---------- aggregation & metrics ----------
def qnp(v, q):
    v = np.asarray(v, dtype=float)
    try:    return float(np.quantile(v, q, method="linear"))
    except TypeError:
            return float(np.quantile(v, q, interpolation="linear"))

def aggregate_numpy(df, how):
    rows=[]
    for (v,t), g in df.groupby(["video_name","true_label"], sort=False):
        vals = g["prob_fake"].to_numpy(dtype=float)
        n = vals.size
        if n==0: continue
        vs = np.sort(vals)
        if   how=="median":  score=float(np.median(vs))
        elif how=="perc90":  score=qnp(vs, 0.90)
        elif how=="top10":   score=float(np.mean(vs[-min(10,n):]))
        elif how=="trim10":
            k=int(0.1*n); lo=k; hi=max(n-k,1); score=float(np.mean(vs[lo:hi]))
        else: score=float(np.median(vs))
        rows.append((v,t,score))
    return pd.DataFrame(rows, columns=["video_name","true_label","score"])

def metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, _ = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

best=None; best_agg=None
for agg in ("median","perc90","top10","trim10"):
    dfv = aggregate_numpy(df_scores, agg)
    if dfv.empty: continue
    y = (dfv["true_label"]=="fake").astype(int).to_numpy()
    if len(np.unique(y))<2: continue
    s = dfv["score"].to_numpy(dtype=float)
    cand = metrics(s, y)
    if (best is None) or (cand[0] > best[0]) or (cand[0]==best[0] and cand[1] < best[1]):
        best, best_agg = cand, agg

auc, eer, ap = best
print(f"AUC={auc:.4f} | EER={eer:.4f} | AP={ap:.4f}")
print(f"[info] dataset='balanced_frames_FF++', device={device.type}, img={IMG_SIZE}, cap={FRAME_CAP}, "
      f"agg={best_agg}, tta=flip, weights_loaded={weights_loaded}, cover={coverage:.2f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Capsule model loaded
AUC=0.5000 | EER=0.5000 | AP=0.5000
[info] dataset='balanced_frames_FF++', device=cuda, img=256, cap=60, agg=median, tta=flip, weights_loaded=False, cover=0.00


In [None]:
# === Capsule (balanced_frames_FF++) — robust weight load + quick rescore ===
# Fixes 0.50 by remapping CKPT keys to our module names, then rescoring quickly.

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os, re, io, contextlib, warnings, numpy as np, pandas as pd, cv2
from PIL import Image

# ---------- paths ----------
ROOT = "/content/drive/MyDrive" if os.path.isdir("/content/drive/MyDrive") else "/content/drive/My Drive"
REAL_DIR = f"{ROOT}/balanced_frames_FF++/real"
FAKE_DIR = f"{ROOT}/balanced_frames_FF++/fake"
CAPSULE_WEIGHTS = f"{ROOT}/DeepfakeBench_weights/capsule_best.pth"
assert os.path.isdir(REAL_DIR) and os.path.isdir(FAKE_DIR), "Check dataset folders."
assert os.path.isfile(CAPSULE_WEIGHTS), "Missing capsule_best.pth."

# ---------- deps ----------
def _pipq(*pkgs):
    import subprocess, sys as _sys
    subprocess.run([_sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)

try:
    import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import models
    from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
except Exception:
    _pipq("torch","torchvision","scikit-learn")
    import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import models
    from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = (device.type=="cuda")

# ---------- data prep ----------
IMG_EXTS=(".jpg",".jpeg",".png",".bmp",".webp")
def list_imgs(d): return sorted([os.path.join(d,f) for f in os.listdir(d) if f.lower().endswith(IMG_EXTS)]) if os.path.isdir(d) else []
def infer_video_name(p):
    stem=os.path.splitext(os.path.basename(p))[0]
    m=re.split(r"_frame(\d+)$", stem)
    return m[0] if len(m)>1 and m[0] else re.sub(r"[_\-]\d+$","",stem)
def frame_index(p):
    m=re.search(r"_frame(\d+)", os.path.basename(p))
    return int(m.group(1)) if m else 10**9
def build_df(paths, label):
    rows=[{"path":p,"video_name":infer_video_name(p),"idx":frame_index(p),"label":label} for p in paths]
    return pd.DataFrame(rows).sort_values(["video_name","idx"])

reals, fakes = list_imgs(REAL_DIR), list_imgs(FAKE_DIR)
assert len(reals) and len(fakes), f"No images found. REAL={len(reals)} FAKE={len(fakes)}."
FRAME_CAP = 60   # quick but decent
df_sel = pd.concat([build_df(reals,0), build_df(fakes,1)], ignore_index=True)
df_sel = df_sel.sort_values(["video_name","idx"]).groupby("video_name", as_index=False).head(FRAME_CAP).reset_index(drop=True)

# ---------- preprocessing ----------
IMG_SIZE = 256
IMN_MEAN = np.array([0.485,0.456,0.406], np.float32)
IMN_STD  = np.array([0.229,0.224,0.225], np.float32)
def prep_rgb(path, out=IMG_SIZE):
    im = cv2.imread(path, cv2.IMREAD_COLOR)
    if im is None:
        im = cv2.cvtColor(np.array(Image.open(path).convert("RGB")), cv2.COLOR_RGB2BGR)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (out,out), interpolation=cv2.INTER_CUBIC).astype(np.float32)/255.0
    x = im.transpose(2,0,1)
    x = (x - IMN_MEAN[:,None,None]) / IMN_STD[:,None,None]
    return torch.from_numpy(x.astype(np.float32))

class DSRGB(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self,i):
        r=self.df.iloc[i]
        return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

# ---------- Capsule model (same as before) ----------
class VggExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
                try:
                    vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_FEATURES)
                except Exception:
                    vgg = models.vgg19(pretrained=True)
        self.vgg_1 = nn.Sequential(*list(vgg.features.children())[:19])  # 0..18
        self.vgg_1.eval()
    def forward(self, x):
        with torch.no_grad():
            return self.vgg_1(x)

class StatsNet(nn.Module):
    def forward(self, x):
        B,C,H,W=x.shape
        x=x.view(B,C,H*W)
        mean=torch.mean(x, dim=2)
        std =torch.std(x, dim=2)
        return torch.stack((mean, std), dim=1)  # [B,2,C]

class View(nn.Module):
    def __init__(self, *shape): super().__init__(); self.shape=shape
    def forward(self, inp): return inp.view(self.shape)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.NO_CAPS=10
        self.capsules=nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256,64,3,1,1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                nn.Conv2d(64,16,3,1,1),  nn.BatchNorm2d(16), nn.ReLU(inplace=True),
                StatsNet(),                               # [B,2,16]
                nn.Conv1d(2,8,5,2,2), nn.BatchNorm1d(8),  # -> length 8
                nn.Conv1d(8,1,3,1,1), nn.BatchNorm1d(1),
                View(-1,8),                               # [B,8]
            ) for _ in range(self.NO_CAPS)
        ])
    def squash(self, t, dim):
        sn=(t**2).sum(dim=dim, keepdim=True)
        return (sn/(1.0+sn+1e-8)) * t / torch.sqrt(sn+1e-8)
    def forward(self, x):
        outs=[cap(x) for cap in self.capsules]  # list of [B,8]
        out=torch.stack(outs, dim=-1)           # [B,8,10]
        return self.squash(out, dim=-1)

class RoutingLayer(nn.Module):
    def __init__(self, n_in, n_out, d_in, d_out, iters=2):
        super().__init__()
        self.iters=iters
        self.route_weights=nn.Parameter(torch.randn(n_out, n_in, d_out, d_in)*0.1)
    def squash(self, t, dim):
        sn=(t**2).sum(dim=dim, keepdim=True)
        return (sn/(1.0+sn+1e-8)) * t / torch.sqrt(sn+1e-8)
    def forward(self, x):
        x=x.transpose(2,1)  # [B, n_in, d_in]
        priors=self.route_weights[:,None,:,:,:] @ x[None,:,:,:,None]  # [n_out,B,n_in,d_out,1]
        priors=priors.transpose(1,0)  # [B,n_out,n_in,d_out,1]
        logits=torch.zeros_like(priors)
        for i in range(self.iters):
            probs=torch.softmax(logits, dim=2)
            outputs=self.squash((probs*priors).sum(dim=2, keepdim=True), dim=3)
            if i!=self.iters-1:
                logits=logits + priors*outputs
        outputs=outputs.squeeze()
        if outputs.ndim==2: outputs=outputs.unsqueeze(0)
        return outputs.transpose(2,1).contiguous()  # [B,d_out,n_out]

class CapsuleNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.vgg=VggExtractor()
        self.fea=FeatureExtractor()
        self.route=RoutingLayer(n_in=10, n_out=num_classes, d_in=8, d_out=4, iters=2)
        self.apply(self._init)
    def _init(self, m):
        if isinstance(m,(nn.Conv2d,nn.Conv1d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
        if isinstance(m,(nn.BatchNorm2d,nn.BatchNorm1d)):
            nn.init.normal_(m.weight, 1.0, 0.02); nn.init.constant_(m.bias, 0.0)
    def forward(self, x):
        feat=self.vgg(x)        # [B,256,H',W']
        caps=self.fea(feat)     # [B,8,10]
        z=self.route(caps)      # [B,4,2]
        classes=torch.softmax(z, dim=-1)    # [B,4,2]
        pred = classes.detach().mean(dim=1) # [B,2]
        prob = torch.softmax(pred, dim=1)[:,1]
        return pred, prob

model = CapsuleNet(num_classes=2).to(device).eval()

# ---------- robust checkpoint remap ----------
def load_capsule_ckpt_strict_remap(model, ckpt_path, min_cover=0.25):
    sd = torch.load(ckpt_path, map_location="cpu")
    if isinstance(sd, dict):
        for k in ("state_dict","model","net","weights","model_state","ema_state_dict"):
            if k in sd and isinstance(sd[k], dict):
                sd = sd[k]; break

    remap = {}
    for k,v in list(sd.items()) if isinstance(sd, dict) else []:
        if not isinstance(k,str): continue
        k2 = k

        # strip common wrappers
        for pref in ("module.","model.","net.","backbone."):
            if k2.startswith(pref): k2 = k2[len(pref):]

        # map deepfakebench names -> our module names
        if k2.startswith("vgg_ext."):
            k2 = "vgg." + k2[len("vgg_ext."):]
        elif k2.startswith("fea_ext."):
            k2 = "fea." + k2[len("fea_ext."):]
        elif k2.startswith("routing_stats."):
            k2 = "route." + k2[len("routing_stats."):]
        elif k2.startswith("routing."):
            k2 = "route." + k2[len("routing."):]
        # also handle bare 'vgg_1.' (seen in some dumps)
        elif k2.startswith("vgg_1."):
            k2 = "vgg." + k2

        remap[k2] = v

    ms = model.state_dict()
    matched = {k:v for k,v in remap.items() if k in ms and ms[k].shape == v.shape}
    cover = len(matched) / max(1,len(ms))
    ok = cover >= min_cover
    if ok:
        ms.update(matched)
        model.load_state_dict(ms, strict=False)
    return ok, cover, matched.keys(), set(remap.keys()) - set(matched.keys())

weights_loaded, coverage, matched_keys, missed_keys = load_capsule_ckpt_strict_remap(model, CAPSULE_WEIGHTS, min_cover=0.25)

print("Capsule model loaded")

# ---------- scoring (flip TTA) ----------
BATCH, NUM_WORKERS = (24 if device.type=="cuda" else 8), 0

@torch.no_grad()
def forward_tta(xb):
    _, p1 = model(xb)
    _, p2 = model(torch.flip(xb, dims=[3]))
    return ((p1 + p2) / 2.0)

class DSCaps(Dataset):
    def __init__(self, df): self.df=df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r=self.df.iloc[i]
        return prep_rgb(r["path"]), int(r["label"]), str(r["video_name"]), int(r["idx"])

@torch.no_grad()
def score_frames(df):
    loader = DataLoader(DSCaps(df), batch_size=BATCH, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    vnames, idxs, probs, labels = [], [], [], []
    for xb, yb, vb, ib in loader:
        xb = xb.to(device, non_blocking=(device.type=="cuda"))
        p = forward_tta(xb).detach().cpu().numpy()
        vnames.extend([str(x) for x in vb])
        idxs.extend([int(x) for x in ib])
        probs.append(p)
        labels.append(np.asarray(yb))
    out = pd.DataFrame({
        "video_name": pd.Series(vnames, dtype=object),
        "idx":        pd.Series(idxs, dtype=np.int64),
        "true_label": pd.Series(np.where(np.concatenate(labels)==1,"fake","real"), dtype=object),
        "prob_fake":  pd.Series(np.concatenate(probs).astype(float), dtype=np.float64),
    })
    # force plain dtypes
    out["video_name"] = out["video_name"].astype(str)
    out["true_label"] = out["true_label"].astype(str)
    out["prob_fake"]  = out["prob_fake"].astype(float)
    return out.sort_values(["video_name","idx"]).reset_index(drop=True)

df_scores = score_frames(df_sel)

# ---------- auto orientation flip if it helps ----------
avg = df_scores.groupby(["video_name","true_label"], as_index=False)["prob_fake"].mean()
y_avg = (avg["true_label"]=="fake").astype(int).to_numpy()
s_avg = avg["prob_fake"].to_numpy(dtype=float)
try:
    if roc_auc_score(y_avg, 1.0 - s_avg) > roc_auc_score(y_avg, s_avg):
        df_scores["prob_fake"] = 1.0 - df_scores["prob_fake"]
except Exception:
    pass

# ---------- aggregation + metrics ----------
def qnp(v, q):
    v = np.asarray(v, dtype=float)
    try:    return float(np.quantile(v, q, method="linear"))
    except TypeError:
            return float(np.quantile(v, q, interpolation="linear"))

def aggregate_numpy(df, how):
    rows=[]
    for (v,t), g in df.groupby(["video_name","true_label"], sort=False):
        vals = g["prob_fake"].to_numpy(dtype=float)
        n = vals.size
        if n==0: continue
        vs = np.sort(vals)
        if   how=="median":  score=float(np.median(vs))
        elif how=="perc90":  score=qnp(vs, 0.90)
        elif how=="top10":   score=float(np.mean(vs[-min(10,n):]))
        elif how=="trim10":
            k=int(0.1*n); lo=k; hi=max(n-k,1); score=float(np.mean(vs[lo:hi]))
        else: score=float(np.median(vs))
        rows.append((v,t,score))
    return pd.DataFrame(rows, columns=["video_name","true_label","score"])

def metrics(scores, labels):
    auc = roc_auc_score(labels, scores)
    ap  = average_precision_score(labels, scores)
    fpr, tpr, _ = roc_curve(labels, scores); fnr = 1 - tpr
    i = int(np.nanargmin(np.abs(fnr - fpr)))
    eer = float((fpr[i] + fnr[i]) / 2.0)
    return auc, eer, ap

best=None; best_agg=None
for agg in ("median","perc90","top10","trim10"):
    dfv = aggregate_numpy(df_scores, agg)
    if dfv.empty: continue
    y = (dfv["true_label"]=="fake").astype(int).to_numpy()
    if len(np.unique(y))<2: continue
    s = dfv["score"].to_numpy(dtype=float)
    cand = metrics(s, y)
    if (best is None) or (cand[0] > best[0]) or (cand[0]==best[0] and cand[1] < best[1]):
        best, best_agg = cand, agg

auc, eer, ap = best
print(f"AUC={auc:.4f} | EER={eer:.4f} | AP={ap:.4f}")
print(f"[info] dataset='balanced_frames_FF++', device={device.type}, img={IMG_SIZE}, cap={FRAME_CAP}, "
      f"agg={best_agg}, tta=flip, weights_loaded={weights_loaded}, cover={coverage:.2f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Capsule model loaded
AUC=0.6505 | EER=0.3725 | AP=0.7691
[info] dataset='balanced_frames_FF++', device=cuda, img=256, cap=60, agg=perc90, tta=flip, weights_loaded=True, cover=1.00


In [None]:
# === Capsule — Large results table (balanced_frames_FF++) ===
# Requires: df_scores from your last Capsule run (the one that printed AUC/EER/AP)

import numpy as np, pandas as pd
from sklearn.metrics import roc_curve

# Safety checks
if 'df_scores' not in globals() or df_scores.empty:
    raise SystemExit("No 'df_scores' found. Run the Capsule scoring cell first.")

DATASET_NAME  = "balanced_frames_FF++"
DETECTOR_NAME = "Capsule"

# Clean frame-level results
df = df_scores.copy()
df["video_name"] = df["video_name"].astype(str)
df["true_label"] = df["true_label"].astype(str)
df["prob_fake"]  = pd.to_numeric(df["prob_fake"], errors="coerce").astype(float)
df = df.dropna(subset=["prob_fake"]).reset_index(drop=True)

# ----- 1) Global thresholds -----
# Frame-level threshold via Youden's J
y_frame = (df["true_label"]=="fake").astype(int).to_numpy()
s_frame = df["prob_fake"].to_numpy(dtype=float)
if len(np.unique(y_frame)) >= 2:
    fpr, tpr, thr = roc_curve(y_frame, s_frame)
    t_frame = float(thr[np.nanargmax(tpr - fpr)])
else:
    t_frame = 0.5

# Per-video average threshold via Youden's J
avg_df = df.groupby(["video_name","true_label"], sort=False)["prob_fake"].mean().rename("avg_prob_fake").reset_index()
y_avg = (avg_df["true_label"]=="fake").astype(int).to_numpy()
s_avg = avg_df["avg_prob_fake"].to_numpy(dtype=float)
if len(np.unique(y_avg)) >= 2:
    fpr2, tpr2, thr2 = roc_curve(y_avg, s_avg)
    t_avg = float(thr2[np.nanargmax(tpr2 - fpr2)])
else:
    t_avg = 0.5

# ----- 2) Frame-level predictions & counts -----
df["frame_pred"] = np.where(df["prob_fake"] >= t_frame, "fake", "real")
# n_frames, n_correct_frames, n_wrong_frames, frame_accuracy
cnts = df.groupby(["video_name","true_label"], sort=False).apply(
    lambda g: pd.Series({
        "n_frames": int(len(g)),
        "n_correct_frames": int((g["frame_pred"]==g["true_label"]).sum()),
        "n_wrong_frames":   int((g["frame_pred"]!=g["true_label"]).sum()),
        "frame_accuracy":   float((g["frame_pred"]==g["true_label"]).mean())
    })
).reset_index()

# ----- 3) Per-video avg/std + decisions (avg & majority) -----
stats = df.groupby(["video_name","true_label"], sort=False)["prob_fake"].agg(
    avg_prob_fake="mean", std_prob_fake="std"
).reset_index()

# Average rule
stats["video_pred_by_avg"]     = (stats["avg_prob_fake"] >= t_avg).astype(int)          # 1=fake, 0=real
stats["video_correct_by_avg"]  = ((stats["video_pred_by_avg"]==1) & (stats["true_label"]=="fake") |
                                  (stats["video_pred_by_avg"]==0) & (stats["true_label"]=="real")).astype(int)

# Majority rule from frame predictions
maj = df.groupby("video_name", sort=False)["frame_pred"].agg(
    lambda a: 1 if (a=="fake").sum() >= (a.size - (a=="fake").sum()) else 0
).rename("video_pred_by_majority").reset_index()

# Merge majority with ground truth and correctness
maj = maj.merge(df.groupby("video_name", sort=False)["true_label"].first().reset_index(), on="video_name", how="left")
maj["video_correct_by_majority"] = ((maj["video_pred_by_majority"]==1) & (maj["true_label"]=="fake") |
                                    (maj["video_pred_by_majority"]==0) & (maj["true_label"]=="real")).astype(int)

# ----- 4) Assemble final table -----
table_capsule_ffpp = (
    stats.merge(cnts, on=["video_name","true_label"], how="left")
         .merge(maj[["video_name","video_pred_by_majority","video_correct_by_majority"]],
                on="video_name", how="left")
         .assign(
             dataset=DATASET_NAME,
             detector=DETECTOR_NAME,
             # tidy numeric formatting
             avg_prob_fake=lambda d: d["avg_prob_fake"].astype(float),
             std_prob_fake=lambda d: d["std_prob_fake"].fillna(0.0).astype(float),
             n_frames=lambda d: d["n_frames"].astype(int),
             n_correct_frames=lambda d: d["n_correct_frames"].astype(int),
             n_wrong_frames=lambda d: d["n_wrong_frames"].astype(int),
             frame_accuracy=lambda d: d["frame_accuracy"].astype(float),
             video_pred_by_avg=lambda d: d["video_pred_by_avg"].astype(int),
             video_correct_by_avg=lambda d: d["video_correct_by_avg"].astype(int),
             video_pred_by_majority=lambda d: d["video_pred_by_majority"].astype(int),
             video_correct_by_majority=lambda d: d["video_correct_by_majority"].astype(int),
         )[[
             "dataset","detector","video_name","true_label",
             "n_frames","n_correct_frames","n_wrong_frames","frame_accuracy",
             "avg_prob_fake","std_prob_fake",
             "video_pred_by_avg","video_correct_by_avg",
             "video_pred_by_majority","video_correct_by_majority"
         ]]
         .sort_values(["true_label","video_name"], kind="stable")
         .reset_index(drop=True)
)

# ----- 5) Display all rows, no column breaks -----
pd.set_option("display.max_rows", 100000)
pd.set_option("display.max_columns", 1000)
pd.set_option("display.width", 10000)
pd.set_option("display.expand_frame_repr", False)

display(table_capsule_ffpp)
print(f"[rows]={len(table_capsule_ffpp)} | thresholds: t_frame={t_frame:.3f}, t_avg={t_avg:.3f}")


  cnts = df.groupby(["video_name","true_label"], sort=False).apply(


Unnamed: 0,dataset,detector,video_name,true_label,n_frames,n_correct_frames,n_wrong_frames,frame_accuracy,avg_prob_fake,std_prob_fake,video_pred_by_avg,video_correct_by_avg,video_pred_by_majority,video_correct_by_majority
0,balanced_frames_FF++,Capsule,000_003,fake,20,20,0,1.0,0.605013,0.002602,1,1,1,1
1,balanced_frames_FF++,Capsule,010_005,fake,20,0,20,0.0,0.500555,0.030406,0,0,0,0
2,balanced_frames_FF++,Capsule,011_805,fake,20,3,17,0.15,0.570437,0.031841,0,0,0,0
3,balanced_frames_FF++,Capsule,012_026,fake,20,0,20,0.0,0.491209,0.041621,0,0,0,0
4,balanced_frames_FF++,Capsule,013_883,fake,20,19,1,0.95,0.607141,0.005542,1,1,1,1
5,balanced_frames_FF++,Capsule,014_790,fake,20,13,7,0.65,0.594129,0.016836,1,1,1,1
6,balanced_frames_FF++,Capsule,015_919,fake,20,20,0,1.0,0.606428,0.003373,1,1,1,1
7,balanced_frames_FF++,Capsule,016_209,fake,20,0,20,0.0,0.51526,0.025184,0,0,0,0
8,balanced_frames_FF++,Capsule,017_803,fake,20,4,16,0.2,0.510247,0.054078,0,0,0,0
9,balanced_frames_FF++,Capsule,018_019,fake,20,2,18,0.1,0.549537,0.039455,0,0,0,0


[rows]=102 | thresholds: t_frame=0.595, t_avg=0.592


In [None]:
# Save the Capsule large table CSV to Drive/Capsule results FF++
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os

# Pick the correct Drive root
ROOT = "/content/drive/MyDrive" if os.path.isdir("/content/drive/MyDrive") else "/content/drive/My Drive"
OUT_DIR = os.path.join(ROOT, "Capsule results FF++")
os.makedirs(OUT_DIR, exist_ok=True)
DEST = os.path.join(OUT_DIR, "capsule_ffpp_large_table.csv")

# Require the table from the previous cell
if 'table_capsule_ffpp' not in globals() or table_capsule_ffpp.empty:
    raise SystemExit("No 'table_capsule_ffpp' found. Run the large-table cell first.")

table_capsule_ffpp.to_csv(DEST, index=False)
print(f"[saved] {DEST}  (rows={len(table_capsule_ffpp)})")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[saved] /content/drive/MyDrive/Capsule results FF++/capsule_ffpp_large_table.csv  (rows=102)


In [None]:
# === Capsule — Small results table (balanced_frames_FF++) ===
# Columns: dataset, detector, video_name, true_label, correctly_predicted (yes/no)
# Uses the same avg-score Youden J thresholding as the large table.

import numpy as np, pandas as pd
from sklearn.metrics import roc_curve

# Safety
if 'df_scores' not in globals() or df_scores.empty:
    raise SystemExit("No 'df_scores' found. Run the Capsule scoring cell first.")

DATASET_NAME  = "balanced_frames_FF++"
DETECTOR_NAME = "Capsule"

# Clean frame-level results
df = df_scores.copy()
df["video_name"] = df["video_name"].astype(str)
df["true_label"] = df["true_label"].astype(str)
df["prob_fake"]  = pd.to_numeric(df["prob_fake"], errors="coerce").astype(float)
df = df.dropna(subset=["prob_fake"]).reset_index(drop=True)

# Per-video average prob
avg_df = df.groupby(["video_name","true_label"], sort=False)["prob_fake"].mean().rename("avg_prob_fake").reset_index()

# Global threshold on video-level averages (Youden J)
y_avg = (avg_df["true_label"]=="fake").astype(int).to_numpy()
s_avg = avg_df["avg_prob_fake"].to_numpy(dtype=float)
if len(np.unique(y_avg)) >= 2:
    fpr, tpr, thr = roc_curve(y_avg, s_avg)
    t_avg = float(thr[np.nanargmax(tpr - fpr)])
else:
    t_avg = 0.5

# Video prediction by average
avg_df["video_pred_by_avg"] = (avg_df["avg_prob_fake"] >= t_avg).map({True:"fake", False:"real"})
avg_df["correctly_predicted (yes or no)"] = np.where(
    avg_df["video_pred_by_avg"] == avg_df["true_label"], "yes", "no"
)

small_table_capsule_ffpp = (
    avg_df[["video_name","true_label","correctly_predicted (yes or no)"]]
    .assign(dataset=DATASET_NAME, detector=DETECTOR_NAME)
    [["dataset","detector","video_name","true_label","correctly_predicted (yes or no)"]]
    .sort_values(["true_label","video_name"], kind="stable")
    .reset_index(drop=True)
)

# Show all rows, no column breaks
pd.set_option("display.max_rows", 100000)
pd.set_option("display.max_columns", 1000)
pd.set_option("display.width", 10000)
pd.set_option("display.expand_frame_repr", False)

display(small_table_capsule_ffpp)
print(f"[rows]={len(small_table_capsule_ffpp)} | t_avg={t_avg:.3f}")


Unnamed: 0,dataset,detector,video_name,true_label,correctly_predicted (yes or no)
0,balanced_frames_FF++,Capsule,000_003,fake,yes
1,balanced_frames_FF++,Capsule,010_005,fake,no
2,balanced_frames_FF++,Capsule,011_805,fake,no
3,balanced_frames_FF++,Capsule,012_026,fake,no
4,balanced_frames_FF++,Capsule,013_883,fake,yes
5,balanced_frames_FF++,Capsule,014_790,fake,yes
6,balanced_frames_FF++,Capsule,015_919,fake,yes
7,balanced_frames_FF++,Capsule,016_209,fake,no
8,balanced_frames_FF++,Capsule,017_803,fake,no
9,balanced_frames_FF++,Capsule,018_019,fake,no


[rows]=102 | t_avg=0.592


In [None]:
# Save the Capsule small table CSV to Drive/Capsule results FF++
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import os

# Drive root & folder
ROOT = "/content/drive/MyDrive" if os.path.isdir("/content/drive/MyDrive") else "/content/drive/My Drive"
OUT_DIR = os.path.join(ROOT, "Capsule results FF++")
os.makedirs(OUT_DIR, exist_ok=True)
DEST = os.path.join(OUT_DIR, "capsule_ffpp_small_table.csv")

# Require the small table from the previous cell
if 'small_table_capsule_ffpp' not in globals() or small_table_capsule_ffpp.empty:
    raise SystemExit("No 'small_table_capsule_ffpp' found. Run the small-table cell first.")

small_table_capsule_ffpp.to_csv(DEST, index=False)
print(f"[saved] {DEST}  (rows={len(small_table_capsule_ffpp)})")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[saved] /content/drive/MyDrive/Capsule results FF++/capsule_ffpp_small_table.csv  (rows=102)
