In [None]:
# NOTE: Paste these cells into a Jupyter notebook (or save as .py and run stepwise).
# Deepfake detection — ViT (RGB) + Custom forensic CNN (ELA, HP, FFT, PRNU-like)
# Author: ChatGPT (GPT-5 Thinking mini)
# Date: 2025-10-19

# -------------------------- CONFIG --------------------------
DATA_ROOT = "./dataset_split"   # must contain train/real, train/fake, val/real, val/fake, (optional) test/...
OUTPUT_DIR = "./checkpoints"
IMG_SIZE = 224                  # ViT standard: 224 (use 384 if you want higher res and have memory)
PATCHES_PER_IMAGE = 4           # random patches per image during training
BATCH_SIZE = 16                 # try 16 / 24 / 32 on RTX 3090
NUM_WORKERS = 6
EPOCHS = 20
LR = 1e-4
WEIGHT_DECAY = 1e-4
MIXUP_ALPHA = 0.2
SEED = 42
DEVICE = "cuda" if __import__("torch").cuda.is_available() else "cpu"
DEBUG_RUN = False               # set True for a tiny quick sanity run
os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------- IMPORTS --------------------------
import os, random, time, json
from glob import glob
from io import BytesIO
from PIL import Image, ImageChops, ImageFilter
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# timm for ViT
try:
    import timm
    TIMM_AVAILABLE = True
except Exception:
    TIMM_AVAILABLE = False

from sklearn.metrics import roc_auc_score
from torch.cuda.amp import GradScaler, autocast
import torch.optim as optim

print("Device:", DEVICE, "timm available:", TIMM_AVAILABLE)

# --------------------- DETERMINISM ---------------------
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

# ------------------ FORENSIC TRANSFORMS ------------------
def compute_ela(pil_img: Image.Image, quality=90) -> Image.Image:
    rgb = pil_img.convert('RGB')
    buf = BytesIO()
    rgb.save(buf, format='JPEG', quality=quality)
    buf.seek(0)
    comp = Image.open(buf).convert('RGB')
    ela = ImageChops.difference(rgb, comp).convert('L')
    arr = np.array(ela).astype(np.float32)
    maxv = arr.max() if arr.max() > 0 else 1.0
    arr = (arr / maxv * 255.0).astype(np.uint8)
    return Image.fromarray(arr, mode='L')

def high_pass_residual(pil_img: Image.Image, radius=3) -> Image.Image:
    rgb = pil_img.convert('RGB')
    blurred = rgb.filter(ImageFilter.GaussianBlur(radius=radius))
    residual = ImageChops.difference(rgb, blurred).convert('L')
    arr = np.array(residual).astype(np.float32)
    arr = (arr - arr.min()) / (arr.max()-arr.min()+1e-8) * 255.0
    return Image.fromarray(arr.astype(np.uint8), mode='L')

def prnu_like_residual(pil_img: Image.Image) -> Image.Image:
    rgb = pil_img.convert('RGB')
    denoised = rgb.filter(ImageFilter.MedianFilter(size=3))
    residual = ImageChops.difference(rgb, denoised).convert('L')
    arr = np.array(residual).astype(np.float32)
    arr = (arr - arr.min()) / (arr.max()-arr.min()+1e-8) * 255.0
    return Image.fromarray(arr.astype(np.uint8), mode='L')

def fft_magnitude_map(pil_img: Image.Image) -> Image.Image:
    gray = pil_img.convert('L')
    arr = np.array(gray).astype(np.float32)
    f = np.fft.fft2(arr)
    fshift = np.fft.fftshift(f)
    mag = np.log(np.abs(fshift) + 1e-8)
    mag = (mag - mag.min()) / (mag.max()-mag.min()+1e-8) * 255.0
    return Image.fromarray(mag.astype(np.uint8), mode='L')

# --------------------- DATASET ---------------------
IMG_EXTS = ['.jpg', '.jpeg', '.png', '.bmp']

