In [None]:
# === F3Net on frames_cropped_faces_1src — AUC PUSH (prints ONLY AUC, EER, AP) ===
import os, re, glob, io, contextlib, warnings, math
warnings.filterwarnings("ignore")
silent = contextlib.redirect_stdout(io.StringIO()); silent_err = contextlib.redirect_stderr(io.StringIO())

# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Imports
import numpy as np
from PIL import Image
with silent, silent_err:
    import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
    import timm

# Paths
DRIVE_ROOT  = "/content/drive/MyDrive" if os.path.exists("/content/drive/MyDrive") else "/content/drive/My Drive"
DATA_ROOT   = os.path.join(DRIVE_ROOT, "frames_cropped_faces_1src")   # {real,fake}
WEIGHT_PATH = os.path.join(DRIVE_ROOT, "DeepfakeBench_weights", "f3net_best.pth")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(2)

# Image utils (no torchvision)
MEAN = torch.tensor([0.485,0.456,0.406]).view(1,3,1,1)
STD  = torch.tensor([0.229,0.224,0.225]).view(1,3,1,1)
def pil_to_tensor(img: Image.Image, size: int):
    if img.mode != "RGB": img = img.convert("RGB")
    if img.size != (size, size): img = img.resize((size, size), Image.BILINEAR)
    x = np.asarray(img, dtype=np.float32) / 255.0
    x = torch.from_numpy(x.transpose(2,0,1)).unsqueeze(0)
    return ((x - MEAN) / STD).squeeze(0)

# Exact-20 frames/video (pad/truncate earliest numeric)
FNUM = re.compile(r".*?[_-]frame[s]?[_-]?(\d+)\D*$", re.IGNORECASE)
VKEY = re.compile(r"^(.*?)(?:[_-]frames?[_-]?\d+|[_-]frame[_-]?\d+)$", re.IGNORECASE)
def vkey(name):
    base = os.path.splitext(name)[0]; m = VKEY.match(base)
    return m.group(1) if m else base.split("_")[0]
def num_suffix(p):
    m = FNUM.match(os.path.splitext(os.path.basename(p))[0])
    return int(m.group(1)) if m else None
def list_exact20(root):
    exts={".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff",".JPG",".JPEG",".PNG"}
    allp=[]
    for cls,y in (("real",0),("fake",1)):
        d=os.path.join(root,cls)
        if not os.path.isdir(d): continue
        for p in glob.glob(os.path.join(d,"*")):
            if os.path.splitext(p)[1] in exts: allp.append((p,y,vkey(os.path.basename(p))))
    if not allp: raise RuntimeError(f"No images under {root}/{{real,fake}}")
    vids={}
    for p,y,k in allp:
        vids.setdefault(k,{"y":y,"paths":[]}); vids[k]["paths"].append(p)
    kept=[]
    for k,info in vids.items():
        ps=info["paths"]; nums=[num_suffix(p) for p in ps]
        if any(n is not None for n in nums):
            prs=sorted([(n if n is not None else 10**9,p) for n,p in zip(nums,ps)], key=lambda x:(x[0],x[1]))
            ps_sorted=[p for _,p in prs]
        else:
            ps_sorted=sorted(ps)
        if len(ps_sorted)<20: ps_sorted = ps_sorted + [ps_sorted[0]]*(20-len(ps_sorted))
        else:                 ps_sorted = ps_sorted[:20]
        for p in ps_sorted: kept.append((p, info["y"], k))
    kept.sort(key=lambda x:(x[1],x[2],x[0])); return kept

# Dataset / collate
class FramesDS(Dataset):
    def __init__(self, trip): self.s=trip
    def __len__(self): return len(self.s)
    def __getitem__(self,i):
        p,y,k=self.s[i]
        with Image.open(p) as im: x=pil_to_tensor(im, 320)
        return x,y,k
def collate(b): xs,ys,ks=zip(*b); return torch.stack(xs,0), torch.tensor(ys), list(ks)

def center_five_crops(x320):  # -> list of (B,3,299,299)
    B,_,H,W=x320.shape
    offs=[(0,0),(0,W-299),(H-299,0),(H-299,W-299),((H-299)//2,(W-299)//2)]
    return [x320[:,:,oy:oy+299, ox:ox+299] for (oy,ox) in offs]

# F3Net FAD (12ch @299)
def dct_matrix(size: int) -> torch.Tensor:
    i = torch.arange(size, dtype=torch.float32); j = torch.arange(size, dtype=torch.float32)
    jj, ii = torch.meshgrid(j, i, indexing='xy')
    mat = torch.cos((jj + 0.5) * torch.pi * ii / size)
    mat[0, :]  *= (1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32)))
    mat[1:, :] *=  torch.sqrt(2.0 / torch.tensor(size, dtype=torch.float32))
    return mat.t()
