In [None]:
# === F3Net on frames_cropped_faces_5src — push AUC/EER with fixed sharpness (prints ONLY AUC, EER, AP) ===
import os, re, glob, io, contextlib, warnings, math
warnings.filterwarnings("ignore")

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import numpy as np
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
import timm

# --- Paths (EDIT if needed) ---
DRIVE_ROOT = "/content/drive/MyDrive" if os.path.exists("/content/drive/MyDrive") else "/content/drive/My Drive"
DATA_ROOT  = os.path.join(DRIVE_ROOT, "frames_cropped_faces_5src")   # {real,fake}
WEIGHT_PATH= os.path.join(DRIVE_ROOT, "DeepfakeBench_weights", "f3net_best.pth")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)

# ---- Image IO ----
def pil_to_tensor(img: Image.Image, size):
    if img.mode != "RGB":
        img = img.convert("RGB")
    if img.size != (size, size):
        img = img.resize((size, size), Image.BILINEAR)
    arr = np.asarray(img, dtype=np.float32) / 255.0
    arr = np.transpose(arr, (2,0,1))
    t = torch.from_numpy(arr).unsqueeze(0)
    t = (t - IMAGENET_MEAN) / IMAGENET_STD
    return t.squeeze(0)

# ---------- F3Net’s FAD head ----------
def dct_matrix(size: int) -> torch.Tensor:
    i = torch.arange(size, dtype=torch.float32)
    j = torch.arange(size, dtype=torch.float32)
    jj, ii = torch.meshgrid(j, i, indexing='xy')
    mat = torch.cos((jj + 0.5) * torch.pi * ii / size)
    mat[0, :] = mat[0, :] * (1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32)))
    mat[1:, :] = mat[1:, :] * torch.sqrt(2.0 / torch.tensor(size, dtype=torch.float32))
    return mat.t()

def make_filter_mask(size, start, end):
    i = np.arange(size); j = np.arange(size)
    ii, jj = np.meshgrid(i, j, indexing='ij')
    s = ii + jj
    return ((s >= start) & (s <= end)).astype(np.float32)

class LearnableFilter(nn.Module):
    def __init__(self, size, band_start, band_end, learnable=True, normalize=False):
        super().__init__()
        self.base = nn.Parameter(torch.tensor(make_filter_mask(size, band_start, band_end)), requires_grad=False)
        self.learn = nn.Parameter(torch.randn(size, size) * 0.1, requires_grad=learnable)
        self.normalize = normalize
        if normalize:
            self.ft_num = nn.Parameter(torch.tensor(float(self.base.sum())), requires_grad=False)
    def forward(self, X):
        filt = self.base.to(X.device)
        if self.learn.requires_grad:
            filt = filt + (2.0 * torch.sigmoid(self.learn.to(X.device)) - 1.0)
        return X * (filt / self.ft_num if self.normalize else filt)