class DeepfakeForensicDataset(Dataset):
    def __init__(self, root_dir, split='train', img_size=IMG_SIZE, patches_per_image=PATCHES_PER_IMAGE, mode='train'):
        self.root = os.path.join(root_dir, split)
        assert os.path.exists(self.root), f"{self.root} not found"
        self.files = []
        for cls, lbl in [('real',0), ('fake',1)]:
            folder = os.path.join(self.root, cls)
            if not os.path.isdir(folder):
                continue
            for ext in IMG_EXTS:
                self.files += [(p, lbl) for p in glob(os.path.join(folder, f"*{ext}"))]
        random.shuffle(self.files)
        self.img_size = img_size
        self.patches = patches_per_image if mode == 'train' else 1
        self.mode = mode
        self.base_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor()])
        self.norm = transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])

    def __len__(self): return len(self.files)

    def read_image(self, path):
        return Image.open(path).convert('RGB')

    def random_patch_coords(self, w,h,pw,ph):
        if w==pw and h==ph: return (0,0,pw,ph)
        x = random.randint(0, max(0, w-pw)); y = random.randint(0, max(0, h-ph))
        return (x,y,x+pw,y+ph)

    def crop_and_compute(self, img, crop_box=None):
        crop = img.crop(crop_box) if crop_box else img.copy()
        # forensic maps
        ela = compute_ela(crop, quality=90).resize((self.img_size,self.img_size))
        hp = high_pass_residual(crop, radius=3).resize((self.img_size,self.img_size))
        fft = fft_magnitude_map(crop).resize((self.img_size,self.img_size))
        pr = prnu_like_residual(crop).resize((self.img_size,self.img_size))
        # rgb
        rgb = crop.resize((self.img_size,self.img_size))
        rgb_t = self.base_transform(rgb)
        ela_t = transforms.ToTensor()(ela); hp_t = transforms.ToTensor()(hp)
        fft_t = transforms.ToTensor()(fft); pr_t = transforms.ToTensor()(pr)
        rgb_t = self.norm(rgb_t)
        forensic = torch.cat([ela_t, hp_t, fft_t, pr_t], dim=0)  # [4,H,W]
        return rgb_t, forensic

    def __getitem__(self, idx):
        path, label = self.files[idx]
        img = self.read_image(path)
        w,h = img.size
        if self.patches == 1:
            coords = (0,0,w,h) if (w<=self.img_size and h<=self.img_size) else self.random_patch_coords(w,h,self.img_size,self.img_size)
            rgb_t, for_t = self.crop_and_compute(img, crop_box=coords)
            return {'rgb': rgb_t, 'forensic': for_t, 'label': torch.tensor(label, dtype=torch.float32), 'path': path}
        else:
            rgb_patches=[]; for_patches=[]
            for _ in range(self.patches):
                coords = self.random_patch_coords(w,h,self.img_size,self.img_size)
                r,f = self.crop_and_compute(img, crop_box=coords)
                rgb_patches.append(r.unsqueeze(0)); for_patches.append(f.unsqueeze(0))
            rgb_stack = torch.cat(rgb_patches, dim=0)   # [P,3,H,W]
            for_stack = torch.cat(for_patches, dim=0)   # [P,4,H,W]
            return {'rgb': rgb_stack, 'forensic': for_stack, 'label': torch.tensor(label, dtype=torch.float32), 'path': path}