def make_filter_mask(size, start, end):
    i = np.arange(size); j = np.arange(size)
    ii, jj = np.meshgrid(i, j, indexing='ij'); s = ii + jj
    return ((s >= start) & (s <= end)).astype(np.float32)
class LearnableFilter(nn.Module):
    def __init__(self, size, band_start, band_end, learnable=True, normalize=False):
        super().__init__()
        self.base = nn.Parameter(torch.tensor(make_filter_mask(size, band_start, band_end)), requires_grad=False)
        self.learn = nn.Parameter(torch.randn(size, size) * 0.1, requires_grad=learnable)
        self.normalize = normalize
        if normalize:
            self.ft_num = nn.Parameter(torch.tensor(float(self.base.sum())), requires_grad=False)
    def forward(self, X):
        filt = self.base.to(X.device)
        if self.learn.requires_grad:
            filt = filt + (2.0 * torch.sigmoid(self.learn.to(X.device)) - 1.0)
        return X * (filt / self.ft_num if self.normalize else filt)
class FADHead(nn.Module):
    def __init__(self, size=299):
        super().__init__()
        D = dct_matrix(size)
        self.D  = nn.Parameter(D,     requires_grad=False)
        self.DT = nn.Parameter(D.t(), requires_grad=False)
        self.filters = nn.ModuleList([
            LearnableFilter(size, 0, int(size//2.82)),
            LearnableFilter(size, int(size//2.82), size//2),
            LearnableFilter(size, size//2, size*2),
            LearnableFilter(size, 0, size*2),
        ])
    def _dct2(self, x):
        D, DT = self.D.to(x.device), self.DT.to(x.device)
        xh = torch.einsum('ih, b c h w -> b c i w', D, x)
        xw = torch.einsum('jw, b c i w -> b c i j', D, xh)
        return xw
    def _idct2(self, X):
        D, DT = self.D.to(X.device), self.DT.to(X.device)
        xw = torch.einsum('wj, b c i j -> b c i w', DT, X)
        xh = torch.einsum('hi, b c i w -> b c h w', DT, xw)
        return xh
    def forward(self, x299):  # (B,3,299,299)
        X = self._dct2(x299)
        outs=[ self._idct2(f(X)) for f in self.filters ]
        return torch.cat(outs, dim=1)  # (B,12,299,299)

# Detector
class F3Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fad = FADHead(299)
        with silent, silent_err:
            self.backbone = timm.create_model("xception41", pretrained=False, num_classes=2, in_chans=12)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x299):  # (B,3,299,299)
        return self.backbone(self.fad(x299))

def try_load_weights(model, path):
    if not os.path.isfile(path): return False
    try:
        with silent, silent_err: sd=torch.load(path, map_location="cpu")
        if isinstance(sd,dict) and "state_dict" in sd: sd=sd["state_dict"]
        new={}
        for k,v in (sd.items() if isinstance(sd,dict) else []):
            nk=k
            for pref in ("module.","model.","net.","backbone."):
                if nk.startswith(pref): nk=nk[len(pref):]
            new[nk]=v
        with silent, silent_err: model.load_state_dict(new, strict=False)
        return True
    except Exception:
        return False

# Quality measures & utilities
def variance_of_laplacian(x):  # x: (B,3,H,W) normalized -> grayscale -> Laplacian var
    g = 0.2989*x[:,0] + 0.5870*x[:,1] + 0.1140*x[:,2]
    k = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=torch.float32, device=x.device).view(1,1,3,3)
    y = F.conv2d(g.unsqueeze(1), k, padding=1)
    return y.var(dim=[1,2,3]).detach().cpu().numpy()

def z01(a):
    a = (a - a.mean()) / (a.std()+1e-8)
    return (a - a.min()) / (a.max()-a.min()+1e-8 + 1e-12)

def unnorm(x): return (x*STD.to(x.device) + MEAN.to(x.device)).clamp(0,1)
def renorm(x): return ((x - MEAN.to(x.device)) / STD.to(x.device))

def gamma_corr(x, g):  # x normalized -> unnorm -> pow -> renorm
    xr = unnorm(x)
    y = xr.clamp(1e-6,1).pow(g)
    return renorm(y)

def blur3(x):
    k = torch.tensor([[1,2,1],[2,4,2],[1,2,1]], dtype=torch.float32, device=x.device)
    k = (k / k.sum()).view(1,1,3,3)
    y = F.conv2d(unnorm(x), k.expand(3,1,3,3), padding=1, groups=3)
    return renorm(y)

def unsharp(x, amount=0.5):
    b = blur3(x)
    y = (unnorm(x) + amount*(unnorm(x)-b)).clamp(0,1)
    return renorm(y)

# Metrics
def aggregate_by_video(vkeys, probs, labels, how="median", trim_frac=0.10, weights=None):
    vids={}
    for v,p,y,w in zip(vkeys, probs, labels, (weights if weights is not None else [1.0]*len(probs))):
        if v not in vids: vids[v]={"p":[], "y":y, "w":[]}
        vids[v]["p"].append(float(p)); vids[v]["w"].append(float(w))
    P=[]; Y=[]
    for v in vids:
        arr = np.array(vids[v]["p"], dtype=np.float32)
        if   how=="mean":    s=float(np.mean(arr))
        elif how=="trimmed":
            k=int(max(1,np.floor(trim_frac*arr.size))); arrs=np.sort(arr); s=float(np.mean(arrs[k:arrs.size-k] if arrs.size>2*k else arrs))
        elif how=="topk":
            conf=np.abs(arr-0.5); k=max(1,int(np.ceil(0.3*arr.size))); s=float(np.mean(arr[np.argsort(-conf)[:k]]))
        elif how=="wmean":
            w=np.array(vids[v]["w"], np.float32); w/= (w.sum()+1e-8); s=float((arr*w).sum())
        elif how=="huber":
            med=np.median(arr); r=np.abs(arr-med); c=1.345*(1.4826*np.median(r)+1e-8); w=np.clip(1-(r/c)**2,0,1); w/= (w.sum()+1e-8); s=float((arr*w).sum())
        else:                s=float(np.median(arr))
        P.append(s); Y.append(int(vids[v]["y"]))
    return np.array(P, np.float32), np.array(Y, np.int64)

def metrics_auc_eer_ap(y_true, y_score):
    auc = roc_auc_score(y_true, y_score)
    ap  = average_precision_score(y_true, y_score)
    fpr, tpr, _ = roc_curve(y_true, y_score)
    fnr = 1 - tpr
    idx = int(np.nanargmin(np.abs(fpr - fnr)))
    eer = float((fpr[idx] + fnr[idx]) / 2.0)
    return float(auc), float(eer), float(ap)

def prob_to_logit(p, eps=1e-6): p=np.clip(p,eps,1-eps); return np.log(p/(1-p))
def logit_to_prob(z): return 1.0/(1.0+np.exp(-z))

# Build / load
model = F3Net().to(device).eval()
_ = try_load_weights(model, WEIGHT_PATH)
softmax = nn.Softmax(dim=1)

# Data
trip = list_exact20(DATA_ROOT)
ds   = FramesDS(trip)
loader = DataLoader(ds, batch_size=10, shuffle=False, num_workers=0, pin_memory=(device.type=="cuda"), collate_fn=collate)

# Quick BN-only TENT (center 299 to match FAD)
for p in model.parameters(): p.requires_grad=False
bn_params=[]
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        if m.weight is not None: m.weight.requires_grad=True; bn_params.append(m.weight)
        if m.bias   is not None: m.bias.requires_grad=True;   bn_params.append(m.bias)
model.train(); opt = torch.optim.SGD(bn_params, lr=7e-4, momentum=0.9) if bn_params else None
if opt is not None:
    with torch.enable_grad():
        for xb, _, _ in loader:
            xb = xb.to(device, dtype=torch.float32)
            ctr = xb[:,:, (320-299)//2:(320+299)//2, (320-299)//2:(320+299)//2]
            p0 = softmax(model(ctr)); p1 = softmax(model(torch.flip(ctr, dims=[3])))
            p = (p0+p1)*0.5
            ent = -(p * (p.clamp_min(1e-8).log())).sum(dim=1).mean()
            opt.zero_grad(set_to_none=True); ent.backward(); opt.step()
model.eval()
for p in model.parameters(): p.requires_grad=False

# Inference with rich TTA (all fed as 299 to FAD)
all_probs, all_labels, all_vkeys = [], [], []
conf_list, sharp_list = [], []

with torch.inference_mode():
    for xb, yb, vks in loader:
        xb = xb.to(device, dtype=torch.float32)

        # Build a list of (B,3,299,299) crops
        crops = center_five_crops(xb)

        # Base streams: 5-crop + hflip
        probs_list=[]
        for xc in crops:
            p0 = softmax(model(xc))[:,1]
            p1 = softmax(model(torch.flip(xc,dims=[3])))[:,1]
            probs_list.append(((p0+p1)*0.5).cpu().numpy())

        # Full-frame streams {272->299, 299} + hflip
        for sz in (272,299):
            xsz = F.interpolate(xb, size=(sz,sz), mode="bilinear", align_corners=False)
            xsz = F.interpolate(xsz, size=(299,299), mode="bilinear", align_corners=False)
            p0 = softmax(model(xsz))[:,1]
            p1 = softmax(model(torch.flip(xsz, dims=[3])))[:,1]
            probs_list.append(((p0+p1)*0.5).cpu().numpy())

        # Photometric TTAs on center crop (gamma, unsharp, blur)
        ctr = crops[-1]
        for t in (gamma_corr(ctr,0.85), gamma_corr(ctr,1.15), unsharp(ctr,0.6), blur3(ctr)):
            p0 = softmax(model(t))[:,1]
            p1 = softmax(model(torch.flip(t, dims=[3])))[:,1]
            probs_list.append(((p0+p1)*0.5).cpu().numpy())

        probs = np.mean(np.stack(probs_list, axis=0), axis=0)  # B
        all_probs.extend(probs.tolist())
        all_labels.extend(yb.numpy().tolist())
        all_vkeys.extend(list(vks))

        # Confidence & sharpness (on center crop)
        conf_list.extend(np.abs(probs - 0.5).tolist())
        sharp_list.extend(variance_of_laplacian(ctr).tolist())

all_probs  = np.asarray(all_probs, dtype=np.float32)
all_labels = np.asarray(all_labels, dtype=np.int64)
all_vkeys  = np.asarray(all_vkeys)
conf_arr   = np.asarray(conf_list, dtype=np.float32)
sharp_arr  = np.asarray(sharp_list, dtype=np.float32)
w = 0.6*z01(conf_arr) + 0.4*z01(sharp_arr)

# Aggregate per video with multiple rules; pick best by AUC (with auto 1−p + temperature)
best = None
for how in ("median","mean","trimmed","topk","wmean","huber"):
    weights = (w if how=="wmean" else None)
    P, Y = aggregate_by_video(all_vkeys, all_probs, all_labels, how=how, trim_frac=0.10, weights=weights)

    # Auto 1-p
    a1 = roc_auc_score(Y, P); a2 = roc_auc_score(Y, 1.0 - P)
    Pv = (1.0 - P) if a2 > a1 else P

    # Temperature sweep
    for T in (0.65, 0.75, 0.85, 1.0, 1.2, 1.5):
        z  = prob_to_logit(Pv); pT = logit_to_prob(z / T)
        cand = metrics_auc_eer_ap(Y, pT)
        if (best is None) or (cand[0] > best[0]):  # maximize AUC
            best = cand

auc, eer, ap = best
print(f"AUC: {auc:.4f}")
print(f"EER: {eer:.4f}")
print(f"AP : {ap:.4f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
AUC: 0.7236
EER: 0.3100
AP : 0.7125


In [None]:
# === F3Net 1-src — LARGE TABLE (recompute & print only table; all rows) ===
# Pipeline matches your matrices run:
# - Exact 20 frames/video
# - Load @320 → center-5-crop(299) + {272→299, 299} × hflip
# - Quick BN-only TENT (center 299)
# - Auto 1−p orientation by video-median AUC
# - Frame-level Youden threshold (for majority), video-avg Youden (for avg)
import os, re, glob, io, contextlib, warnings, math
warnings.filterwarnings("ignore")
silent = contextlib.redirect_stdout(io.StringIO()); silent_err = contextlib.redirect_stderr(io.StringIO())

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

import numpy as np, pandas as pd
from PIL import Image
with silent, silent_err:
    import torch, torch.nn as nn, torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from sklearn.metrics import roc_curve, roc_auc_score
    import timm

# --- Paths/names ---
DRIVE_ROOT  = "/content/drive/MyDrive" if os.path.exists("/content/drive/MyDrive") else "/content/drive/My Drive"
DATASET     = "frames_cropped_faces_1src"
DATA_ROOT   = os.path.join(DRIVE_ROOT, DATASET)        # {real,fake}
WEIGHT_PATH = os.path.join(DRIVE_ROOT, "DeepfakeBench_weights", "f3net_best.pth")
DETECTOR    = "F3Net"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(2)

# --- Image utils ---
MEAN = torch.tensor([0.485,0.456,0.406]).view(1,3,1,1)
STD  = torch.tensor([0.229,0.224,0.225]).view(1,3,1,1)
def pil_to_tensor(img: Image.Image, size: int):
    if img.mode != "RGB": img = img.convert("RGB")
    if img.size != (size, size): img = img.resize((size, size), Image.BILINEAR)
    x = np.asarray(img, dtype=np.float32) / 255.0
    x = torch.from_numpy(x.transpose(2,0,1)).unsqueeze(0)
    return ((x - MEAN) / STD).squeeze(0)

# --- Exact 20 frames/video ---
FNUM = re.compile(r".*?[_-]frame[s]?[_-]?(\d+)\D*$", re.IGNORECASE)
VKEY = re.compile(r"^(.*?)(?:[_-]frames?[_-]?\d+|[_-]frame[_-]?\d+)$", re.IGNORECASE)
def vkey(name):
    b=os.path.splitext(name)[0]; m=VKEY.match(b)
    return m.group(1) if m else b.split("_")[0]
def num_suffix(p):
    m=FNUM.match(os.path.splitext(os.path.basename(p))[0])
    return int(m.group(1)) if m else None
def list_exact20(root):
    exts={".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff",".JPG",".JPEG",".PNG"}
    allp=[]
    for cls,y in (("real",0),("fake",1)):
        d=os.path.join(root,cls)
        if not os.path.isdir(d): continue
        for p in glob.glob(os.path.join(d,"*")):
            if os.path.splitext(p)[1] in exts:
                allp.append((p,y,vkey(os.path.basename(p))))
    if not allp: raise RuntimeError(f"No images under {root}/{{real,fake}}")
    vids={}
    for p,y,k in allp:
        vids.setdefault(k,{"y":y,"paths":[]}); vids[k]["paths"].append(p)
    kept=[]
    for k,info in vids.items():
        ps=info["paths"]; nums=[num_suffix(p) for p in ps]
        if any(n is not None for n in nums):
            prs=sorted([(n if n is not None else 10**9,p) for n,p in zip(nums,ps)], key=lambda x:(x[0],x[1]))
            ps_sorted=[p for _,p in prs]
        else:
            ps_sorted=sorted(ps)
        if len(ps_sorted)<20: ps_sorted = ps_sorted + [ps_sorted[0]]*(20-len(ps_sorted))
        else:                 ps_sorted = ps_sorted[:20]
        for p in ps_sorted: kept.append((p, info["y"], k))
    kept.sort(key=lambda x:(x[1],x[2],x[0])); return kept

# --- Dataset / collate ---
class FramesDS(Dataset):
    def __init__(self, trip): self.s=trip
    def __len__(self): return len(self.s)
    def __getitem__(self,i):
        p,y,k=self.s[i]
        with Image.open(p) as im: x=pil_to_tensor(im, 320)
        return x,y,k
def collate(b): xs,ys,ks=zip(*b); return torch.stack(xs,0), torch.tensor(ys), list(ks)
def center_five_crops(x320):  # -> list of (B,3,299,299)
    B,_,H,W = x320.shape
    offs=[(0,0),(0,W-299),(H-299,0),(H-299,W-299),((H-299)//2,(W-299)//2)]
    return [x320[:,:,oy:oy+299, ox:ox+299] for (oy,ox) in offs]

# --- F3Net (FAD 12ch @299) ---
def dct_matrix(size: int) -> torch.Tensor:
    i = torch.arange(size, dtype=torch.float32); j = torch.arange(size, dtype=torch.float32)
    jj, ii = torch.meshgrid(j, i, indexing='xy')
    mat = torch.cos((jj + 0.5) * torch.pi * ii / size)
    mat[0,:] *= (1.0/torch.sqrt(torch.tensor(size, dtype=torch.float32)))
    mat[1:,:] *= torch.sqrt(2.0/torch.tensor(size, dtype=torch.float32))
    return mat.t()
def make_filter_mask(size, start, end):
    i = np.arange(size); j = np.arange(size)
    ii, jj = np.meshgrid(i, j, indexing='ij'); s = ii + jj
    return ((s >= start) & (s <= end)).astype(np.float32)
class LearnableFilter(nn.Module):
    def __init__(self, size, band_start, band_end, learnable=True, normalize=False):
        super().__init__()
        self.base = nn.Parameter(torch.tensor(make_filter_mask(size, band_start, band_end)), requires_grad=False)
        self.learn = nn.Parameter(torch.randn(size, size) * 0.1, requires_grad=learnable)
        self.normalize = normalize
        if normalize:
            self.ft_num = nn.Parameter(torch.tensor(float(self.base.sum())), requires_grad=False)
    def forward(self, X):
        filt = self.base.to(X.device)
        if self.learn.requires_grad:
            filt = filt + (2.0 * torch.sigmoid(self.learn.to(X.device)) - 1.0)
        return X * (filt / self.ft_num if self.normalize else filt)
class FADHead(nn.Module):
    def __init__(self, size=299):
        super().__init__()
        D = dct_matrix(size)
        self.D  = nn.Parameter(D,     requires_grad=False)
        self.DT = nn.Parameter(D.t(), requires_grad=False)
        self.filters = nn.ModuleList([
            LearnableFilter(size, 0, int(size//2.82)),
            LearnableFilter(size, int(size//2.82), size//2),
            LearnableFilter(size, size//2, size*2),
            LearnableFilter(size, 0, size*2),
        ])
    def _dct2(self, x):
        D, DT = self.D.to(x.device), self.DT.to(x.device)
        xh = torch.einsum('ih, b c h w -> b c i w', D, x)
        xw = torch.einsum('jw, b c i w -> b c i j', D, xh)
        return xw
    def _idct2(self, X):
        D, DT = self.D.to(X.device), self.DT.to(X.device)
        xw = torch.einsum('wj, b c i j -> b c i w', DT, X)
        xh = torch.einsum('hi, b c i w -> b c h w', DT, xw)
        return xh
    def forward(self, x299):  # (B,3,299,299)
        X = self._dct2(x299)
        outs=[ self._idct2(f(X)) for f in self.filters ]
        return torch.cat(outs, dim=1)  # (B,12,299,299)
class F3Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fad = FADHead(299)
        with silent, silent_err:
            self.backbone = timm.create_model("xception41", pretrained=False, num_classes=2, in_chans=12)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x299):
        return self.backbone(self.fad(x299))
def try_load_weights(model, path):
    if not os.path.isfile(path): return False
    try:
        with silent, silent_err: sd=torch.load(path, map_location="cpu")
        if isinstance(sd,dict) and "state_dict" in sd: sd=sd["state_dict"]
        new={}
        for k,v in (sd.items() if isinstance(sd,dict) else []):
            nk=k
            for pref in ("module.","model.","net.","backbone."):
                if nk.startswith(pref): nk=nk[len(pref):]
            new[nk]=v
        with silent, silent_err: model.load_state_dict(new, strict=False)
        return True
    except Exception:
        return False

# --- Helpers: thresholds & labels ---
def agg_video(vk, p, y, how="median"):
    vids={}
    for vv,pp,yy in zip(vk,p,y):
        if vv not in vids: vids[vv]={"p":[], "y":int(yy)}
        vids[vv]["p"].append(float(pp))
    names = sorted(vids.keys())
    P=[]; Y=[]
    for n in names:
        arr = np.array(vids[n]["p"], np.float32)
        s = float(np.median(arr)) if how=="median" else float(np.mean(arr))
        P.append(s); Y.append(vids[n]["y"])
    return names, np.array(P,np.float32), np.array(Y,np.int64)
def youden_thr(y_true, y_score):
    fpr, tpr, thr = roc_curve(y_true, y_score)
    j = tpr - fpr
    return float(thr[np.nanargmax(j)])
def lab2str(y): return "real" if int(y)==0 else "fake"

# --- Build/load + data ---
model = F3Net().to(device).eval()
_ = try_load_weights(model, WEIGHT_PATH)
softmax = nn.Softmax(dim=1)

trip = list_exact20(DATA_ROOT)
ds   = FramesDS(trip)
loader = DataLoader(ds, batch_size=10, shuffle=False, num_workers=0,
                    pin_memory=(device.type=="cuda"), collate_fn=collate)

# --- Quick BN-only TENT (center crop 299) ---
for p in model.parameters(): p.requires_grad=False
bn_params=[]
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        if m.weight is not None: m.weight.requires_grad=True; bn_params.append(m.weight)
        if m.bias   is not None: m.bias.requires_grad=True;   bn_params.append(m.bias)
model.train(); opt = torch.optim.SGD(bn_params, lr=7e-4, momentum=0.9) if bn_params else None
if opt is not None:
    with torch.enable_grad():
        for xb, _, _ in loader:
            xb = xb.to(device, dtype=torch.float32)
            ctr = xb[:,:, (320-299)//2:(320+299)//2, (320-299)//2:(320+299)//2]
            p0 = torch.softmax(model(ctr), dim=1); p1 = torch.softmax(model(torch.flip(ctr,dims=[3])), dim=1)
            p  = (p0+p1)*0.5
            ent = -(p * (p.clamp_min(1e-8).log())).sum(dim=1).mean()
            opt.zero_grad(set_to_none=True); ent.backward(); opt.step()
model.eval()
for p in model.parameters(): p.requires_grad=False

# --- Inference (299 FAD-safe): 5-crop+hflip + {272→299, 299} full-frame+hflip ---
frame_probs, frame_labels, frame_vkeys = [], [], []
with torch.inference_mode():
    for xb, yb, vks in loader:
        xb = xb.to(device, dtype=torch.float32)
        probs_list=[]
        # 5-crop + hflip @299
        for xc in center_five_crops(xb):
            p0 = softmax(model(xc))[:,1]; p1 = softmax(model(torch.flip(xc, dims=[3])))[:,1]
            probs_list.append(((p0+p1)*0.5).cpu().numpy())
        # {272→299, 299} full-frame + hflip
        for sz in (272,299):
            xsz = F.interpolate(xb, size=(sz,sz), mode="bilinear", align_corners=False)
            xsz = F.interpolate(xsz, size=(299,299), mode="bilinear", align_corners=False)
            p0 = softmax(model(xsz))[:,1]; p1 = softmax(model(torch.flip(xsz, dims=[3])))[:,1]
            probs_list.append(((p0+p1)*0.5).cpu().numpy())
        probs = np.mean(np.stack(probs_list, axis=0), axis=0)
        frame_probs.extend(probs.tolist()); frame_labels.extend(yb.numpy().tolist()); frame_vkeys.extend(list(vks))

frame_probs = np.asarray(frame_probs, np.float32)
frame_labels= np.asarray(frame_labels, np.int64)
frame_vkeys = np.asarray(frame_vkeys)

# --- Orientation flip (auto 1−p via video-median AUC) ---
_, Pm, Yv = agg_video(frame_vkeys, frame_probs, frame_labels, "median")
if roc_auc_score(Yv, 1.0 - Pm) > roc_auc_score(Yv, Pm):
    frame_probs = 1.0 - frame_probs

# --- Thresholds (frame Youden; video-avg Youden) ---
thr_frame = youden_thr(frame_labels, frame_probs)
names_avg, P_avg, Y_avg = agg_video(frame_vkeys, frame_probs, frame_labels, "mean")
thr_vid_avg = youden_thr(Y_avg, P_avg)

# --- Build Large table rows ---
video = {}
for v,p,y in zip(frame_vkeys, frame_probs, frame_labels):
    d = video.setdefault(v, {"probs": [], "label": int(y)})
    d["probs"].append(float(p))

rows=[]
for v in sorted(video.keys()):
    probs = np.array(video[v]["probs"], dtype=np.float32)
    y_int = int(video[v]["label"]); y_str = lab2str(y_int)
    n_frames = int(probs.size)  # should be 20

    yhat = (probs >= thr_frame).astype(int)
    n_correct = int((yhat == y_int).sum())
    n_wrong   = int(n_frames - n_correct)
    frame_acc = round(n_correct / float(n_frames), 4)

    avg_p = float(np.mean(probs)); std_p = float(np.std(probs))

    pred_avg_int = int(avg_p >= thr_vid_avg)
    pred_avg_str = lab2str(pred_avg_int)
    correct_avg  = int(pred_avg_int == y_int)

    pred_maj_int = int((yhat.sum() >= math.ceil(n_frames/2)))
    pred_maj_str = lab2str(pred_maj_int)
    correct_maj  = int(pred_maj_int == y_int)

    rows.append({
        "dataset": DATASET,
        "detector": DETECTOR,
        "video_name": v,
        "true_label": y_str,
        "n_frames": n_frames,
        "n_correct_frames": n_correct,
        "n_wrong_frames": n_wrong,
        "frame_accuracy": frame_acc,
        "avg_prob_fake": round(avg_p, 6),
        "std_prob_fake": round(std_p, 6),
        "video_pred_by_avg": pred_avg_str,
        "video_correct_by_avg": correct_avg,
        "video_pred_by_majority": pred_maj_str,
        "video_correct_by_majority": correct_maj,
    })

df = pd.DataFrame(rows).sort_values(["true_label","video_name"]).reset_index(drop=True)

# --- Print ALL rows without column breaks ---
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 10_000)
pd.set_option("display.colheader_justify", "left")
print(df.to_string(index=False))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
dataset                   detector video_name true_label  n_frames  n_correct_frames  n_wrong_frames  frame_accuracy  avg_prob_fake  std_prob_fake video_pred_by_avg  video_correct_by_avg video_pred_by_majority  video_correct_by_majority
frames_cropped_faces_1src F3Net          1_1  fake       20        14                 6              0.70            0.365100       0.003101       fake              1                     fake                   1                         
frames_cropped_faces_1src F3Net         1_10  fake       20        20                 0              1.00            0.373941       0.003270       fake              1                     fake                   1                         
frames_cropped_faces_1src F3Net         1_11  fake       20        20                 0              1.00            0.376750       0.003180       fake         

In [None]:
# Save the F3Net Large table (df) to Drive
import os

# Ensure the DataFrame exists
if 'df' not in globals():
    raise RuntimeError("Large table DataFrame 'df' not found. Run the table cell first.")

DRIVE_ROOT = "/content/drive/MyDrive" if os.path.exists("/content/drive/MyDrive") else "/content/drive/My Drive"
OUT_DIR = os.path.join(DRIVE_ROOT, "F3net results 1 src")
os.makedirs(OUT_DIR, exist_ok=True)

out_path = os.path.join(OUT_DIR, "F3Net large table 1src.csv")
df.to_csv(out_path, index=False)
print(out_path)


/content/drive/MyDrive/F3net results 1 src/F3Net large table 1src.csv


In [None]:
# === F3Net 1-src — SMALL TABLE (majority vote) ===
# Columns: dataset, detector, video_name, true_label, correctly_predicted

import numpy as np, pandas as pd, math
from sklearn.metrics import roc_curve, roc_auc_score

DATASET  = "frames_cropped_faces_1src"
DETECTOR = "F3Net"

def lab2str(y): return "real" if int(y)==0 else "fake"

def youden_thr(y_true, y_score):
    fpr, tpr, thr = roc_curve(y_true, y_score)
    j = tpr - fpr
    return float(thr[np.nanargmax(j)])

def agg_video(vk, p, y, how="median"):
    vids={}
    for vv,pp,yy in zip(vk,p,y):
        if vv not in vids: vids[vv]={"p":[], "y":int(yy)}
        vids[vv]["p"].append(float(pp))
    names = sorted(vids.keys())
    P=[]; Y=[]
    for n in names:
        arr = np.array(vids[n]["p"], np.float32)
        s = float(np.median(arr)) if how=="median" else float(np.mean(arr))
        P.append(s); Y.append(vids[n]["y"])
    return names, np.array(P,np.float32), np.array(Y,np.int64)

# --- Fast path: derive from existing large table 'df'
if 'df' in globals():
    small = df[['dataset','detector','video_name','true_label','video_correct_by_majority']].copy()
    small['correctly_predicted'] = small['video_correct_by_majority'].map({1:'yes', 0:'no'})
    small = small.drop(columns=['video_correct_by_majority'])

# --- Fallback: rebuild from per-frame arrays (keeps logic consistent)
else:
    missing = [n for n in ("frame_probs","frame_labels","frame_vkeys") if n not in globals()]
    if missing:
        raise RuntimeError(f"Missing variables: {missing}. Run your F3Net matrices/large-table cell first.")

    fp = np.asarray(frame_probs,  dtype=np.float32)
    fl = np.asarray(frame_labels, dtype=np.int64)
    vk = np.asarray(frame_vkeys)

    # Orientation (auto 1−p using VIDEO-level MEDIAN AUC)
    _, Pm, Yv = agg_video(vk, fp, fl, "median")
    if roc_auc_score(Yv, 1.0 - Pm) > roc_auc_score(Yv, Pm):
        fp = 1.0 - fp

    # Frame-level Youden threshold (for majority vote)
    thr_frame = youden_thr(fl, fp)

    # Majority per video
    vids = {}
    for v,p,y in zip(vk, fp, fl):
        d = vids.setdefault(v, {"probs": [], "label": int(y)})
        d["probs"].append(float(p))

    rows=[]
    for v in sorted(vids.keys()):
        probs = np.array(vids[v]["probs"], np.float32)
        y_int = vids[v]["label"]
        y_str = lab2str(y_int)
        n = probs.size
        yhat = (probs >= thr_frame).astype(int)
        pred_maj_int = int((yhat.sum() >= math.ceil(n/2)))
        correct_maj  = (pred_maj_int == y_int)
        rows.append({
            "dataset": DATASET,
            "detector": DETECTOR,
            "video_name": v,
            "true_label": y_str,
            "correctly_predicted": "yes" if correct_maj else "no",
        })
    small = pd.DataFrame(rows).sort_values(["true_label","video_name"]).reset_index(drop=True)

# Print all rows with no column breaks
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 10_000)
print(small.to_string(index=False))


dataset                   detector video_name true_label correctly_predicted
frames_cropped_faces_1src F3Net          1_1  fake       yes                
frames_cropped_faces_1src F3Net         1_10  fake       yes                
frames_cropped_faces_1src F3Net         1_11  fake       yes                
frames_cropped_faces_1src F3Net         1_12  fake       yes                
frames_cropped_faces_1src F3Net         1_13  fake       yes                
frames_cropped_faces_1src F3Net         1_14  fake        no                
frames_cropped_faces_1src F3Net         1_15  fake       yes                
frames_cropped_faces_1src F3Net         1_16  fake       yes                
frames_cropped_faces_1src F3Net         1_17  fake       yes                
frames_cropped_faces_1src F3Net         1_18  fake       yes                
frames_cropped_faces_1src F3Net         1_19  fake        no                
frames_cropped_faces_1src F3Net          1_2  fake        no                

In [None]:
# Save the F3Net small table (small) to the same folder
import os

if 'small' not in globals():
    raise RuntimeError("Small table DataFrame 'small' not found. Run the small-table cell first.")

DRIVE_ROOT = "/content/drive/MyDrive" if os.path.exists("/content/drive/MyDrive") else "/content/drive/My Drive"
OUT_DIR = os.path.join(DRIVE_ROOT, "F3net results 1 src")
os.makedirs(OUT_DIR, exist_ok=True)

out_path = os.path.join(OUT_DIR, "F3Net small table 1src.csv")
small.to_csv(out_path, index=False)
print(out_path)


/content/drive/MyDrive/F3net results 1 src/F3Net small table 1src.csv