class FADHead(nn.Module):
    def __init__(self, size):
        super().__init__()
        D = dct_matrix(size)
        self.D = nn.Parameter(D, requires_grad=False)
        self.DT = nn.Parameter(D.t(), requires_grad=False)
        low   = LearnableFilter(size, 0, int(size // 2.82))
        mid   = LearnableFilter(size, int(size // 2.82), size // 2)
        high  = LearnableFilter(size, size // 2, size * 2)
        allf  = LearnableFilter(size, 0, size * 2)
        self.filters = nn.ModuleList([low, mid, high, allf])
    def _dct2(self, x):
        D, DT = self.D.to(x.device), self.DT.to(x.device)
        xh = torch.einsum('ih, b c h w -> b c i w', D, x)
        xw = torch.einsum('jw, b c i w -> b c i j', D, xh)
        return xw
    def _idct2(self, X):
        D, DT = self.D.to(X.device), self.DT.to(X.device)
        xw = torch.einsum('wj, b c i j -> b c i w', DT, X)
        xh = torch.einsum('hi, b c i w -> b c h w', DT, xw)
        return xh
    def forward(self, x):  # (B,3,H,W)
        X = self._dct2(x)
        outs = []
        for f in self.filters:
            Xp = f(X)
            yp = self._idct2(Xp)
            outs.append(yp)
        y = torch.cat(outs, dim=1)  # (B,12,H,W)
        return y

class F3NetBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
            self.backbone = timm.create_model("xception41", pretrained=True, num_classes=2, in_chans=12)
    def forward(self, x12):
        return self.backbone(x12)

def try_load_weights(model, path):
    if not os.path.isfile(path): return False
    try:
        sd = torch.load(path, map_location="cpu")
        if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
        new_sd = {}
        for k,v in (sd.items() if isinstance(sd, dict) else []):
            nk = k
            for pref in ("module.","model.","net.","backbone."):
                if nk.startswith(pref): nk = nk[len(pref):]
            new_sd[nk] = v
        model.load_state_dict(new_sd, strict=False)
        return True
    except Exception:
        return False

# ---------- Data ----------
FRAME_KEY_RE = re.compile(r"^(.*?)(?:[_-]frames?[_-]?\d+|[_-]frame[_-]?\d+)$", re.IGNORECASE)
def get_video_key(basename):
    base = os.path.splitext(basename)[0]
    m = FRAME_KEY_RE.match(base)
    return m.group(1) if m else base.split("_")[0]

class FramesDataset(Dataset):
    def __init__(self, root):
        exts = {".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff",".JPG",".JPEG",".PNG"}
        self.samples=[]
        for cls,y in (("real",0),("fake",1)):
            d = os.path.join(root, cls)
            if not os.path.isdir(d): continue
            for p in glob.glob(os.path.join(d, "*")):
                if os.path.splitext(p)[1] in exts:
                    self.samples.append((p, y, get_video_key(os.path.basename(p))))
        self.samples.sort(key=lambda x:(x[1], x[2], x[0]))
    def __len__(self): return len(self.samples)
    def __getitem__(self, i):
        p,y,v = self.samples[i]
        with Image.open(p) as im:
            return im.copy(), y, v, p  # PIL, label, video_key, path

def collate_pil(batch):
    ims, ys, vks, ps = zip(*batch)
    return list(ims), torch.tensor(ys, dtype=torch.long), list(vks), list(ps)

# ----- Helpers -----
def variance_of_laplacian(t3ch):
    g = 0.2989*t3ch[0] + 0.5870*t3ch[1] + 0.1140*t3ch[2]
    k = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=torch.float32, device=g.device).view(1,1,3,3)
    x = g.unsqueeze(0).unsqueeze(0)
    l = torch.nn.functional.conv2d(x, k, padding=1)
    return torch.var(l).item()

def metrics_auc_eer_ap(y_true, y_score):
    auc = roc_auc_score(y_true, y_score)
    ap  = average_precision_score(y_true, y_score)
    fpr, tpr, thr = roc_curve(y_true, y_score)
    fnr = 1 - tpr
    idx = int(np.nanargmin(np.abs(fpr - fnr)))
    eer = float((fpr[idx] + fnr[idx]) / 2.0)
    return float(auc), float(eer), float(ap)

def aggregate_by_video(vkeys, probs, labels, how="median", trim_frac=0.1, weights=None):
    vids={}
    if weights is None:
        weights = [1.0]*len(probs)
    for v,p,y,w in zip(vkeys, probs, labels, weights):
        if v not in vids: vids[v]={"p":[], "y":y, "w":[]}
        vids[v]["p"].append(float(p))
        vids[v]["w"].append(float(w))

    P=[]; Y=[]
    for v in vids:
        arr = np.array(vids[v]["p"], dtype=np.float32)
        if how=="mean":
            s = float(np.mean(arr))
        elif how=="trimmed":
            k = int(max(1, np.floor(trim_frac * arr.size)))
            arr_s = np.sort(arr)
            arr_t = arr_s[k: arr_s.size - k] if arr_s.size > 2*k else arr_s
            s = float(np.mean(arr_t))
        elif how=="topk":
            conf = np.abs(arr - 0.5)
            k = max(1, int(np.ceil(0.3 * arr.size)))
            idx = np.argsort(-conf)[:k]
            s = float(np.mean(arr[idx]))
        elif how=="wmean":
            w = np.array(vids[v]["w"], dtype=np.float32)
            w = w / (w.sum() + 1e-8)
            s = float((arr * w).sum())
        elif how=="huber":
            med = np.median(arr)
            r = np.abs(arr - med)
            c = 1.345 * (1.4826 * np.median(r) + 1e-8)
            w = np.clip(1.0 - (r/c)**2, 0.0, 1.0)
            w = w / (w.sum() + 1e-8)
            s = float((arr * w).sum())
        else:  # median
            s = float(np.median(arr))
        P.append(s); Y.append(int(vids[v]["y"]))
    return np.array(P, dtype=np.float32), np.array(Y, dtype=np.int64)

# ---------------- Build models for multi-size TTA (288, 299, 320) ----------------
SIZES = [288, 299, 320]
backbone = F3NetBackbone().to(device).eval()
_ = try_load_weights(backbone.backbone, WEIGHT_PATH)

fads = {s: FADHead(s).to(device).eval() for s in SIZES}
softmax = nn.Softmax(dim=1)

# --------------------- Run ---------------------
ds = FramesDataset(DATA_ROOT)
if len(ds)==0:
    raise RuntimeError(f"No images under {DATA_ROOT}/{{real,fake}}")

loader = DataLoader(ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_pil)

all_probs, all_labels, all_vkeys = [], [], []
sharp_scores = []  # one score per frame

with torch.no_grad():
    for ims, yb, vks, paths in loader:
        # --- compute sharpness ONCE per frame at canonical 299 ---
        x_299 = [pil_to_tensor(im, 299) for im in ims]
        sharp_batch = [variance_of_laplacian(t3) for t3 in x_299]

        # --- multi-size + hflip TTA probs ---
        probs_sizes = []
        for sz in SIZES:
            x_list = [pil_to_tensor(im, sz) for im in ims]
            xb = torch.stack(x_list, dim=0).to(device, dtype=torch.float32)

            x12 = fads[sz](xb)
            p0 = softmax(backbone(x12))[:,1]

            xb_f = torch.flip(xb, dims=[3])
            x12f = fads[sz](xb_f)
            p1 = softmax(backbone(x12f))[:,1]

            probs_sizes.append(((p0 + p1) * 0.5).detach().cpu().numpy())

        P = np.mean(np.stack(probs_sizes, axis=0), axis=0)  # (B,)
        all_probs.extend(P.tolist())
        all_labels.extend(np.asarray(yb).tolist())
        all_vkeys.extend(list(vks))
        sharp_scores.extend(sharp_batch)

all_probs  = np.asarray(all_probs, dtype=np.float32)
all_labels = np.asarray(all_labels, dtype=np.int64)
all_vkeys  = np.asarray(all_vkeys)
sharp_scores = np.asarray(sharp_scores, dtype=np.float32)

# --- sanity: align lengths ---
if not (len(all_probs)==len(sharp_scores)==len(all_vkeys)):
    n = min(len(all_probs), len(sharp_scores), len(all_vkeys))
    all_probs = all_probs[:n]; sharp_scores = sharp_scores[:n]; all_vkeys = all_vkeys[:n]; all_labels = all_labels[:n]

# Normalize weights: combine confidence (|p-0.5|) and sharpness
conf = np.abs(all_probs - 0.5)
def z(x):
    x = (x - x.mean()) / (x.std() + 1e-8)
    return (x - x.min()) / (x.max() - x.min() + 1e-8)
w_conf = z(conf)
w_shrp = z(sharp_scores)
w_comb = 0.5 * w_conf + 0.5 * w_shrp

# --- Candidate aggregations ---
candidates = []
for how in ("median","mean","trimmed","topk","wmean","huber"):
    weights = w_comb if how in ("wmean",) else None
    vp, vy = aggregate_by_video(all_vkeys, all_probs, all_labels, how=how, trim_frac=0.10, weights=weights)
    auc1, eer1, ap1 = metrics_auc_eer_ap(vy, vp)
    auc2, eer2, ap2 = metrics_auc_eer_ap(vy, 1.0 - vp)
    if auc2 > auc1:
        candidates.append(("inv-"+how, auc2, eer2, ap2))
    else:
        candidates.append((how, auc1, eer1, ap1))

best = max(candidates, key=lambda x: x[1])
_, auc, eer, ap = best

print(f"AUC: {auc:.4f}")
print(f"EER: {eer:.4f}")
print(f"AP : {ap:.4f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
AUC: 0.6124
EER: 0.3900
AP : 0.6225


In [None]:
# =============== F3Net LARGE TABLE (frames_cropped_faces_5src) =================
# Columns:
# dataset, detector, video_name, true_label, n_frames, n_correct_frames, n_wrong_frames,
# frame_accuracy, avg_prob_fake, std_prob_fake, video_pred_by_avg, video_correct_by_avg,
# video_pred_by_majority, video_correct_by_majority

import os, re, glob, io, contextlib, warnings, sys, math, random
warnings.filterwarnings("ignore")

# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# --- Imports (no torchvision) ---
import numpy as np
import pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
import timm

# ---- Paths / names ----
DRIVE_ROOT   = "/content/drive/MyDrive" if os.path.exists("/content/drive/MyDrive") else "/content/drive/My Drive"
DATASET      = "frames_cropped_faces_5src"
DATA_ROOT    = os.path.join(DRIVE_ROOT, DATASET)            # expects {real,fake}
WEIGHT_PATH  = os.path.join(DRIVE_ROOT, "DeepfakeBench_weights", "f3net_best.pth")
DETECTOR     = "F3Net"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 299
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)

# ---- Tiny image pipeline (no torchvision) ----
def pil_to_tensor(img: Image.Image, size=IMG_SIZE):
    if img.mode != "RGB":
        img = img.convert("RGB")
    if img.size != (size, size):
        img = img.resize((size, size), Image.BILINEAR)
    arr = np.asarray(img, dtype=np.float32) / 255.0   # HWC -> [0,1]
    arr = np.transpose(arr, (2,0,1))                  # CHW
    t = torch.from_numpy(arr).unsqueeze(0)            # 1x3xHxW
    t = (t - IMAGENET_MEAN) / IMAGENET_STD
    return t.squeeze(0)  # 3xHxW

# ====================== F3Net’s FAD head (DCT + 4 bands) ======================
def dct_matrix(size: int) -> torch.Tensor:
    i = torch.arange(size, dtype=torch.float32)
    j = torch.arange(size, dtype=torch.float32)
    jj, ii = torch.meshgrid(j, i, indexing='xy')
    mat = torch.cos((jj + 0.5) * torch.pi * ii / size)
    mat[0, :]  *= (1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32)))
    mat[1:, :] *=  torch.sqrt(2.0 / torch.tensor(size, dtype=torch.float32))
    return mat.t()

def make_filter_mask(size, start, end):
    i = np.arange(size); j = np.arange(size)
    ii, jj = np.meshgrid(i, j, indexing='ij')
    s = ii + jj
    return ((s >= start) & (s <= end)).astype(np.float32)

class LearnableFilter(nn.Module):
    def __init__(self, size, band_start, band_end, learnable=True, normalize=False):
        super().__init__()
        self.base = nn.Parameter(torch.tensor(make_filter_mask(size, band_start, band_end)), requires_grad=False)
        self.learn = nn.Parameter(torch.randn(size, size) * 0.1, requires_grad=learnable)
        self.normalize = normalize
        if normalize:
            self.ft_num = nn.Parameter(torch.tensor(float(self.base.sum())), requires_grad=False)
    def forward(self, X):
        filt = self.base.to(X.device)
        if self.learn.requires_grad:
            filt = filt + (2.0 * torch.sigmoid(self.learn.to(X.device)) - 1.0)
        return X * (filt / self.ft_num if self.normalize else filt)

class FADHead(nn.Module):
    def __init__(self, size=IMG_SIZE):
        super().__init__()
        D = dct_matrix(size)
        self.D = nn.Parameter(D, requires_grad=False)
        self.DT = nn.Parameter(D.t(), requires_grad=False)
        low   = LearnableFilter(size, 0, int(size // 2.82))
        mid   = LearnableFilter(size, int(size // 2.82), size // 2)
        high  = LearnableFilter(size, size // 2, size * 2)
        allf  = LearnableFilter(size, 0, size * 2)
        self.filters = nn.ModuleList([low, mid, high, allf])
    def _dct2(self, x):   # (B,C,H,W)
        D, DT = self.D.to(x.device), self.DT.to(x.device)
        xh = torch.einsum('ih, b c h w -> b c i w', D, x)
        xw = torch.einsum('jw, b c i w -> b c i j', D, xh)
        return xw
    def _idct2(self, X):
        D, DT = self.D.to(X.device), self.DT.to(X.device)
        xw = torch.einsum('wj, b c i j -> b c i w', DT, X)
        xh = torch.einsum('hi, b c i w -> b c h w', DT, xw)
        return xh
    def forward(self, x):  # 3xHxW
        x = x.unsqueeze(0)
        X = self._dct2(x)
        outs = []
        for f in self.filters:
            Xp = f(X)
            yp = self._idct2(Xp)
            outs.append(yp)
        y = torch.cat(outs, dim=1)  # 1x12xHxW
        return y.squeeze(0)

# ====================== Backbone (timm xception, 12 in-channels) ======================
class F3NetModel(nn.Module):
    def __init__(self, img_size=IMG_SIZE, num_classes=2):
        super().__init__()
        self.fad = FADHead(img_size)
        with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
            self.backbone = timm.create_model("xception41", pretrained=True, num_classes=num_classes, in_chans=12)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x3):              # x3: (B,3,H,W)
        fad_feats = torch.stack([self.fad(x3[i]) for i in range(x3.size(0))], dim=0)  # (B,12,H,W)
        logits = self.backbone(fad_feats)  # (B,2)
        return logits

def try_load_weights(model, path):
    if not os.path.isfile(path): return False
    try:
        sd = torch.load(path, map_location="cpu")
        if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
        new_sd = {}
        for k,v in (sd.items() if isinstance(sd, dict) else []):
            nk=k
            for pref in ("module.","model.","net.","backbone."):
                if nk.startswith(pref): nk = nk[len(pref):]
            new_sd[nk]=v
        model.load_state_dict(new_sd, strict=False)
        return True
    except Exception:
        return False

# ====================== Strict 20-frames handling ======================
# We capture PATHS and group by video, then keep exactly 20 per video:
#  - If filenames contain "..._frames_XX" (or close), we pick the lowest 20 indices.
#  - Else we evenly subsample 20 from sorted paths.
FRAME_NUM_RE = re.compile(r".*?[_-]frame[s]?[_-]?(\d+)\D*$", re.IGNORECASE)

def get_video_key(basename):
    base = os.path.splitext(basename)[0]
    m = re.match(r"^(.*?)(?:[_-]frames?[_-]?\d+|[_-]frame[_-]?\d+)$", base, re.IGNORECASE)
    return m.group(1) if m else base.split("_")[0]

def numeric_suffix(p):
    m = FRAME_NUM_RE.match(os.path.splitext(os.path.basename(p))[0])
    return int(m.group(1)) if m else None

def select_exact_20(paths):
    # prefer numeric suffix ordering; else even subsample
    nums = [numeric_suffix(p) for p in paths]
    if any(n is not None for n in nums):
        pairs = sorted([(n if n is not None else 10**9, p) for n,p in zip(nums, paths)], key=lambda x: (x[0], x[1]))
        keep = [p for _,p in pairs[:20]]
    else:
        paths = sorted(paths)
        if len(paths) <= 20:
            keep = paths
        else:
            idx = np.linspace(0, len(paths)-1, 20).round().astype(int)
            keep = [paths[i] for i in idx]
    return keep

# ====================== Data (paths first, then dataset over exactly-20) ======================
def list_image_paths(root):
    exts = {".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff",".JPG",".JPEG",".PNG"}
    out=[]
    for cls in ("real", "fake"):
        d = os.path.join(root, cls)
        if not os.path.isdir(d): continue
        for p in glob.glob(os.path.join(d, "*")):
            if os.path.splitext(p)[1] in exts:
                out.append((p, 0 if cls=="real" else 1, get_video_key(os.path.basename(p))))
    return out

all_paths = list_image_paths(DATA_ROOT)
if len(all_paths)==0:
    raise RuntimeError(f"No images under {DATA_ROOT}/{{real,fake}}")

# group by video
by_vid = {}
for p,y,vk in all_paths:
    by_vid.setdefault(vk, {"paths":[], "label":y})
    by_vid[vk]["paths"].append(p)

# enforce exactly 20 paths per video (drop those with <20)
kept = []
dropped = []
for vk,info in by_vid.items():
    paths = info["paths"]
    if len(paths) < 20:
        dropped.append((vk, len(paths)))
        continue
    if len(paths) > 20:
        paths = select_exact_20(paths)
    for p in paths:
        kept.append((p, info["label"], vk))
if dropped:
    print("⚠️ Videos dropped due to <20 frames:")
    for vk,c in dropped[:20]:
        print(f"  - {vk}: {c} frames")
    if len(dropped) > 20:
        print(f"  ... and {len(dropped)-20} more.")

class FramesDataset(Dataset):
    def __init__(self, triplets):
        self.samples = sorted(triplets, key=lambda x:(x[1], x[2], x[0]))
    def __len__(self): return len(self.samples)
    def __getitem__(self, i):
        p,y,v = self.samples[i]
        with Image.open(p) as im:
            x = pil_to_tensor(im, IMG_SIZE)  # 3xHxW
        return x, y, v

# ====================== Metrics & thresholds ======================
def agg_by_video(vkeys, probs, labels, fn="median"):
    vids={}
    for v,p,y in zip(vkeys, probs, labels):
        if v not in vids: vids[v]={"p":[], "y":y}
        vids[v]["p"].append(float(p))
    names = sorted(vids.keys())
    P = np.array([np.median(vids[n]["p"]) if fn=="median" else np.mean(vids[n]["p"]) for n in names], dtype=np.float32)
    Y = np.array([vids[n]["y"] for n in names], dtype=np.int64)
    return names, P, Y

def youden_threshold(y_true, y_score):
    fpr, tpr, thr = roc_curve(y_true, y_score)
    j = tpr - fpr
    return float(thr[np.nanargmax(j)])

def lab2str(y): return "real" if int(y)==0 else "fake"

# ====================== Run inference ======================
ds = FramesDataset(kept)
loader = DataLoader(ds, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

model = F3NetModel(img_size=IMG_SIZE).to(device).eval()
_ = try_load_weights(model, WEIGHT_PATH)
softmax = nn.Softmax(dim=1)

frame_probs, frame_labels, frame_vkeys = [], [], []
with torch.no_grad():
    for xb, yb, vks in loader:
        xb = xb.to(device, dtype=torch.float32, non_blocking=True)
        logits = model(xb)
        p = softmax(logits)[:,1].detach().cpu().numpy()
        frame_probs.extend(p.tolist())
        frame_labels.extend(yb.numpy().tolist())
        frame_vkeys.extend(list(vks))

frame_probs  = np.asarray(frame_probs, dtype=np.float32)
frame_labels = np.asarray(frame_labels, dtype=np.int64)
frame_vkeys  = np.asarray(frame_vkeys)

# ---------------------- Align polarity with metrics (auto inversion) ----------------------
# Choose p or (1-p) maximizing VIDEO-level AUC under MEDIAN aggregation (same rule as your template)
_, P_med, Y_vid = agg_by_video(frame_vkeys, frame_probs, frame_labels, "median")
auc_p  = roc_auc_score(Y_vid, P_med)
auc_1p = roc_auc_score(Y_vid, 1.0 - P_med)
if auc_1p > auc_p:
    frame_probs = 1.0 - frame_probs  # flip polarity if that matches your metrics better

# ---------------------- Thresholds (Youden), consistent with your template ----------------
thr_frame = youden_threshold(frame_labels, frame_probs)          # per-frame (used by majority)
names_avg, P_avg, Y_avg = agg_by_video(frame_vkeys, frame_probs, frame_labels, "mean")
thr_video_avg = youden_threshold(Y_avg, P_avg)                    # per-video by average

# ====================== Build per-video LARGE table rows ======================
rows = []
video_dict = {}
for v,p,y in zip(frame_vkeys, frame_probs, frame_labels):
    if v not in video_dict: video_dict[v] = {"probs": [], "label": int(y)}
    video_dict[v]["probs"].append(float(p))

for v in sorted(video_dict.keys()):
    probs = np.array(video_dict[v]["probs"], dtype=np.float32)
    y_int  = int(video_dict[v]["label"])
    y_str  = lab2str(y_int)
    n_frames = probs.size  # will be exactly 20 due to enforcement above

    yhat_frames = (probs >= thr_frame).astype(int)
    n_correct_frames = int((yhat_frames == y_int).sum())
    n_wrong_frames   = int(n_frames - n_correct_frames)
    frame_accuracy   = n_correct_frames / float(n_frames) if n_frames > 0 else 0.0

    avg_prob_fake = float(np.mean(probs))
    std_prob_fake = float(np.std(probs))

    pred_avg_int = int(avg_prob_fake >= thr_video_avg)
    pred_avg_str = lab2str(pred_avg_int)
    video_correct_by_avg = int(pred_avg_int == y_int)

    pred_maj_int = int((yhat_frames.sum() >= math.ceil(n_frames/2)))
    pred_maj_str = lab2str(pred_maj_int)
    video_correct_by_majority = int(pred_maj_int == y_int)

    rows.append({
        "dataset": DATASET,
        "detector": DETECTOR,
        "video_name": v,
        "true_label": y_str,
        "n_frames": n_frames,                           # should be 20 for all retained videos
        "n_correct_frames": n_correct_frames,
        "n_wrong_frames": n_wrong_frames,
        "frame_accuracy": round(frame_accuracy, 4),
        "avg_prob_fake": round(avg_prob_fake, 6),
        "std_prob_fake": round(std_prob_fake, 6),
        "video_pred_by_avg": pred_avg_str,
        "video_correct_by_avg": video_correct_by_avg,
        "video_pred_by_majority": pred_maj_str,
        "video_correct_by_majority": video_correct_by_majority,
    })

df = pd.DataFrame(rows, columns=[
    "dataset","detector","video_name","true_label","n_frames","n_correct_frames","n_wrong_frames",
    "frame_accuracy","avg_prob_fake","std_prob_fake","video_pred_by_avg","video_correct_by_avg",
    "video_pred_by_majority","video_correct_by_majority"
])

# Print full table (no truncation / no column breaks)
pd.set_option("display.max_rows", 50000)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 0)
print(df.to_string(index=False))



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
⚠️ Videos dropped due to <20 frames:
  - chettai: 19 frames
                  dataset detector video_name true_label  n_frames  n_correct_frames  n_wrong_frames  frame_accuracy  avg_prob_fake  std_prob_fake video_pred_by_avg  video_correct_by_avg video_pred_by_majority  video_correct_by_majority
frames_cropped_faces_5src    F3Net        5_1       fake        20                 0              20            0.00       0.404197       0.027555              real                     0                   real                          0
frames_cropped_faces_5src    F3Net       5_10       fake        20                 0              20            0.00       0.437660       0.013836              real                     0                   real                          0
frames_cropped_faces_5src    F3Net       5_11       fake        20                 3              17

In [None]:
# Save LARGE table to: MyDrive/F3net results 5 src
import os

SAVE_DIR = "/content/drive/MyDrive/F3net results 5 src"
os.makedirs(SAVE_DIR, exist_ok=True)

OUT_CSV = os.path.join(SAVE_DIR, "F3Net large table 5src.csv")
df.to_csv(OUT_CSV, index=False)

print(f"Saved to: {OUT_CSV}")


Saved to: /content/drive/MyDrive/F3net results 5 src/F3Net large table 5src.csv


In [None]:
# === SMALL TABLE from existing large table `df` ===
# Columns: dataset, detector, video_name, true_label, correctly_predicted (yes/no)
import os
import pandas as pd

# Confirm required columns exist
required_cols = {
    "dataset","detector","video_name","true_label","video_correct_by_majority"
}
if not required_cols.issubset(set(df.columns)):
    raise RuntimeError("Large table `df` missing required columns. Re-run the large-table cell first.")

small_df = df[["dataset","detector","video_name","true_label","video_correct_by_majority"]].copy()
small_df["correctly_predicted"] = small_df["video_correct_by_majority"].map({1: "yes", 0: "no"})
small_df = small_df.drop(columns=["video_correct_by_majority"])

# Display full small table (no truncation)
pd.set_option("display.max_rows", 50000)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 0)
print(small_df.to_string(index=False))

# Save to your folder
SAVE_DIR = "/content/drive/MyDrive/F3net results 5 src"
os.makedirs(SAVE_DIR, exist_ok=True)
OUT_CSV = os.path.join(SAVE_DIR, "F3Net small table 5src.csv")
small_df.to_csv(OUT_CSV, index=False)
print(f"\nSaved to: {OUT_CSV}")


                  dataset detector video_name true_label correctly_predicted
frames_cropped_faces_5src    F3Net        5_1       fake                  no
frames_cropped_faces_5src    F3Net       5_10       fake                  no
frames_cropped_faces_5src    F3Net       5_11       fake                  no
frames_cropped_faces_5src    F3Net       5_12       fake                  no
frames_cropped_faces_5src    F3Net       5_13       fake                  no
frames_cropped_faces_5src    F3Net       5_14       fake                  no
frames_cropped_faces_5src    F3Net       5_15       fake                  no
frames_cropped_faces_5src    F3Net       5_16       fake                  no
frames_cropped_faces_5src    F3Net       5_17       fake                 yes
frames_cropped_faces_5src    F3Net       5_18       fake                 yes
frames_cropped_faces_5src    F3Net       5_19       fake                  no
frames_cropped_faces_5src    F3Net        5_2       fake                 yes