# --------------------- MODEL ---------------------
class ForensicCNN(nn.Module):
    def __init__(self, in_ch=4, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.out_dim = out_dim

    def forward(self, x):
        if x.dim() == 5:
            B,P,C,H,W = x.shape
            x = x.view(B*P, C, H, W)
            o = self.net(x).view(B, P, -1).mean(dim=1)
            return o
        else:
            return self.net(x).view(x.size(0), -1)

class ViTForensicsFusion(nn.Module):
    def __init__(self, vit_name='vit_base_patch16_224', pretrained=True):
        super().__init__()
        assert TIMM_AVAILABLE, "Install timm to use ViT backbones"
        self.vit = timm.create_model(vit_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        self.rgb_dim = self.vit.num_features
        self.forensic = ForensicCNN(in_ch=4, out_dim=128)
        # fusion MLP
        self.fusion = nn.Sequential(
            nn.Linear(self.rgb_dim + self.forensic.out_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1)
        )

    def forward(self, rgb, forensic):
        # rgb: [B,3,H,W] or [B,P,3,H,W]
        if rgb.dim() == 5:
            B,P,C,H,W = rgb.shape
            rgb = rgb.view(B*P, C, H, W)
            rfeat = self.vit(rgb)
            rfeat = rfeat.view(B, P, -1).mean(dim=1)
        else:
            rfeat = self.vit(rgb)
        ffeat = self.forensic(forensic)
        x = torch.cat([rfeat, ffeat], dim=1)
        logits = self.fusion(x).squeeze(1)
        return logits

# ------------------ TRAIN / EVAL UTILITIES ------------------
def mixup_data(x1, x2, y, alpha=0.2):
    if alpha <= 0:
        return x1, x2, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = y.size(0)
    index = torch.randperm(batch_size).to(y.device)
    x1 = x1 * lam + x1[index] * (1 - lam)
    x2 = x2 * lam + x2[index] * (1 - lam)
    y_a, y_b = y, y[index]
    return x1, x2, y_a, y_b, lam

def build_loaders(root):
    train_ds = DeepfakeForensicDataset(root, split='train', img_size=IMG_SIZE, patches_per_image=PATCHES_PER_IMAGE, mode='train')
    val_ds = DeepfakeForensicDataset(root, split='val', img_size=IMG_SIZE, patches_per_image=1, mode='val')
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    return train_loader, val_loader

def train_epoch(model, loader, optimizer, scaler, device, mixup_alpha=0.2):
    model.train()
    losses=[]; preds=[]; targs=[]
    for batch in loader:
        rgb = batch['rgb'].to(device); forensic = batch['forensic'].to(device); label = batch['label'].to(device)
        if mixup_alpha>0:
            rgb, forensic, y_a, y_b, lam = mixup_data(rgb, forensic, label, mixup_alpha)
            optimizer.zero_grad()
            with autocast():
                out = model(rgb, forensic)
                loss = lam * F.binary_cross_entropy_with_logits(out, y_a) + (1-lam) * F.binary_cross_entropy_with_logits(out, y_b)
            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        else:
            optimizer.zero_grad()
            with autocast():
                out = model(rgb, forensic)
                loss = F.binary_cross_entropy_with_logits(out, label)
            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        losses.append(loss.item())
        probs = torch.sigmoid(out).detach().cpu().numpy().tolist()
        preds += probs; targs += label.detach().cpu().numpy().tolist()
    auc = roc_auc_score(targs, preds) if len(set(targs))>1 else 0.0
    return float(np.mean(losses)), float(auc)

def eval_epoch(model, loader, device):
    model.eval()
    losses=[]; preds=[]; targs=[]
    with torch.no_grad():
        for batch in loader:
            rgb = batch['rgb'].to(device); forensic = batch['forensic'].to(device); label = batch['label'].to(device)
            out = model(rgb, forensic)
            loss = F.binary_cross_entropy_with_logits(out, label)
            losses.append(loss.item())
            probs = torch.sigmoid(out).detach().cpu().numpy().tolist()
            preds += probs; targs += label.detach().cpu().numpy().tolist()
    auc = roc_auc_score(targs, preds) if len(set(targs))>1 else 0.0
    return float(np.mean(losses)), float(auc)

In [None]:

# ------------------ TRAINING LAUNCH (call cells step-by-step) ------------------
if DEBUG_RUN:
    # quick dry-run to verify dataloaders & forward pass (VERY small)
    print("DEBUG_RUN: building loaders and running tiny pass")
    tr_loader, val_loader = build_loaders(DATA_ROOT)
    it = iter(tr_loader); batch = next(it)
    print("Sample batch keys:", batch.keys())
    # prepare model
    assert TIMM_AVAILABLE, "install timm for ViT"
    model = ViTForensicsFusion(vit_name='vit_base_patch16_224', pretrained=True).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scaler = GradScaler()
    # forward pass (may be heavy - if it OOM reduce BATCH_SIZE & IMG_SIZE)
    rgb = batch['rgb'].to(DEVICE); forensic = batch['forensic'].to(DEVICE); label=batch['label'].to(DEVICE)
    with autocast():
        out = model(rgb, forensic)
        loss = F.binary_cross_entropy_with_logits(out, label)
    print("DEBUG forward OK - loss:", float(loss.detach().cpu().numpy()))
else:
    print("DEBUG_RUN is False. Set DEBUG_RUN = True to perform a quick check run before full training.")

# ------------------ USAGE / TRAINING LOOP (run manually) ------------------
# Example training loop you can copy into a cell to run full training:
#
# train_loader, val_loader = build_loaders(DATA_ROOT)
# model = ViTForensicsFusion(vit_name='vit_base_patch16_224', pretrained=True).to(DEVICE)
# optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# scaler = GradScaler()
# best_auc = -1.0
# for epoch in range(1, EPOCHS+1):
#     t0 = time.time()
#     tr_loss, tr_auc = train_epoch(model, train_loader, optimizer, scaler, DEVICE, mixup_alpha=MIXUP_ALPHA)
#     val_loss, val_auc = eval_epoch(model, val_loader, DEVICE)
#     scheduler.step()
#     print(f"Epoch {epoch} | tr_loss {tr_loss:.4f} tr_auc {tr_auc:.4f} | val_loss {val_loss:.4f} val_auc {val_auc:.4f} | time {(time.time()-t0):.1f}s")
#     torch.save({'epoch':epoch,'model_state':model.state_dict(),'optimizer':optimizer.state_dict(),'val_auc':val_auc},
#                os.path.join(OUTPUT_DIR, f"epoch{epoch:02d}_auc{val_auc:.4f}.pth"))
#     if val_auc > best_auc + 1e-4:
#         best_auc = val_auc
#         torch.save({'epoch':epoch,'model_state':model.state_dict(),'optimizer':optimizer.state_dict(),'val_auc':val_auc},
#                    os.path.join(OUTPUT_DIR, "best.pth"))
#     # optional early stopping check here
#
# After training: load best.pth and use inference helper shown below.

In [None]:


# ------------------ INFERENCE HELPER ------------------
def predict_image(model, path, device, top_k_patches=PATCHES_PER_IMAGE):
    model.eval()
    img = Image.open(path).convert('RGB')
    w,h = img.size
    patches = []
    for i in range(top_k_patches):
        x = random.randint(0, max(0, w-IMG_SIZE))
        y = random.randint(0, max(0, h-IMG_SIZE))
        crop = img.crop((x,y,x+IMG_SIZE,y+IMG_SIZE)).resize((IMG_SIZE,IMG_SIZE))
        rgb_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])(crop).unsqueeze(0)
        ela = transforms.ToTensor()(compute_ela(crop)).unsqueeze(0)
        hp = transforms.ToTensor()(high_pass_residual(crop)).unsqueeze(0)
        fft = transforms.ToTensor()(fft_magnitude_map(crop)).unsqueeze(0)
        pr = transforms.ToTensor()(prnu_like_residual(crop)).unsqueeze(0)
        forensic = torch.cat([ela,hp,fft,pr], dim=1)
        patches.append((rgb_t, forensic))
    rgbs = torch.cat([p[0] for p in patches], dim=0).unsqueeze(0)  # [1,P,3,H,W]
    fors = torch.cat([p[1] for p in patches], dim=0).unsqueeze(0)   # [1,P,4,H,W]
    with torch.no_grad():
        logits = model(rgbs.to(device), fors.to(device))
        probs = torch.sigmoid(logits).cpu().numpy().tolist()
    return float(np.mean(probs)), probs

# ------------------ QUICK VISUALIZATION OF A SAMPLE (optional) ------------------
def visualize_sample(path):
    img = Image.open(path).convert('RGB')
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1); plt.imshow(img); plt.axis('off'); plt.title('RGB')
    f = fft_magnitude_map(img)
    plt.subplot(1,2,2); plt.imshow(f, cmap='gray'); plt.axis('off'); plt.title('FFT magnitude')
    plt.show()
