In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
"""
ResNet-50 Deepfake Detector - Kaggle GPU Optimized
VERSION FOR COMBINED IMAGE (e.g., WildDeepfake) & VIDEO (e.g., FF++) DATASETS
"""

import argparse, io, json, math, os, random, time, sys
from pathlib import Path
from collections import defaultdict

# --- [SETUP] Install necessary libraries ---
try:
    import cv2
except ImportError:
    print("OpenCV not found. Installing opencv-python-headless...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "opencv-python-headless"])
    import cv2

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler, ConcatDataset
from torchvision import datasets, transforms, models

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, roc_auc_score, roc_curve
)

# ----------------------- Utilities & Repro -----------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.benchmark = True
    print(f"Seed set to {seed}.")

# ----------------------- Social-style Augmentations -----------------------
class RandomJPEGCompression:
    def __init__(self, qmin=35, qmax=92, p=0.7): self.qmin, self.qmax, self.p = qmin, qmax, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        buf = io.BytesIO(); img.save(buf, format="JPEG", quality=random.randint(self.qmin, self.qmax))
        buf.seek(0); return Image.open(buf).convert("RGB")
class RandomDownscale:
    def __init__(self, scale_min=0.4, scale_max=0.85, p=0.6): self.scale_min, self.scale_max, self.p = scale_min, scale_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        w, h = img.size; s = random.uniform(self.scale_min, self.scale_max)
        nw, nh = max(8, int(w*s)), max(8, int(h*s))
        img = img.resize((nw, nh), resample=Image.BILINEAR); img = img.resize((w, h), resample=Image.BILINEAR)
        return img
class RandomGaussianNoise:
    def __init__(self, sigma_min=0.0, sigma_max=0.03, p=0.4): self.sigma_min, self.sigma_max, self.p = sigma_min, sigma_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        arr = np.asarray(img).astype(np.float32)/255.0; sigma = random.uniform(self.sigma_min, self.sigma_max)
        noise = np.random.normal(0.0, sigma, arr.shape).astype(np.float32); arr = np.clip(arr + noise, 0.0, 1.0)
        arr = (arr*255.0 + 0.5).astype(np.uint8); return Image.fromarray(arr)

# ----------------------- Custom Video Dataset -----------------------
class VideoFrameDataset(torch.utils.data.Dataset):
    def __init__(self, root: Path, transform=None, class_to_idx=None):
        self.root = root; self.transform = transform; self.class_to_idx = class_to_idx
        self.samples = []
        # ***** THE FIX IS HERE (1/2): Initialize targets list *****
        self.targets = []
        
        for class_name in self.class_to_idx.keys():
            class_dir = self.root / class_name
            if not class_dir.exists(): continue
            for video_path in sorted(list(class_dir.rglob('*.mp4'))):
                label = self.class_to_idx[class_name]
                self.samples.append((str(video_path), label))
                # ***** THE FIX IS HERE (2/2): Add label to targets list *****
                self.targets.append(label)

        if not self.samples: print(f"  --> WARNING: No .mp4 videos found in {self.root}")
        
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        cap = None
        try:
            cap = cv2.VideoCapture(video_path)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames < 1: return self.__getitem__((idx + 1) % len(self))
            frame_idx = random.randint(0, total_frames - 1)
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret: return self.__getitem__((idx + 1) % len(self))
        except Exception as e:
            print(f"  --> Error processing {video_path}: {e}. Skipping."); return self.__getitem__((idx + 1) % len(self))
        finally:
            if cap: cap.release()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB); img = Image.fromarray(frame_rgb)
        if self.transform: img = self.transform(img)
        return img, label

# ----------------------- Data Loading Pipeline -----------------------
def build_indices_limit_per_seq(dataset: ConcatDataset, max_per_seq=60, seed=42):
    if max_per_seq is None or max_per_seq <= 0: return list(range(len(dataset)))
    by_seq = defaultdict(list)
    print("Building sequence-limited sampler for combined dataset...")
    for idx, (path, _) in enumerate(tqdm(dataset.samples, desc="Indexing sequences")):
        path_obj = Path(path)
        seq_id = path_obj.stem if path_obj.suffix.lower() == '.mp4' else path_obj.parent.name
        by_seq[seq_id].append(idx)
    rng = random.Random(seed); indices = []
    for seq, idxs in by_seq.items(): rng.shuffle(idxs); indices.extend(idxs[:max_per_seq])
    rng.shuffle(indices); return indices

def build_loaders(image_data_root, video_data_root, img_size=224, batch_size=64, workers=2, per_seq_cap=60):
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    train_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.2)), transforms.RandomResizedCrop(img_size, scale=(0.55, 1.0), ratio=(0.75, 1.33)),
        transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=10),
        transforms.RandomPerspective(distortion_scale=0.25, p=0.25), RandomJPEGCompression(qmin=35, qmax=92, p=0.7),
        RandomDownscale(scale_min=0.4, scale_max=0.85, p=0.6), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)),
        transforms.ColorJitter(0.15, 0.15, 0.15, 0.05), transforms.RandomAdjustSharpness(sharpness_factor=0.6, p=0.3),
        transforms.RandomAutocontrast(p=0.3), RandomGaussianNoise(sigma_min=0.0, sigma_max=0.03, p=0.4),
        transforms.ToTensor(), normalize, transforms.RandomErasing(p=0.25, scale=(0.02,0.2), ratio=(0.3,3.3), value='random'),
    ])
    val_tf = transforms.Compose([transforms.Resize(int(img_size*1.15)), transforms.CenterCrop(img_size), transforms.ToTensor(), normalize])
    train_datasets, val_datasets = [], []; master_class_to_idx = {'fake': 0, 'real': 1}
    if image_data_root:
        root = Path(image_data_root); print(f"Loading IMAGE data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(datasets.ImageFolder(root/"train", transform=train_tf))
            val_datasets.append(datasets.ImageFolder(root/"val", transform=val_tf))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if video_data_root:
        root = Path(video_data_root); print(f"Loading VIDEO data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(VideoFrameDataset(root/"train", transform=train_tf, class_to_idx=master_class_to_idx))
            val_datasets.append(VideoFrameDataset(root/"val", transform=val_tf, class_to_idx=master_class_to_idx))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if not train_datasets: raise ValueError("No valid datasets were loaded. Check paths in USER CONFIGURATION.")
    train_ds = ConcatDataset(train_datasets); val_ds = ConcatDataset(val_datasets)
    train_ds.class_to_idx = master_class_to_idx; train_ds.classes = list(master_class_to_idx.keys())
    train_ds.samples = [s for ds in train_ds.datasets for s in ds.samples]
    train_ds.targets = [t for ds in train_ds.datasets for t in ds.targets]
    # Also create combined targets for the validation set for completeness
    val_ds.targets = [t for ds in val_ds.datasets for t in ds.targets]

    print(f"\nCombined {len(train_datasets)} dataset(s). Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}")
    indices = build_indices_limit_per_seq(train_ds, max_per_seq=per_seq_cap, seed=42)
    num_workers = 2 if os.cpu_count() <= 4 else 4
    train_dl = DataLoader(train_ds, batch_size=batch_size, sampler=SubsetRandomSampler(indices), num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    return train_ds, val_ds, train_dl, val_dl

# --- [The rest of the helper functions are unchanged] ---
class WarmupThenCosine(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1):
        self.warmup_steps=max(1,warmup_steps); self.total_steps=max(self.warmup_steps+1,total_steps); super().__init__(optimizer, last_epoch)
    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.warmup_steps: return [base * (step/self.warmup_steps) for base in self.base_lrs]
        t = (step-self.warmup_steps)/(self.total_steps-self.warmup_steps); return [base*0.5*(1+math.cos(math.pi*t)) for base in self.base_lrs]
class EMA:
    def __init__(self, model, decay=0.999): self.decay = decay; self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if not v.is_floating_point(): self.shadow[k] = v.detach().clone(); continue
            if (self.shadow[k].dtype!=v.dtype)or(self.shadow[k].device!=v.device): self.shadow[k]=v.detach().clone(); continue
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
def rand_bbox(W, H, lam):
    cut_w=int(W*math.sqrt(1-lam)); cut_h=int(H*math.sqrt(1-lam)); cx,cy=random.randint(0,W-1),random.randint(0,H-1)
    x1,y1=max(cx-cut_w//2,0),max(cy-cut_h//2,0); x2,y2=min(cx+cut_w//2,W),min(cy+cut_h//2,H); return x1,y1,x2,y2
def mixup_data(x, y, alpha=0.2):
    lam=np.random.beta(alpha,alpha) if alpha>0 else 1.0; idx=torch.randperm(x.size(0),device=x.device)
    mixed_x=lam*x+(1-lam)*x[idx]; y_a,y_b=y,y[idx]; return mixed_x,y_a,y_b,lam
def mixup_criterion(criterion, pred, y_a, y_b, lam): return lam*criterion(pred,y_a)+(1-lam)*criterion(pred,y_b)
@torch.no_grad()
def tta_logits(model,x,device,tta=4):
    outs=[]
    for op in range(tta):
        xx=x;
        if op%2==1: xx=torch.flip(xx,dims=[-1])
        with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")): outs.append(model(xx))
    return torch.stack(outs,0).mean(0)
@torch.no_grad()
def evaluate(model,dl,device,pos_index,tta=4):
    model.eval(); y_true,y_pred,y_prob_pos=[],[],[]; total_loss,n=0.0,0
    criterion=nn.CrossEntropyLoss()
    for x,y in dl:
        x,y=x.to(device,non_blocking=True),y.to(device,non_blocking=True)
        logits=tta_logits(model,x,device,tta=tta) if tta and tta>1 else model(x)
        with torch.amp.autocast(device_type=device.type,enabled=(device.type=="cuda")): loss=criterion(logits,y)
        prob=torch.softmax(logits,dim=1)
        y_true.extend(y.cpu().tolist());y_pred.extend(prob.argmax(1).cpu().tolist());y_prob_pos.extend(prob[:,pos_index].cpu().tolist())
        bs=y.size(0);total_loss+=loss.item()*bs;n+=bs
    acc=accuracy_score(y_true,y_pred); prec,rec,f1,_=precision_recall_fscore_support(y_true,y_pred,average="binary",pos_label=pos_index,zero_division=0)
    try: auc=roc_auc_score([1 if t==pos_index else 0 for t in y_true], y_prob_pos)
    except Exception: auc=float("nan")
    cm=confusion_matrix(y_true,y_pred,labels=[0,1]); return (total_loss/n,acc,prec,rec,f1,auc,np.array(cm),np.array(y_true),np.array(y_prob_pos))
def plot_curves(history,out_dir:Path):
    out_dir.mkdir(parents=True,exist_ok=True);epochs=np.arange(1,len(history["train_loss"])+1)
    plt.figure();plt.plot(epochs,history["train_loss"],label="train loss");plt.plot(epochs,history["val_loss"],label="val loss");plt.xlabel("epoch");plt.ylabel("loss");plt.legend();plt.tight_layout();plt.savefig(out_dir/"loss_curves.png");plt.close()
    plt.figure();plt.plot(epochs,history["train_acc"],label="train acc");plt.plot(epochs,history["val_acc"],label="val acc");plt.xlabel("epoch");plt.ylabel("accuracy");plt.legend();plt.tight_layout();plt.savefig(out_dir/"acc_curves.png");plt.close()
def plot_cm(cm,out_dir:Path,class_names):
    plt.figure();plt.imshow(cm,interpolation="nearest");plt.title("Confusion Matrix");plt.colorbar();ticks=np.arange(len(class_names));plt.xticks(ticks,class_names,rotation=45);plt.yticks(ticks,class_names);thresh=cm.max()/2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]): plt.text(j,i,int(cm[i,j]),ha="center",va="center",color="white" if cm[i,j]>thresh else "black")
    plt.ylabel("True");plt.xlabel("Predicted");plt.tight_layout();plt.savefig(out_dir/"confusion_matrix.png");plt.close()
def plot_roc_and_save(y_true_bin,y_prob_pos,out_dir:Path,pos_label_name="fake"):
    fpr,tpr,_=roc_curve(y_true_bin,y_prob_pos);auc=roc_auc_score(y_true_bin,y_prob_pos)
    plt.figure();plt.plot(fpr,tpr,label=f"AUC={auc:.3f}");plt.plot([0,1],[0,1],"--");plt.xlabel("FPR");plt.ylabel("TPR");plt.title(f"ROC ({pos_label_name} positive)");plt.legend();plt.tight_layout();plt.savefig(out_dir/"roc_curve.png");plt.close();return auc

# ----------------------- Main -----------------------
def main():
    
    # =================================================================================
    # ===> KAGGLE USER CONFIGURATION SECTION (CRITICAL!) <===
    # THIS SECTION HAS BEEN PRE-FILLED BASED ON YOUR SCREENSHOT.
    
    # Path to the IMAGE dataset folder (the one containing train/ and val/).
    USER_IMAGE_DATA_ROOT = "/kaggle/input/faceforensics/WD_subset_png"

    # Path to the VIDEO dataset folder (the one containing train/ and val/).
    USER_VIDEO_DATA_ROOT = "/kaggle/input/wilddeepfake/ffpp_subset_c23"

    # Output is saved to the standard Kaggle working directory
    USER_OUT_DIR = "/kaggle/working/training_results"
    # =================================================================================

    args_dict = {
        "image_data_root": USER_IMAGE_DATA_ROOT, "video_data_root": USER_VIDEO_DATA_ROOT,
        "out_dir": USER_OUT_DIR, "per_seq_cap": 60, "epochs": 40, 
        "batch_size": 128, "img_size": 224, "workers": 2,
        "lr": 3e-4, "weight_decay": 1e-4, "label_smoothing": 0.05, "freeze_backbone": False,
        "class_weights": True, "mixup_p": 0.3, "cutmix_p": 0.2, "mix_alpha": 0.2,
        "ema_decay": 0.999, "grad_clip": 1.0, "warmup_pct": 0.1, "tta": 4, "seed": 42,
    }
    args = argparse.Namespace(**args_dict)

    print("\n[Kaggle Setup]");
    if not torch.cuda.is_available(): print("\n!!!! WARNING: GPU ACCELERATOR NOT DETECTED !!!!\n")
    else: print(f"GPU Detected: {torch.cuda.get_device_name(0)}")

    set_seed(args.seed); device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); out_dir = Path(args.out_dir)
    ckpt_dir = out_dir / "checkpoints"; art_dir = out_dir / "artifacts"
    ckpt_dir.mkdir(parents=True, exist_ok=True); art_dir.mkdir(parents=True, exist_ok=True)
    print(f"\nSaving outputs to: {out_dir}")

    train_ds, val_ds, train_dl, val_dl = build_loaders(
        image_data_root=args.image_data_root, video_data_root=args.video_data_root,
        img_size=args.img_size, batch_size=args.batch_size, workers=args.workers, per_seq_cap=args.per_seq_cap
    )
    print(f"Master Classes: {train_ds.class_to_idx}"); idx_fake = train_ds.class_to_idx.get("fake", 0)

    class_weights_tensor = None
    if args.class_weights:
        print("Computing class weights for combined dataset...")
        sampler_indices = list(train_dl.sampler.indices); labels = [train_ds.targets[i] for i in sampler_indices]
        counts = np.bincount(labels, minlength=len(train_ds.classes)); total = counts.sum()
        weights = [total/(len(counts)*c) if c > 0 else 0.0 for c in counts]
        class_weights_tensor = torch.tensor(weights, dtype=torch.float32, device=device)
        print(f"Class counts in one epoch: {counts}. Weights: {class_weights_tensor.cpu().numpy()}")

    print("\nInitializing ResNet-50 model..."); model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    model.fc = nn.Linear(model.fc.in_features, len(train_ds.classes))
    if args.freeze_backbone: [p.requires_grad_(False) for n, p in model.named_parameters() if not n.startswith("fc.")]
    model.to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=args.label_smoothing)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    total_steps = len(train_dl) * args.epochs; warmup_steps = int(args.warmup_pct * total_steps)
    scheduler = WarmupThenCosine(optimizer, warmup_steps=warmup_steps, total_steps=total_steps)
    scaler = torch.amp.GradScaler(enabled=(device.type=="cuda"))
    ema = EMA(model, decay=args.ema_decay)
    history = defaultdict(list); best_auc, best_epoch = -1.0, -1; global_step = 0
    
    print("\nStarting Training...")
    for epoch in range(1, args.epochs+1):
        model.train(); tr_loss_sum, tr_correct, n_samples = 0.0, 0, 0
        pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{args.epochs}", leave=False)
        for x, y in pbar:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            use_mixup = random.random() < args.mixup_p; use_cutmix = (not use_mixup) and (random.random() < args.cutmix_p)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
                if use_mixup:
                    x_mix, y_a, y_b, lam = mixup_data(x, y, alpha=args.mix_alpha)
                    logits = model(x_mix); loss = mixup_criterion(criterion, logits, y_a, y_b, lam); y_for_acc = y_a
                elif use_cutmix:
                    lam = np.random.beta(args.mix_alpha, args.mix_alpha); idx_perm = torch.randperm(x.size(0), device=x.device)
                    y_a, y_b = y, y[idx_perm]; W, H = x.shape[3], x.shape[2]; x1, y1, x2, y2 = rand_bbox(W, H, lam)
                    x[:, :, y1:y2, x1:x2] = x[idx_perm, :, y1:y2, x1:x2]; lam_adj = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
                    logits = model(x); loss = mixup_criterion(criterion, logits, y_a, y_b, lam_adj); y_for_acc = y_a
                else: logits = model(x); loss = criterion(logits, y); y_for_acc = y
            scaler.scale(loss).backward(); scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer); scaler.update(); ema.update(model); scheduler.step(); global_step += 1
            bs = y.size(0); tr_loss_sum += loss.item() * bs; tr_correct += (logits.detach().argmax(1) == y_for_acc).sum().item(); n_samples += bs
            pbar.set_postfix(loss=tr_loss_sum/max(1,n_samples), acc=tr_correct/max(1,n_samples))
        tr_loss = tr_loss_sum / max(1, n_samples); tr_acc  = tr_correct / max(1, n_samples)
        state_backup = {k: v.detach().clone() for k, v in model.state_dict().items()}
        model.load_state_dict(ema.shadow, strict=True)
        val_loss, val_acc, prec, rec, f1, val_auc, cm, y_true, y_prob_pos = evaluate(model, val_dl, device, pos_index=idx_fake, tta=args.tta)
        model.load_state_dict(state_backup, strict=True)
        ckpt = {"model": ema.shadow, "epoch": epoch, "val_auc": float(val_auc), "class_to_idx": train_ds.class_to_idx, "args": args_dict}
        torch.save(ckpt, ckpt_dir / f"resnet50_epoch{epoch:02d}.pt")
        history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc); history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc); history["val_auc"].append(val_auc); history["val_f1"].append(f1)
        tqdm.write(f"Epoch {epoch:02d} | tr Ls {tr_loss:.4f} Ac {tr_acc:.4f} | val Ls {val_loss:.4f} Ac {val_acc:.4f} AUC {val_auc:.4f} F1 {f1:.4f}")
        class_names = [k for k,_ in sorted(train_ds.class_to_idx.items(), key=lambda kv: kv[1])]
        plot_cm(cm, art_dir, class_names=class_names)
        y_true_bin = (y_true == idx_fake).astype(int)
        _ = plot_roc_and_save(y_true_bin, y_prob_pos, art_dir, pos_label_name="fake")
        if not math.isnan(val_auc) and val_auc > best_auc:
            best_auc, best_epoch = val_auc, epoch; torch.save(ckpt, ckpt_dir/"resnet50_best.pt")
            tqdm.write(f"  >>> New Best AUC: {best_auc:.4f} saved.")
    plot_curves(history, art_dir);
    with open(out_dir/"history.json","w") as f: json.dump(history, f, indent=2)
    print(f"\nTraining Complete.\nBest epoch by AUC: {best_epoch} (AUC={best_auc:.4f})\nCheckpoints: {ckpt_dir}\nArtifacts: {art_dir}")

if __name__ == "__main__":
    main()


[Kaggle Setup]
GPU Detected: Tesla T4
Seed set to 42.

Saving outputs to: /kaggle/working/training_results
Loading IMAGE data from: /kaggle/input/faceforensics/WD_subset_png
Loading VIDEO data from: /kaggle/input/wilddeepfake/ffpp_subset_c23

Combined 2 dataset(s). Training samples: 47927, Validation samples: 5219
Building sequence-limited sampler for combined dataset...


Indexing sequences:   0%|          | 0/47927 [00:00<?, ?it/s]

Master Classes: {'fake': 0, 'real': 1}
Computing class weights for combined dataset...
Class counts in one epoch: [1211 1530]. Weights: [1.1317093  0.89575166]

Initializing ResNet-50 model...


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 160MB/s] 



Starting Training...


Epoch 1/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 01 | tr Ls 0.6798 Ac 0.5662 | val Ls 0.6698 Ac 0.5576 AUC 0.6793 F1 0.0625
  >>> New Best AUC: 0.6793 saved.


Epoch 2/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 02 | tr Ls 0.5364 Ac 0.7125 | val Ls 0.6688 Ac 0.5612 AUC 0.6889 F1 0.0803
  >>> New Best AUC: 0.6889 saved.


Epoch 3/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 03 | tr Ls 0.3972 Ac 0.7935 | val Ls 0.6673 Ac 0.5620 AUC 0.7027 F1 0.0856
  >>> New Best AUC: 0.7027 saved.


Epoch 4/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 04 | tr Ls 0.3373 Ac 0.8927 | val Ls 0.6664 Ac 0.5664 AUC 0.7099 F1 0.1087
  >>> New Best AUC: 0.7099 saved.


Epoch 5/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 05 | tr Ls 0.2574 Ac 0.8595 | val Ls 0.6660 Ac 0.5677 AUC 0.7139 F1 0.1330
  >>> New Best AUC: 0.7139 saved.


Epoch 6/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 06 | tr Ls 0.2697 Ac 0.8997 | val Ls 0.6657 Ac 0.5731 AUC 0.7155 F1 0.1637
  >>> New Best AUC: 0.7155 saved.


Epoch 7/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 07 | tr Ls 0.2505 Ac 0.8781 | val Ls 0.6652 Ac 0.5794 AUC 0.7185 F1 0.1992
  >>> New Best AUC: 0.7185 saved.


Epoch 8/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 08 | tr Ls 0.2476 Ac 0.8763 | val Ls 0.6650 Ac 0.5850 AUC 0.7175 F1 0.2405


Epoch 9/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 09 | tr Ls 0.2380 Ac 0.9230 | val Ls 0.6652 Ac 0.5917 AUC 0.7114 F1 0.2913


Epoch 10/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 10 | tr Ls 0.2005 Ac 0.9307 | val Ls 0.6659 Ac 0.6049 AUC 0.7043 F1 0.3612


Epoch 11/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 11 | tr Ls 0.2089 Ac 0.9489 | val Ls 0.6667 Ac 0.6095 AUC 0.6968 F1 0.4323


Epoch 12/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 12 | tr Ls 0.2399 Ac 0.8760 | val Ls 0.6674 Ac 0.6197 AUC 0.6923 F1 0.5066


Epoch 13/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 13 | tr Ls 0.2834 Ac 0.8701 | val Ls 0.6675 Ac 0.6386 AUC 0.6942 F1 0.5718


Epoch 14/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 14 | tr Ls 0.2382 Ac 0.9044 | val Ls 0.6678 Ac 0.6497 AUC 0.6965 F1 0.6193


Epoch 15/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 15 | tr Ls 0.1743 Ac 0.9289 | val Ls 0.6669 Ac 0.6507 AUC 0.7036 F1 0.6416


Epoch 16/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 16 | tr Ls 0.2038 Ac 0.9354 | val Ls 0.6660 Ac 0.6526 AUC 0.7093 F1 0.6634


Epoch 17/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 17 | tr Ls 0.2456 Ac 0.9110 | val Ls 0.6641 Ac 0.6480 AUC 0.7194 F1 0.6723
  >>> New Best AUC: 0.7194 saved.


Epoch 18/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 18 | tr Ls 0.2201 Ac 0.8767 | val Ls 0.6625 Ac 0.6371 AUC 0.7284 F1 0.6737
  >>> New Best AUC: 0.7284 saved.


Epoch 19/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 19 | tr Ls 0.2392 Ac 0.8570 | val Ls 0.6600 Ac 0.6306 AUC 0.7411 F1 0.6756
  >>> New Best AUC: 0.7411 saved.


Epoch 20/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 20 | tr Ls 0.2423 Ac 0.8763 | val Ls 0.6576 Ac 0.6177 AUC 0.7512 F1 0.6718
  >>> New Best AUC: 0.7512 saved.


Epoch 21/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 21 | tr Ls 0.2199 Ac 0.8738 | val Ls 0.6547 Ac 0.6135 AUC 0.7625 F1 0.6721
  >>> New Best AUC: 0.7625 saved.


Epoch 22/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 22 | tr Ls 0.1956 Ac 0.9296 | val Ls 0.6511 Ac 0.6093 AUC 0.7750 F1 0.6715
  >>> New Best AUC: 0.7750 saved.


Epoch 23/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 23 | tr Ls 0.2014 Ac 0.8858 | val Ls 0.6475 Ac 0.6097 AUC 0.7858 F1 0.6734
  >>> New Best AUC: 0.7858 saved.


Epoch 24/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 24 | tr Ls 0.2118 Ac 0.8599 | val Ls 0.6433 Ac 0.6103 AUC 0.7970 F1 0.6751
  >>> New Best AUC: 0.7970 saved.


Epoch 25/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 25 | tr Ls 0.2216 Ac 0.8986 | val Ls 0.6391 Ac 0.6122 AUC 0.8078 F1 0.6768
  >>> New Best AUC: 0.8078 saved.


Epoch 26/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 26 | tr Ls 0.1984 Ac 0.9460 | val Ls 0.6341 Ac 0.6168 AUC 0.8189 F1 0.6805
  >>> New Best AUC: 0.8189 saved.


Epoch 27/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 27 | tr Ls 0.1721 Ac 0.9165 | val Ls 0.6287 Ac 0.6218 AUC 0.8279 F1 0.6835
  >>> New Best AUC: 0.8279 saved.


Epoch 28/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 28 | tr Ls 0.2337 Ac 0.8614 | val Ls 0.6224 Ac 0.6269 AUC 0.8376 F1 0.6864
  >>> New Best AUC: 0.8376 saved.


Epoch 29/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 29 | tr Ls 0.2171 Ac 0.9000 | val Ls 0.6161 Ac 0.6317 AUC 0.8454 F1 0.6895
  >>> New Best AUC: 0.8454 saved.


Epoch 30/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 30 | tr Ls 0.2106 Ac 0.8435 | val Ls 0.6100 Ac 0.6390 AUC 0.8516 F1 0.6939
  >>> New Best AUC: 0.8516 saved.


Epoch 31/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 31 | tr Ls 0.2085 Ac 0.8825 | val Ls 0.6037 Ac 0.6465 AUC 0.8574 F1 0.6983
  >>> New Best AUC: 0.8574 saved.


Epoch 32/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 32 | tr Ls 0.2158 Ac 0.8325 | val Ls 0.5976 Ac 0.6526 AUC 0.8622 F1 0.7020
  >>> New Best AUC: 0.8622 saved.


Epoch 33/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 33 | tr Ls 0.1963 Ac 0.9267 | val Ls 0.5912 Ac 0.6599 AUC 0.8669 F1 0.7063
  >>> New Best AUC: 0.8669 saved.


Epoch 34/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 34 | tr Ls 0.1724 Ac 0.8172 | val Ls 0.5842 Ac 0.6660 AUC 0.8718 F1 0.7099
  >>> New Best AUC: 0.8718 saved.


Epoch 35/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 35 | tr Ls 0.2050 Ac 0.8709 | val Ls 0.5773 Ac 0.6722 AUC 0.8757 F1 0.7137
  >>> New Best AUC: 0.8757 saved.


Epoch 36/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 36 | tr Ls 0.2139 Ac 0.8858 | val Ls 0.5707 Ac 0.6775 AUC 0.8792 F1 0.7172
  >>> New Best AUC: 0.8792 saved.


Epoch 37/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 37 | tr Ls 0.2628 Ac 0.8971 | val Ls 0.5636 Ac 0.6815 AUC 0.8829 F1 0.7195
  >>> New Best AUC: 0.8829 saved.


Epoch 38/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 38 | tr Ls 0.1802 Ac 0.8690 | val Ls 0.5575 Ac 0.6858 AUC 0.8854 F1 0.7222
  >>> New Best AUC: 0.8854 saved.


Epoch 39/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 39 | tr Ls 0.2388 Ac 0.8530 | val Ls 0.5508 Ac 0.6911 AUC 0.8881 F1 0.7256
  >>> New Best AUC: 0.8881 saved.


Epoch 40/40:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 40 | tr Ls 0.2053 Ac 0.8854 | val Ls 0.5440 Ac 0.6973 AUC 0.8908 F1 0.7294
  >>> New Best AUC: 0.8908 saved.

Training Complete.
Best epoch by AUC: 40 (AUC=0.8908)
Checkpoints: /kaggle/working/training_results/checkpoints
Artifacts: /kaggle/working/training_results/artifacts


In [2]:
"""
ResNet-50 Deepfake Detector - Kaggle GPU Optimized
VERSION FOR COMBINED IMAGE & VIDEO DATASETS WITH RESUME CAPABILITY
"""

import argparse, io, json, math, os, random, time, sys
from pathlib import Path
from collections import defaultdict

# --- [SETUP] Install necessary libraries ---
try:
    import cv2
except ImportError:
    print("OpenCV not found. Installing opencv-python-headless...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "opencv-python-headless"])
    import cv2

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler, ConcatDataset
from torchvision import datasets, transforms, models

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, roc_auc_score, roc_curve
)

# ----------------------- Utilities & Repro -----------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.benchmark = True
    print(f"Seed set to {seed}.")

# ----------------------- Social-style Augmentations -----------------------
class RandomJPEGCompression:
    def __init__(self, qmin=35, qmax=92, p=0.7): self.qmin, self.qmax, self.p = qmin, qmax, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        buf = io.BytesIO(); img.save(buf, format="JPEG", quality=random.randint(self.qmin, self.qmax))
        buf.seek(0); return Image.open(buf).convert("RGB")
class RandomDownscale:
    def __init__(self, scale_min=0.4, scale_max=0.85, p=0.6): self.scale_min, self.scale_max, self.p = scale_min, scale_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        w, h = img.size; s = random.uniform(self.scale_min, self.scale_max)
        nw, nh = max(8, int(w*s)), max(8, int(h*s))
        img = img.resize((nw, nh), resample=Image.BILINEAR); img = img.resize((w, h), resample=Image.BILINEAR)
        return img
class RandomGaussianNoise:
    def __init__(self, sigma_min=0.0, sigma_max=0.03, p=0.4): self.sigma_min, self.sigma_max, self.p = sigma_min, sigma_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        arr = np.asarray(img).astype(np.float32)/255.0; sigma = random.uniform(self.sigma_min, self.sigma_max)
        noise = np.random.normal(0.0, sigma, arr.shape).astype(np.float32); arr = np.clip(arr + noise, 0.0, 1.0)
        arr = (arr*255.0 + 0.5).astype(np.uint8); return Image.fromarray(arr)

# ----------------------- Custom Video Dataset -----------------------
class VideoFrameDataset(torch.utils.data.Dataset):
    def __init__(self, root: Path, transform=None, class_to_idx=None):
        self.root = root; self.transform = transform; self.class_to_idx = class_to_idx
        self.samples = []; self.targets = []
        for class_name in self.class_to_idx.keys():
            class_dir = self.root / class_name
            if not class_dir.exists(): continue
            for video_path in sorted(list(class_dir.rglob('*.mp4'))):
                label = self.class_to_idx[class_name]
                self.samples.append((str(video_path), label)); self.targets.append(label)
        if not self.samples: print(f"  --> WARNING: No .mp4 videos found in {self.root}")
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        cap = None
        try:
            cap = cv2.VideoCapture(video_path)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames < 1: return self.__getitem__((idx + 1) % len(self))
            frame_idx = random.randint(0, total_frames - 1)
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret: return self.__getitem__((idx + 1) % len(self))
        except Exception as e:
            print(f"  --> Error processing {video_path}: {e}. Skipping."); return self.__getitem__((idx + 1) % len(self))
        finally:
            if cap: cap.release()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB); img = Image.fromarray(frame_rgb)
        if self.transform: img = self.transform(img)
        return img, label

# ----------------------- Data Loading Pipeline -----------------------
def build_indices_limit_per_seq(dataset: ConcatDataset, max_per_seq=60, seed=42):
    if max_per_seq is None or max_per_seq <= 0: return list(range(len(dataset)))
    by_seq = defaultdict(list)
    print("Building sequence-limited sampler for combined dataset...")
    for idx, (path, _) in enumerate(tqdm(dataset.samples, desc="Indexing sequences")):
        path_obj = Path(path)
        seq_id = path_obj.stem if path_obj.suffix.lower() == '.mp4' else path_obj.parent.name
        by_seq[seq_id].append(idx)
    rng = random.Random(seed); indices = []
    for seq, idxs in by_seq.items(): rng.shuffle(idxs); indices.extend(idxs[:max_per_seq])
    rng.shuffle(indices); return indices

def build_loaders(image_data_root, video_data_root, img_size=224, batch_size=64, workers=2, per_seq_cap=60):
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    train_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.2)), transforms.RandomResizedCrop(img_size, scale=(0.55, 1.0), ratio=(0.75, 1.33)),
        transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=10),
        transforms.RandomPerspective(distortion_scale=0.25, p=0.25), RandomJPEGCompression(qmin=35, qmax=92, p=0.7),
        RandomDownscale(scale_min=0.4, scale_max=0.85, p=0.6), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)),
        transforms.ColorJitter(0.15, 0.15, 0.15, 0.05), transforms.RandomAdjustSharpness(sharpness_factor=0.6, p=0.3),
        transforms.RandomAutocontrast(p=0.3), RandomGaussianNoise(sigma_min=0.0, sigma_max=0.03, p=0.4),
        transforms.ToTensor(), normalize, transforms.RandomErasing(p=0.25, scale=(0.02,0.2), ratio=(0.3,3.3), value='random'),
    ])
    val_tf = transforms.Compose([transforms.Resize(int(img_size*1.15)), transforms.CenterCrop(img_size), transforms.ToTensor(), normalize])
    train_datasets, val_datasets = [], []; master_class_to_idx = {'fake': 0, 'real': 1}
    if image_data_root:
        root = Path(image_data_root); print(f"Loading IMAGE data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(datasets.ImageFolder(root/"train", transform=train_tf))
            val_datasets.append(datasets.ImageFolder(root/"val", transform=val_tf))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if video_data_root:
        root = Path(video_data_root); print(f"Loading VIDEO data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(VideoFrameDataset(root/"train", transform=train_tf, class_to_idx=master_class_to_idx))
            val_datasets.append(VideoFrameDataset(root/"val", transform=val_tf, class_to_idx=master_class_to_idx))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if not train_datasets: raise ValueError("No valid datasets were loaded. Check paths in USER CONFIGURATION.")
    train_ds = ConcatDataset(train_datasets); val_ds = ConcatDataset(val_datasets)
    train_ds.class_to_idx = master_class_to_idx; train_ds.classes = list(master_class_to_idx.keys())
    train_ds.samples = [s for ds in train_ds.datasets for s in ds.samples]
    train_ds.targets = [t for ds in train_ds.datasets for t in ds.targets]
    val_ds.targets = [t for ds in val_ds.datasets for t in ds.targets]
    print(f"\nCombined {len(train_datasets)} dataset(s). Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}")
    indices = build_indices_limit_per_seq(train_ds, max_per_seq=per_seq_cap, seed=42)
    num_workers = 2 if os.cpu_count() <= 4 else 4
    train_dl = DataLoader(train_ds, batch_size=batch_size, sampler=SubsetRandomSampler(indices), num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    return train_ds, val_ds, train_dl, val_dl

# --- [The rest of the helper functions are unchanged] ---
class WarmupThenCosine(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1): self.warmup_steps=max(1,warmup_steps); self.total_steps=max(self.warmup_steps+1,total_steps); super().__init__(optimizer, last_epoch)
    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.warmup_steps: return [base * (step/self.warmup_steps) for base in self.base_lrs]
        t = (step-self.warmup_steps)/(self.total_steps-self.warmup_steps); return [base*0.5*(1+math.cos(math.pi*t)) for base in self.base_lrs]
class EMA:
    def __init__(self, model, decay=0.999): self.decay = decay; self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if not v.is_floating_point(): self.shadow[k] = v.detach().clone(); continue
            if (self.shadow[k].dtype!=v.dtype)or(self.shadow[k].device!=v.device): self.shadow[k]=v.detach().clone(); continue
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
def rand_bbox(W, H, lam):
    cut_w=int(W*math.sqrt(1-lam)); cut_h=int(H*math.sqrt(1-lam)); cx,cy=random.randint(0,W-1),random.randint(0,H-1)
    x1,y1=max(cx-cut_w//2,0),max(cy-cut_h//2,0); x2,y2=min(cx+cut_w//2,W),min(cy+cut_h//2,H); return x1,y1,x2,y2
def mixup_data(x, y, alpha=0.2):
    lam=np.random.beta(alpha,alpha) if alpha>0 else 1.0; idx=torch.randperm(x.size(0),device=x.device)
    mixed_x=lam*x+(1-lam)*x[idx]; y_a,y_b=y,y[idx]; return mixed_x,y_a,y_b,lam
def mixup_criterion(criterion, pred, y_a, y_b, lam): return lam*criterion(pred,y_a)+(1-lam)*criterion(pred,y_b)
@torch.no_grad()
def tta_logits(model,x,device,tta=4):
    outs=[]
    for op in range(tta):
        xx=x;
        if op%2==1: xx=torch.flip(xx,dims=[-1])
        with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")): outs.append(model(xx))
    return torch.stack(outs,0).mean(0)
@torch.no_grad()
def evaluate(model,dl,device,pos_index,tta=4):
    model.eval(); y_true,y_pred,y_prob_pos=[],[],[]; total_loss,n=0.0,0
    criterion=nn.CrossEntropyLoss()
    for x,y in dl:
        x,y=x.to(device,non_blocking=True),y.to(device,non_blocking=True)
        logits=tta_logits(model,x,device,tta=tta) if tta and tta>1 else model(x)
        with torch.amp.autocast(device_type=device.type,enabled=(device.type=="cuda")): loss=criterion(logits,y)
        prob=torch.softmax(logits,dim=1)
        y_true.extend(y.cpu().tolist());y_pred.extend(prob.argmax(1).cpu().tolist());y_prob_pos.extend(prob[:,pos_index].cpu().tolist())
        bs=y.size(0);total_loss+=loss.item()*bs;n+=bs
    acc=accuracy_score(y_true,y_pred); prec,rec,f1,_=precision_recall_fscore_support(y_true,y_pred,average="binary",pos_label=pos_index,zero_division=0)
    try: auc=roc_auc_score([1 if t==pos_index else 0 for t in y_true], y_prob_pos)
    except Exception: auc=float("nan")
    cm=confusion_matrix(y_true,y_pred,labels=[0,1]); return (total_loss/n,acc,prec,rec,f1,auc,np.array(cm),np.array(y_true),np.array(y_prob_pos))
def plot_curves(history,out_dir:Path):
    out_dir.mkdir(parents=True,exist_ok=True);epochs=np.arange(1,len(history["train_loss"])+1)
    plt.figure();plt.plot(epochs,history["train_loss"],label="train loss");plt.plot(epochs,history["val_loss"],label="val loss");plt.xlabel("epoch");plt.ylabel("loss");plt.legend();plt.tight_layout();plt.savefig(out_dir/"loss_curves.png");plt.close()
    plt.figure();plt.plot(epochs,history["train_acc"],label="train acc");plt.plot(epochs,history["val_acc"],label="val acc");plt.xlabel("epoch");plt.ylabel("accuracy");plt.legend();plt.tight_layout();plt.savefig(out_dir/"acc_curves.png");plt.close()
def plot_cm(cm,out_dir:Path,class_names):
    plt.figure();plt.imshow(cm,interpolation="nearest");plt.title("Confusion Matrix");plt.colorbar();ticks=np.arange(len(class_names));plt.xticks(ticks,class_names,rotation=45);plt.yticks(ticks,class_names);thresh=cm.max()/2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]): plt.text(j,i,int(cm[i,j]),ha="center",va="center",color="white" if cm[i,j]>thresh else "black")
    plt.ylabel("True");plt.xlabel("Predicted");plt.tight_layout();plt.savefig(out_dir/"confusion_matrix.png");plt.close()
def plot_roc_and_save(y_true_bin,y_prob_pos,out_dir:Path,pos_label_name="fake"):
    fpr,tpr,_=roc_curve(y_true_bin,y_prob_pos);auc=roc_auc_score(y_true_bin,y_prob_pos)
    plt.figure();plt.plot(fpr,tpr,label=f"AUC={auc:.3f}");plt.plot([0,1],[0,1],"--");plt.xlabel("FPR");plt.ylabel("TPR");plt.title(f"ROC ({pos_label_name} positive)");plt.legend();plt.tight_layout();plt.savefig(out_dir/"roc_curve.png");plt.close();return auc

# ----------------------- Main -----------------------
def main():
    
    # =================================================================================
    # ===> KAGGLE USER CONFIGURATION SECTION (CRITICAL!) <===
    USER_IMAGE_DATA_ROOT = "/kaggle/input/faceforensics/WD_subset_png"
    USER_VIDEO_DATA_ROOT = "/kaggle/input/wilddeepfake/ffpp_subset_c23"
    USER_OUT_DIR = "/kaggle/working/training_results"

    # ===> NEW: RESUME TRAINING CONFIGURATION <===
    RESUME_TRAINING = True # SET THIS TO True TO RESUME
    # This path points to the 'latest' checkpoint inside your output directory
    CHECKPOINT_PATH = f"{USER_OUT_DIR}/checkpoints/resnet50_epoch40.pt"
    # =================================================================================

    args_dict = {
        "image_data_root": USER_IMAGE_DATA_ROOT, "video_data_root": USER_VIDEO_DATA_ROOT,
        "out_dir": USER_OUT_DIR, "per_seq_cap": 60, 
        "epochs": 60, # <== INCREASE THE TOTAL NUMBER OF EPOCHS
        "batch_size": 128, "img_size": 224, "workers": 2,
        "lr": 3e-4, "weight_decay": 1e-4, "label_smoothing": 0.05, "freeze_backbone": False,
        "class_weights": True, "mixup_p": 0.3, "cutmix_p": 0.2, "mix_alpha": 0.2,
        "ema_decay": 0.999, "grad_clip": 1.0, "warmup_pct": 0.1, "tta": 4, "seed": 42,
    }
    args = argparse.Namespace(**args_dict)

    print("\n[Kaggle Setup]");
    if not torch.cuda.is_available(): print("\n!!!! WARNING: GPU ACCELERATOR NOT DETECTED !!!!\n")
    else: print(f"GPU Detected: {torch.cuda.get_device_name(0)}")

    set_seed(args.seed); device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); out_dir = Path(args.out_dir)
    ckpt_dir = out_dir / "checkpoints"; art_dir = out_dir / "artifacts"
    ckpt_dir.mkdir(parents=True, exist_ok=True); art_dir.mkdir(parents=True, exist_ok=True)
    print(f"\nSaving outputs to: {out_dir}")

    train_ds, val_ds, train_dl, val_dl = build_loaders(
        image_data_root=args.image_data_root, video_data_root=args.video_data_root,
        img_size=args.img_size, batch_size=args.batch_size, workers=args.workers, per_seq_cap=args.per_seq_cap
    )
    print(f"Master Classes: {train_ds.class_to_idx}"); idx_fake = train_ds.class_to_idx.get("fake", 0)

    class_weights_tensor = None
    if args.class_weights:
        print("Computing class weights for combined dataset...")
        sampler_indices = list(train_dl.sampler.indices); labels = [train_ds.targets[i] for i in sampler_indices]
        counts = np.bincount(labels, minlength=len(train_ds.classes)); total = counts.sum()
        weights = [total/(len(counts)*c) if c > 0 else 0.0 for c in counts]
        class_weights_tensor = torch.tensor(weights, dtype=torch.float32, device=device)
        print(f"Class counts in one epoch: {counts}. Weights: {class_weights_tensor.cpu().numpy()}")

    print("\nInitializing ResNet-50 model..."); model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    model.fc = nn.Linear(model.fc.in_features, len(train_ds.classes))
    if args.freeze_backbone: [p.requires_grad_(False) for n, p in model.named_parameters() if not n.startswith("fc.")]
    model.to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=args.label_smoothing)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    total_steps = len(train_dl) * args.epochs; warmup_steps = int(args.warmup_pct * total_steps)
    scheduler = WarmupThenCosine(optimizer, warmup_steps=warmup_steps, total_steps=total_steps)
    scaler = torch.amp.GradScaler(enabled=(device.type=="cuda"))
    ema = EMA(model, decay=args.ema_decay)
    history = defaultdict(list); best_auc, best_epoch = -1.0, -1; global_step = 0; start_epoch = 1

    # ===> NEW: LOGIC TO LOAD CHECKPOINT AND RESUME <===
    if RESUME_TRAINING:
        print("\n[Resuming Training]")
        if not Path(CHECKPOINT_PATH).exists():
            print(f"  !!! WARNING: Checkpoint not found at {CHECKPOINT_PATH}. Starting from scratch. !!!")
        else:
            # For modern PyTorch, weights_only=False is needed for non-tensor data
            ckpt = torch.load(CHECKPOINT_PATH, map_location='cpu', weights_only=False)
            
            # The saved model is the EMA model
            model.load_state_dict(ckpt['model'], strict=True) 
            model.to(device)
            ema.shadow = {k: v.to(device) for k, v in ckpt['model'].items()}
            
            start_epoch = ckpt.get('epoch', 0) + 1
            
            print(f"  Fast-forwarding LR scheduler to step of epoch {start_epoch - 1}...")
            steps_to_advance = (start_epoch - 1) * len(train_dl)
            for _ in range(steps_to_advance): scheduler.step()

            history_path = out_dir / "history.json"
            if history_path.exists():
                with open(history_path, 'r') as f: history = defaultdict(list, json.load(f))
                if history.get('val_auc'): best_auc = max(history.get('val_auc', [0.0]))
            print(f"  --> Checkpoint loaded. Resuming from Epoch {start_epoch}. Previous best AUC: {best_auc:.4f}")

    print("\nStarting Training...")
    # The loop now starts from 'start_epoch'
    for epoch in range(start_epoch, args.epochs + 1):
        model.train(); tr_loss_sum, tr_correct, n_samples = 0.0, 0, 0
        pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{args.epochs}", leave=False)
        for x, y in pbar:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            use_mixup = random.random() < args.mixup_p; use_cutmix = (not use_mixup) and (random.random() < args.cutmix_p)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
                if use_mixup:
                    x_mix, y_a, y_b, lam = mixup_data(x, y, alpha=args.mix_alpha)
                    logits = model(x_mix); loss = mixup_criterion(criterion, logits, y_a, y_b, lam); y_for_acc = y_a
                elif use_cutmix:
                    lam = np.random.beta(args.mix_alpha, args.mix_alpha); idx_perm = torch.randperm(x.size(0), device=x.device)
                    y_a, y_b = y, y[idx_perm]; W, H = x.shape[3], x.shape[2]; x1, y1, x2, y2 = rand_bbox(W, H, lam)
                    x[:, :, y1:y2, x1:x2] = x[idx_perm, :, y1:y2, x1:x2]; lam_adj = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
                    logits = model(x); loss = mixup_criterion(criterion, logits, y_a, y_b, lam_adj); y_for_acc = y_a
                else: logits = model(x); loss = criterion(logits, y); y_for_acc = y
            scaler.scale(loss).backward(); scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer); scaler.update(); ema.update(model); scheduler.step(); global_step += 1
            bs = y.size(0); tr_loss_sum += loss.item() * bs; tr_correct += (logits.detach().argmax(1) == y_for_acc).sum().item(); n_samples += bs
            pbar.set_postfix(loss=tr_loss_sum/max(1,n_samples), acc=tr_correct/max(1,n_samples))
        tr_loss = tr_loss_sum / max(1, n_samples); tr_acc  = tr_correct / max(1, n_samples)
        state_backup = {k: v.detach().clone() for k, v in model.state_dict().items()}
        model.load_state_dict(ema.shadow, strict=True)
        val_loss, val_acc, prec, rec, f1, val_auc, cm, y_true, y_prob_pos = evaluate(model, val_dl, device, pos_index=idx_fake, tta=args.tta)
        model.load_state_dict(state_backup, strict=True)
        
        # We now save optimizer and scheduler states for even better resumption next time
        ckpt = {
            "model": ema.shadow, "epoch": epoch, "val_auc": float(val_auc), 
            "class_to_idx": train_ds.class_to_idx, "args": args_dict,
            "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict()
        }
        torch.save(ckpt, ckpt_dir / f"resnet50_epoch{epoch:02d}.pt")
        # Overwrite a 'latest' file for easy resuming
        torch.save(ckpt, ckpt_dir / "resnet50_latest.pt")
        
        history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc); history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc); history["val_auc"].append(val_auc); history["val_f1"].append(f1)
        tqdm.write(f"Epoch {epoch:02d} | tr Ls {tr_loss:.4f} Ac {tr_acc:.4f} | val Ls {val_loss:.4f} Ac {val_acc:.4f} AUC {val_auc:.4f} F1 {f1:.4f}")
        class_names = [k for k,_ in sorted(train_ds.class_to_idx.items(), key=lambda kv: kv[1])]
        plot_cm(cm, art_dir, class_names=class_names)
        y_true_bin = (y_true == idx_fake).astype(int)
        _ = plot_roc_and_save(y_true_bin, y_prob_pos, art_dir, pos_label_name="fake")
        if not math.isnan(val_auc) and val_auc > best_auc:
            best_auc, best_epoch = val_auc, epoch; torch.save(ckpt, ckpt_dir/"resnet50_best.pt")
            tqdm.write(f"  >>> New Best AUC: {best_auc:.4f} saved.")
    plot_curves(history, art_dir);
    with open(out_dir/"history.json","w") as f: json.dump(history, f, indent=2)
    print(f"\nTraining Complete.\nBest epoch by AUC: {best_epoch} (AUC={best_auc:.4f})\nCheckpoints: {ckpt_dir}\nArtifacts: {art_dir}")

if __name__ == "__main__":
    main()


[Kaggle Setup]
GPU Detected: Tesla T4
Seed set to 42.

Saving outputs to: /kaggle/working/training_results
Loading IMAGE data from: /kaggle/input/faceforensics/WD_subset_png
Loading VIDEO data from: /kaggle/input/wilddeepfake/ffpp_subset_c23

Combined 2 dataset(s). Training samples: 47927, Validation samples: 5219
Building sequence-limited sampler for combined dataset...


Indexing sequences:   0%|          | 0/47927 [00:00<?, ?it/s]

Master Classes: {'fake': 0, 'real': 1}
Computing class weights for combined dataset...
Class counts in one epoch: [1211 1530]. Weights: [1.1317093  0.89575166]

Initializing ResNet-50 model...

[Resuming Training]
  Fast-forwarding LR scheduler to step of epoch 40...
  --> Checkpoint loaded. Resuming from Epoch 41. Previous best AUC: 0.8908

Starting Training...




Epoch 41/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 41 | tr Ls 0.2965 Ac 0.8548 | val Ls 0.5404 Ac 0.7007 AUC 0.8928 F1 0.7316
  >>> New Best AUC: 0.8928 saved.


Epoch 42/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 42 | tr Ls 0.1863 Ac 0.8599 | val Ls 0.5336 Ac 0.7078 AUC 0.8958 F1 0.7363
  >>> New Best AUC: 0.8958 saved.


Epoch 43/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 43 | tr Ls 0.1977 Ac 0.8792 | val Ls 0.5273 Ac 0.7143 AUC 0.8988 F1 0.7406
  >>> New Best AUC: 0.8988 saved.


Epoch 44/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 44 | tr Ls 0.2207 Ac 0.9511 | val Ls 0.5215 Ac 0.7204 AUC 0.9013 F1 0.7447
  >>> New Best AUC: 0.9013 saved.


Epoch 45/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 45 | tr Ls 0.1769 Ac 0.8906 | val Ls 0.5145 Ac 0.7275 AUC 0.9044 F1 0.7496
  >>> New Best AUC: 0.9044 saved.


Epoch 46/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 46 | tr Ls 0.2099 Ac 0.9281 | val Ls 0.5083 Ac 0.7363 AUC 0.9067 F1 0.7557
  >>> New Best AUC: 0.9067 saved.


Epoch 47/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 47 | tr Ls 0.1984 Ac 0.8913 | val Ls 0.5000 Ac 0.7455 AUC 0.9100 F1 0.7621
  >>> New Best AUC: 0.9100 saved.


Epoch 48/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 48 | tr Ls 0.1949 Ac 0.9004 | val Ls 0.4925 Ac 0.7519 AUC 0.9125 F1 0.7665
  >>> New Best AUC: 0.9125 saved.


Epoch 49/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 49 | tr Ls 0.2008 Ac 0.9383 | val Ls 0.4862 Ac 0.7565 AUC 0.9150 F1 0.7697
  >>> New Best AUC: 0.9150 saved.


Epoch 50/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 50 | tr Ls 0.1775 Ac 0.9351 | val Ls 0.4805 Ac 0.7601 AUC 0.9166 F1 0.7720
  >>> New Best AUC: 0.9166 saved.


Epoch 51/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 51 | tr Ls 0.1899 Ac 0.9584 | val Ls 0.4751 Ac 0.7662 AUC 0.9179 F1 0.7765
  >>> New Best AUC: 0.9179 saved.


Epoch 52/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 52 | tr Ls 0.2185 Ac 0.8854 | val Ls 0.4700 Ac 0.7701 AUC 0.9190 F1 0.7789
  >>> New Best AUC: 0.9190 saved.


Epoch 53/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 53 | tr Ls 0.2573 Ac 0.8822 | val Ls 0.4655 Ac 0.7749 AUC 0.9203 F1 0.7824
  >>> New Best AUC: 0.9203 saved.


Epoch 54/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 54 | tr Ls 0.2191 Ac 0.9033 | val Ls 0.4611 Ac 0.7783 AUC 0.9214 F1 0.7847
  >>> New Best AUC: 0.9214 saved.


Epoch 55/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 55 | tr Ls 0.1617 Ac 0.9347 | val Ls 0.4561 Ac 0.7821 AUC 0.9224 F1 0.7870
  >>> New Best AUC: 0.9224 saved.


Epoch 56/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 56 | tr Ls 0.1907 Ac 0.9424 | val Ls 0.4518 Ac 0.7856 AUC 0.9229 F1 0.7895
  >>> New Best AUC: 0.9229 saved.


Epoch 57/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 57 | tr Ls 0.2312 Ac 0.9161 | val Ls 0.4477 Ac 0.7877 AUC 0.9236 F1 0.7907
  >>> New Best AUC: 0.9236 saved.


Epoch 58/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 58 | tr Ls 0.2095 Ac 0.8738 | val Ls 0.4430 Ac 0.7908 AUC 0.9245 F1 0.7929
  >>> New Best AUC: 0.9245 saved.


Epoch 59/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 59 | tr Ls 0.2277 Ac 0.8599 | val Ls 0.4384 Ac 0.7956 AUC 0.9254 F1 0.7963
  >>> New Best AUC: 0.9254 saved.


Epoch 60/60:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 60 | tr Ls 0.2316 Ac 0.8741 | val Ls 0.4340 Ac 0.7980 AUC 0.9262 F1 0.7979
  >>> New Best AUC: 0.9262 saved.

Training Complete.
Best epoch by AUC: 60 (AUC=0.9262)
Checkpoints: /kaggle/working/training_results/checkpoints
Artifacts: /kaggle/working/training_results/artifacts


In [3]:
"""
ResNet-50 Deepfake Detector - Kaggle GPU Optimized
VERSION FOR COMBINED IMAGE & VIDEO DATASETS WITH RESUME CAPABILITY
"""

import argparse, io, json, math, os, random, time, sys
from pathlib import Path
from collections import defaultdict

# --- [SETUP] Install necessary libraries ---
try:
    import cv2
except ImportError:
    print("OpenCV not found. Installing opencv-python-headless...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "opencv-python-headless"])
    import cv2

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler, ConcatDataset
from torchvision import datasets, transforms, models

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, roc_auc_score, roc_curve
)

# ----------------------- Utilities & Repro -----------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.benchmark = True
    print(f"Seed set to {seed}.")

# ----------------------- Social-style Augmentations -----------------------
class RandomJPEGCompression:
    def __init__(self, qmin=35, qmax=92, p=0.7): self.qmin, self.qmax, self.p = qmin, qmax, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        buf = io.BytesIO(); img.save(buf, format="JPEG", quality=random.randint(self.qmin, self.qmax))
        buf.seek(0); return Image.open(buf).convert("RGB")
class RandomDownscale:
    def __init__(self, scale_min=0.4, scale_max=0.85, p=0.6): self.scale_min, self.scale_max, self.p = scale_min, scale_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        w, h = img.size; s = random.uniform(self.scale_min, self.scale_max)
        nw, nh = max(8, int(w*s)), max(8, int(h*s))
        img = img.resize((nw, nh), resample=Image.BILINEAR); img = img.resize((w, h), resample=Image.BILINEAR)
        return img
class RandomGaussianNoise:
    def __init__(self, sigma_min=0.0, sigma_max=0.03, p=0.4): self.sigma_min, self.sigma_max, self.p = sigma_min, sigma_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        arr = np.asarray(img).astype(np.float32)/255.0; sigma = random.uniform(self.sigma_min, self.sigma_max)
        noise = np.random.normal(0.0, sigma, arr.shape).astype(np.float32); arr = np.clip(arr + noise, 0.0, 1.0)
        arr = (arr*255.0 + 0.5).astype(np.uint8); return Image.fromarray(arr)

# ----------------------- Custom Video Dataset -----------------------
class VideoFrameDataset(torch.utils.data.Dataset):
    def __init__(self, root: Path, transform=None, class_to_idx=None):
        self.root = root; self.transform = transform; self.class_to_idx = class_to_idx
        self.samples = []; self.targets = []
        for class_name in self.class_to_idx.keys():
            class_dir = self.root / class_name
            if not class_dir.exists(): continue
            for video_path in sorted(list(class_dir.rglob('*.mp4'))):
                label = self.class_to_idx[class_name]
                self.samples.append((str(video_path), label)); self.targets.append(label)
        if not self.samples: print(f"  --> WARNING: No .mp4 videos found in {self.root}")
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        cap = None
        try:
            cap = cv2.VideoCapture(video_path)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames < 1: return self.__getitem__((idx + 1) % len(self))
            frame_idx = random.randint(0, total_frames - 1)
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret: return self.__getitem__((idx + 1) % len(self))
        except Exception as e:
            print(f"  --> Error processing {video_path}: {e}. Skipping."); return self.__getitem__((idx + 1) % len(self))
        finally:
            if cap: cap.release()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB); img = Image.fromarray(frame_rgb)
        if self.transform: img = self.transform(img)
        return img, label

# ----------------------- Data Loading Pipeline -----------------------
def build_indices_limit_per_seq(dataset: ConcatDataset, max_per_seq=60, seed=42):
    if max_per_seq is None or max_per_seq <= 0: return list(range(len(dataset)))
    by_seq = defaultdict(list)
    print("Building sequence-limited sampler for combined dataset...")
    for idx, (path, _) in enumerate(tqdm(dataset.samples, desc="Indexing sequences")):
        path_obj = Path(path)
        seq_id = path_obj.stem if path_obj.suffix.lower() == '.mp4' else path_obj.parent.name
        by_seq[seq_id].append(idx)
    rng = random.Random(seed); indices = []
    for seq, idxs in by_seq.items(): rng.shuffle(idxs); indices.extend(idxs[:max_per_seq])
    rng.shuffle(indices); return indices

def build_loaders(image_data_root, video_data_root, img_size=224, batch_size=64, workers=2, per_seq_cap=60):
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    train_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.2)), transforms.RandomResizedCrop(img_size, scale=(0.55, 1.0), ratio=(0.75, 1.33)),
        transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=10),
        transforms.RandomPerspective(distortion_scale=0.25, p=0.25), RandomJPEGCompression(qmin=35, qmax=92, p=0.7),
        RandomDownscale(scale_min=0.4, scale_max=0.85, p=0.6), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)),
        transforms.ColorJitter(0.15, 0.15, 0.15, 0.05), transforms.RandomAdjustSharpness(sharpness_factor=0.6, p=0.3),
        transforms.RandomAutocontrast(p=0.3), RandomGaussianNoise(sigma_min=0.0, sigma_max=0.03, p=0.4),
        transforms.ToTensor(), normalize, transforms.RandomErasing(p=0.25, scale=(0.02,0.2), ratio=(0.3,3.3), value='random'),
    ])
    val_tf = transforms.Compose([transforms.Resize(int(img_size*1.15)), transforms.CenterCrop(img_size), transforms.ToTensor(), normalize])
    train_datasets, val_datasets = [], []; master_class_to_idx = {'fake': 0, 'real': 1}
    if image_data_root:
        root = Path(image_data_root); print(f"Loading IMAGE data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(datasets.ImageFolder(root/"train", transform=train_tf))
            val_datasets.append(datasets.ImageFolder(root/"val", transform=val_tf))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if video_data_root:
        root = Path(video_data_root); print(f"Loading VIDEO data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(VideoFrameDataset(root/"train", transform=train_tf, class_to_idx=master_class_to_idx))
            val_datasets.append(VideoFrameDataset(root/"val", transform=val_tf, class_to_idx=master_class_to_idx))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if not train_datasets: raise ValueError("No valid datasets were loaded. Check paths in USER CONFIGURATION.")
    train_ds = ConcatDataset(train_datasets); val_ds = ConcatDataset(val_datasets)
    train_ds.class_to_idx = master_class_to_idx; train_ds.classes = list(master_class_to_idx.keys())
    train_ds.samples = [s for ds in train_ds.datasets for s in ds.samples]
    train_ds.targets = [t for ds in train_ds.datasets for t in ds.targets]
    val_ds.targets = [t for ds in val_ds.datasets for t in ds.targets]
    print(f"\nCombined {len(train_datasets)} dataset(s). Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}")
    indices = build_indices_limit_per_seq(train_ds, max_per_seq=per_seq_cap, seed=42)
    num_workers = 2 if os.cpu_count() <= 4 else 4
    train_dl = DataLoader(train_ds, batch_size=batch_size, sampler=SubsetRandomSampler(indices), num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    return train_ds, val_ds, train_dl, val_dl

# --- [The rest of the helper functions are unchanged] ---
class WarmupThenCosine(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1): self.warmup_steps=max(1,warmup_steps); self.total_steps=max(self.warmup_steps+1,total_steps); super().__init__(optimizer, last_epoch)
    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.warmup_steps: return [base * (step/self.warmup_steps) for base in self.base_lrs]
        t = (step-self.warmup_steps)/(self.total_steps-self.warmup_steps); return [base*0.5*(1+math.cos(math.pi*t)) for base in self.base_lrs]
class EMA:
    def __init__(self, model, decay=0.999): self.decay = decay; self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if not v.is_floating_point(): self.shadow[k] = v.detach().clone(); continue
            if (self.shadow[k].dtype!=v.dtype)or(self.shadow[k].device!=v.device): self.shadow[k]=v.detach().clone(); continue
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
def rand_bbox(W, H, lam):
    cut_w=int(W*math.sqrt(1-lam)); cut_h=int(H*math.sqrt(1-lam)); cx,cy=random.randint(0,W-1),random.randint(0,H-1)
    x1,y1=max(cx-cut_w//2,0),max(cy-cut_h//2,0); x2,y2=min(cx+cut_w//2,W),min(cy+cut_h//2,H); return x1,y1,x2,y2
def mixup_data(x, y, alpha=0.2):
    lam=np.random.beta(alpha,alpha) if alpha>0 else 1.0; idx=torch.randperm(x.size(0),device=x.device)
    mixed_x=lam*x+(1-lam)*x[idx]; y_a,y_b=y,y[idx]; return mixed_x,y_a,y_b,lam
def mixup_criterion(criterion, pred, y_a, y_b, lam): return lam*criterion(pred,y_a)+(1-lam)*criterion(pred,y_b)
@torch.no_grad()
def tta_logits(model,x,device,tta=4):
    outs=[]
    for op in range(tta):
        xx=x;
        if op%2==1: xx=torch.flip(xx,dims=[-1])
        with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")): outs.append(model(xx))
    return torch.stack(outs,0).mean(0)
@torch.no_grad()
def evaluate(model,dl,device,pos_index,tta=4):
    model.eval(); y_true,y_pred,y_prob_pos=[],[],[]; total_loss,n=0.0,0
    criterion=nn.CrossEntropyLoss()
    for x,y in dl:
        x,y=x.to(device,non_blocking=True),y.to(device,non_blocking=True)
        logits=tta_logits(model,x,device,tta=tta) if tta and tta>1 else model(x)
        with torch.amp.autocast(device_type=device.type,enabled=(device.type=="cuda")): loss=criterion(logits,y)
        prob=torch.softmax(logits,dim=1)
        y_true.extend(y.cpu().tolist());y_pred.extend(prob.argmax(1).cpu().tolist());y_prob_pos.extend(prob[:,pos_index].cpu().tolist())
        bs=y.size(0);total_loss+=loss.item()*bs;n+=bs
    acc=accuracy_score(y_true,y_pred); prec,rec,f1,_=precision_recall_fscore_support(y_true,y_pred,average="binary",pos_label=pos_index,zero_division=0)
    try: auc=roc_auc_score([1 if t==pos_index else 0 for t in y_true], y_prob_pos)
    except Exception: auc=float("nan")
    cm=confusion_matrix(y_true,y_pred,labels=[0,1]); return (total_loss/n,acc,prec,rec,f1,auc,np.array(cm),np.array(y_true),np.array(y_prob_pos))
def plot_curves(history,out_dir:Path):
    out_dir.mkdir(parents=True,exist_ok=True);epochs=np.arange(1,len(history["train_loss"])+1)
    plt.figure();plt.plot(epochs,history["train_loss"],label="train loss");plt.plot(epochs,history["val_loss"],label="val loss");plt.xlabel("epoch");plt.ylabel("loss");plt.legend();plt.tight_layout();plt.savefig(out_dir/"loss_curves.png");plt.close()
    plt.figure();plt.plot(epochs,history["train_acc"],label="train acc");plt.plot(epochs,history["val_acc"],label="val acc");plt.xlabel("epoch");plt.ylabel("accuracy");plt.legend();plt.tight_layout();plt.savefig(out_dir/"acc_curves.png");plt.close()
def plot_cm(cm,out_dir:Path,class_names):
    plt.figure();plt.imshow(cm,interpolation="nearest");plt.title("Confusion Matrix");plt.colorbar();ticks=np.arange(len(class_names));plt.xticks(ticks,class_names,rotation=45);plt.yticks(ticks,class_names);thresh=cm.max()/2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]): plt.text(j,i,int(cm[i,j]),ha="center",va="center",color="white" if cm[i,j]>thresh else "black")
    plt.ylabel("True");plt.xlabel("Predicted");plt.tight_layout();plt.savefig(out_dir/"confusion_matrix.png");plt.close()
def plot_roc_and_save(y_true_bin,y_prob_pos,out_dir:Path,pos_label_name="fake"):
    fpr,tpr,_=roc_curve(y_true_bin,y_prob_pos);auc=roc_auc_score(y_true_bin,y_prob_pos)
    plt.figure();plt.plot(fpr,tpr,label=f"AUC={auc:.3f}");plt.plot([0,1],[0,1],"--");plt.xlabel("FPR");plt.ylabel("TPR");plt.title(f"ROC ({pos_label_name} positive)");plt.legend();plt.tight_layout();plt.savefig(out_dir/"roc_curve.png");plt.close();return auc

# ----------------------- Main -----------------------
def main():
    
    # =================================================================================
    # ===> KAGGLE USER CONFIGURATION SECTION (CRITICAL!) <===
    USER_IMAGE_DATA_ROOT = "/kaggle/input/faceforensics/WD_subset_png"
    USER_VIDEO_DATA_ROOT = "/kaggle/input/wilddeepfake/ffpp_subset_c23"
    USER_OUT_DIR = "/kaggle/working/training_results"

    # ===> NEW: RESUME TRAINING CONFIGURATION <===
    RESUME_TRAINING = True # SET THIS TO True TO RESUME
    # This path points to the 'latest' checkpoint inside your output directory
    CHECKPOINT_PATH = f"{USER_OUT_DIR}/checkpoints/resnet50_epoch60.pt"
    # =================================================================================

    args_dict = {
        "image_data_root": USER_IMAGE_DATA_ROOT, "video_data_root": USER_VIDEO_DATA_ROOT,
        "out_dir": USER_OUT_DIR, "per_seq_cap": 60, 
        "epochs": 80, # <== INCREASE THE TOTAL NUMBER OF EPOCHS
        "batch_size": 128, "img_size": 224, "workers": 2,
        "lr": 3e-4, "weight_decay": 1e-4, "label_smoothing": 0.05, "freeze_backbone": False,
        "class_weights": True, "mixup_p": 0.3, "cutmix_p": 0.2, "mix_alpha": 0.2,
        "ema_decay": 0.999, "grad_clip": 1.0, "warmup_pct": 0.1, "tta": 4, "seed": 42,
    }
    args = argparse.Namespace(**args_dict)

    print("\n[Kaggle Setup]");
    if not torch.cuda.is_available(): print("\n!!!! WARNING: GPU ACCELERATOR NOT DETECTED !!!!\n")
    else: print(f"GPU Detected: {torch.cuda.get_device_name(0)}")

    set_seed(args.seed); device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); out_dir = Path(args.out_dir)
    ckpt_dir = out_dir / "checkpoints"; art_dir = out_dir / "artifacts"
    ckpt_dir.mkdir(parents=True, exist_ok=True); art_dir.mkdir(parents=True, exist_ok=True)
    print(f"\nSaving outputs to: {out_dir}")

    train_ds, val_ds, train_dl, val_dl = build_loaders(
        image_data_root=args.image_data_root, video_data_root=args.video_data_root,
        img_size=args.img_size, batch_size=args.batch_size, workers=args.workers, per_seq_cap=args.per_seq_cap
    )
    print(f"Master Classes: {train_ds.class_to_idx}"); idx_fake = train_ds.class_to_idx.get("fake", 0)

    class_weights_tensor = None
    if args.class_weights:
        print("Computing class weights for combined dataset...")
        sampler_indices = list(train_dl.sampler.indices); labels = [train_ds.targets[i] for i in sampler_indices]
        counts = np.bincount(labels, minlength=len(train_ds.classes)); total = counts.sum()
        weights = [total/(len(counts)*c) if c > 0 else 0.0 for c in counts]
        class_weights_tensor = torch.tensor(weights, dtype=torch.float32, device=device)
        print(f"Class counts in one epoch: {counts}. Weights: {class_weights_tensor.cpu().numpy()}")

    print("\nInitializing ResNet-50 model..."); model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    model.fc = nn.Linear(model.fc.in_features, len(train_ds.classes))
    if args.freeze_backbone: [p.requires_grad_(False) for n, p in model.named_parameters() if not n.startswith("fc.")]
    model.to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=args.label_smoothing)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    total_steps = len(train_dl) * args.epochs; warmup_steps = int(args.warmup_pct * total_steps)
    scheduler = WarmupThenCosine(optimizer, warmup_steps=warmup_steps, total_steps=total_steps)
    scaler = torch.amp.GradScaler(enabled=(device.type=="cuda"))
    ema = EMA(model, decay=args.ema_decay)
    history = defaultdict(list); best_auc, best_epoch = -1.0, -1; global_step = 0; start_epoch = 1

    # ===> NEW: LOGIC TO LOAD CHECKPOINT AND RESUME <===
    if RESUME_TRAINING:
        print("\n[Resuming Training]")
        if not Path(CHECKPOINT_PATH).exists():
            print(f"  !!! WARNING: Checkpoint not found at {CHECKPOINT_PATH}. Starting from scratch. !!!")
        else:
            # For modern PyTorch, weights_only=False is needed for non-tensor data
            ckpt = torch.load(CHECKPOINT_PATH, map_location='cpu', weights_only=False)
            
            # The saved model is the EMA model
            model.load_state_dict(ckpt['model'], strict=True) 
            model.to(device)
            ema.shadow = {k: v.to(device) for k, v in ckpt['model'].items()}
            
            start_epoch = ckpt.get('epoch', 0) + 1
            
            print(f"  Fast-forwarding LR scheduler to step of epoch {start_epoch - 1}...")
            steps_to_advance = (start_epoch - 1) * len(train_dl)
            for _ in range(steps_to_advance): scheduler.step()

            history_path = out_dir / "history.json"
            if history_path.exists():
                with open(history_path, 'r') as f: history = defaultdict(list, json.load(f))
                if history.get('val_auc'): best_auc = max(history.get('val_auc', [0.0]))
            print(f"  --> Checkpoint loaded. Resuming from Epoch {start_epoch}. Previous best AUC: {best_auc:.4f}")

    print("\nStarting Training...")
    # The loop now starts from 'start_epoch'
    for epoch in range(start_epoch, args.epochs + 1):
        model.train(); tr_loss_sum, tr_correct, n_samples = 0.0, 0, 0
        pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{args.epochs}", leave=False)
        for x, y in pbar:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            use_mixup = random.random() < args.mixup_p; use_cutmix = (not use_mixup) and (random.random() < args.cutmix_p)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
                if use_mixup:
                    x_mix, y_a, y_b, lam = mixup_data(x, y, alpha=args.mix_alpha)
                    logits = model(x_mix); loss = mixup_criterion(criterion, logits, y_a, y_b, lam); y_for_acc = y_a
                elif use_cutmix:
                    lam = np.random.beta(args.mix_alpha, args.mix_alpha); idx_perm = torch.randperm(x.size(0), device=x.device)
                    y_a, y_b = y, y[idx_perm]; W, H = x.shape[3], x.shape[2]; x1, y1, x2, y2 = rand_bbox(W, H, lam)
                    x[:, :, y1:y2, x1:x2] = x[idx_perm, :, y1:y2, x1:x2]; lam_adj = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
                    logits = model(x); loss = mixup_criterion(criterion, logits, y_a, y_b, lam_adj); y_for_acc = y_a
                else: logits = model(x); loss = criterion(logits, y); y_for_acc = y
            scaler.scale(loss).backward(); scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer); scaler.update(); ema.update(model); scheduler.step(); global_step += 1
            bs = y.size(0); tr_loss_sum += loss.item() * bs; tr_correct += (logits.detach().argmax(1) == y_for_acc).sum().item(); n_samples += bs
            pbar.set_postfix(loss=tr_loss_sum/max(1,n_samples), acc=tr_correct/max(1,n_samples))
        tr_loss = tr_loss_sum / max(1, n_samples); tr_acc  = tr_correct / max(1, n_samples)
        state_backup = {k: v.detach().clone() for k, v in model.state_dict().items()}
        model.load_state_dict(ema.shadow, strict=True)
        val_loss, val_acc, prec, rec, f1, val_auc, cm, y_true, y_prob_pos = evaluate(model, val_dl, device, pos_index=idx_fake, tta=args.tta)
        model.load_state_dict(state_backup, strict=True)
        
        # We now save optimizer and scheduler states for even better resumption next time
        ckpt = {
            "model": ema.shadow, "epoch": epoch, "val_auc": float(val_auc), 
            "class_to_idx": train_ds.class_to_idx, "args": args_dict,
            "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict()
        }
        torch.save(ckpt, ckpt_dir / f"resnet50_epoch{epoch:02d}.pt")
        # Overwrite a 'latest' file for easy resuming
        torch.save(ckpt, ckpt_dir / "resnet50_latest.pt")
        
        history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc); history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc); history["val_auc"].append(val_auc); history["val_f1"].append(f1)
        tqdm.write(f"Epoch {epoch:02d} | tr Ls {tr_loss:.4f} Ac {tr_acc:.4f} | val Ls {val_loss:.4f} Ac {val_acc:.4f} AUC {val_auc:.4f} F1 {f1:.4f}")
        class_names = [k for k,_ in sorted(train_ds.class_to_idx.items(), key=lambda kv: kv[1])]
        plot_cm(cm, art_dir, class_names=class_names)
        y_true_bin = (y_true == idx_fake).astype(int)
        _ = plot_roc_and_save(y_true_bin, y_prob_pos, art_dir, pos_label_name="fake")
        if not math.isnan(val_auc) and val_auc > best_auc:
            best_auc, best_epoch = val_auc, epoch; torch.save(ckpt, ckpt_dir/"resnet50_best.pt")
            tqdm.write(f"  >>> New Best AUC: {best_auc:.4f} saved.")
    plot_curves(history, art_dir);
    with open(out_dir/"history.json","w") as f: json.dump(history, f, indent=2)
    print(f"\nTraining Complete.\nBest epoch by AUC: {best_epoch} (AUC={best_auc:.4f})\nCheckpoints: {ckpt_dir}\nArtifacts: {art_dir}")

if __name__ == "__main__":
    main()


[Kaggle Setup]
GPU Detected: Tesla T4
Seed set to 42.

Saving outputs to: /kaggle/working/training_results
Loading IMAGE data from: /kaggle/input/faceforensics/WD_subset_png
Loading VIDEO data from: /kaggle/input/wilddeepfake/ffpp_subset_c23

Combined 2 dataset(s). Training samples: 47927, Validation samples: 5219
Building sequence-limited sampler for combined dataset...


Indexing sequences:   0%|          | 0/47927 [00:00<?, ?it/s]

Master Classes: {'fake': 0, 'real': 1}
Computing class weights for combined dataset...
Class counts in one epoch: [1211 1530]. Weights: [1.1317093  0.89575166]

Initializing ResNet-50 model...

[Resuming Training]
  Fast-forwarding LR scheduler to step of epoch 60...
  --> Checkpoint loaded. Resuming from Epoch 61. Previous best AUC: 0.9262

Starting Training...




Epoch 61/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 61 | tr Ls 0.2603 Ac 0.8709 | val Ls 0.4309 Ac 0.7994 AUC 0.9269 F1 0.7988
  >>> New Best AUC: 0.9269 saved.


Epoch 62/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 62 | tr Ls 0.1741 Ac 0.8621 | val Ls 0.4266 Ac 0.8023 AUC 0.9276 F1 0.8006
  >>> New Best AUC: 0.9276 saved.


Epoch 63/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 63 | tr Ls 0.1873 Ac 0.8822 | val Ls 0.4213 Ac 0.8046 AUC 0.9286 F1 0.8019
  >>> New Best AUC: 0.9286 saved.


Epoch 64/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 64 | tr Ls 0.2106 Ac 0.9570 | val Ls 0.4180 Ac 0.8080 AUC 0.9287 F1 0.8044
  >>> New Best AUC: 0.9287 saved.


Epoch 65/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 65 | tr Ls 0.1690 Ac 0.8924 | val Ls 0.4147 Ac 0.8103 AUC 0.9287 F1 0.8055


Epoch 66/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 66 | tr Ls 0.2003 Ac 0.9314 | val Ls 0.4116 Ac 0.8130 AUC 0.9289 F1 0.8076
  >>> New Best AUC: 0.9289 saved.


Epoch 67/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 67 | tr Ls 0.1914 Ac 0.8924 | val Ls 0.4085 Ac 0.8164 AUC 0.9293 F1 0.8103
  >>> New Best AUC: 0.9293 saved.


Epoch 68/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 68 | tr Ls 0.1898 Ac 0.9011 | val Ls 0.4047 Ac 0.8191 AUC 0.9296 F1 0.8124
  >>> New Best AUC: 0.9296 saved.


Epoch 69/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 69 | tr Ls 0.1962 Ac 0.9427 | val Ls 0.4013 Ac 0.8201 AUC 0.9305 F1 0.8128
  >>> New Best AUC: 0.9305 saved.


Epoch 70/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 70 | tr Ls 0.1749 Ac 0.9351 | val Ls 0.3991 Ac 0.8222 AUC 0.9306 F1 0.8143
  >>> New Best AUC: 0.9306 saved.


Epoch 71/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 71 | tr Ls 0.1878 Ac 0.9606 | val Ls 0.3967 Ac 0.8241 AUC 0.9303 F1 0.8158


Epoch 72/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 72 | tr Ls 0.2176 Ac 0.8865 | val Ls 0.3944 Ac 0.8258 AUC 0.9304 F1 0.8169


Epoch 73/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 73 | tr Ls 0.2561 Ac 0.8811 | val Ls 0.3922 Ac 0.8274 AUC 0.9306 F1 0.8180
  >>> New Best AUC: 0.9306 saved.


Epoch 74/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 74 | tr Ls 0.2181 Ac 0.9037 | val Ls 0.3909 Ac 0.8272 AUC 0.9305 F1 0.8178


Epoch 75/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 75 | tr Ls 0.1615 Ac 0.9347 | val Ls 0.3888 Ac 0.8274 AUC 0.9305 F1 0.8174


Epoch 76/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 76 | tr Ls 0.1909 Ac 0.9435 | val Ls 0.3866 Ac 0.8277 AUC 0.9304 F1 0.8174


Epoch 77/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 77 | tr Ls 0.2320 Ac 0.9190 | val Ls 0.3852 Ac 0.8281 AUC 0.9301 F1 0.8174


Epoch 78/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 78 | tr Ls 0.2092 Ac 0.8734 | val Ls 0.3839 Ac 0.8293 AUC 0.9297 F1 0.8181


Epoch 79/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 79 | tr Ls 0.2281 Ac 0.8548 | val Ls 0.3819 Ac 0.8320 AUC 0.9296 F1 0.8203


Epoch 80/80:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 80 | tr Ls 0.2325 Ac 0.8730 | val Ls 0.3805 Ac 0.8329 AUC 0.9294 F1 0.8208

Training Complete.
Best epoch by AUC: 73 (AUC=0.9306)
Checkpoints: /kaggle/working/training_results/checkpoints
Artifacts: /kaggle/working/training_results/artifacts


In [4]:
"""
ResNet-50 Deepfake Detector - Kaggle GPU Optimized
VERSION FOR COMBINED IMAGE & VIDEO DATASETS WITH RESUME CAPABILITY
"""

import argparse, io, json, math, os, random, time, sys
from pathlib import Path
from collections import defaultdict

# --- [SETUP] Install necessary libraries ---
try:
    import cv2
except ImportError:
    print("OpenCV not found. Installing opencv-python-headless...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "opencv-python-headless"])
    import cv2

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler, ConcatDataset
from torchvision import datasets, transforms, models

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, roc_auc_score, roc_curve
)

# ----------------------- Utilities & Repro -----------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.benchmark = True
    print(f"Seed set to {seed}.")

# ----------------------- Social-style Augmentations -----------------------
class RandomJPEGCompression:
    def __init__(self, qmin=35, qmax=92, p=0.7): self.qmin, self.qmax, self.p = qmin, qmax, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        buf = io.BytesIO(); img.save(buf, format="JPEG", quality=random.randint(self.qmin, self.qmax))
        buf.seek(0); return Image.open(buf).convert("RGB")
class RandomDownscale:
    def __init__(self, scale_min=0.4, scale_max=0.85, p=0.6): self.scale_min, self.scale_max, self.p = scale_min, scale_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        w, h = img.size; s = random.uniform(self.scale_min, self.scale_max)
        nw, nh = max(8, int(w*s)), max(8, int(h*s))
        img = img.resize((nw, nh), resample=Image.BILINEAR); img = img.resize((w, h), resample=Image.BILINEAR)
        return img
class RandomGaussianNoise:
    def __init__(self, sigma_min=0.0, sigma_max=0.03, p=0.4): self.sigma_min, self.sigma_max, self.p = sigma_min, sigma_max, p
    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p: return img
        arr = np.asarray(img).astype(np.float32)/255.0; sigma = random.uniform(self.sigma_min, self.sigma_max)
        noise = np.random.normal(0.0, sigma, arr.shape).astype(np.float32); arr = np.clip(arr + noise, 0.0, 1.0)
        arr = (arr*255.0 + 0.5).astype(np.uint8); return Image.fromarray(arr)

# ----------------------- Custom Video Dataset -----------------------
class VideoFrameDataset(torch.utils.data.Dataset):
    def __init__(self, root: Path, transform=None, class_to_idx=None):
        self.root = root; self.transform = transform; self.class_to_idx = class_to_idx
        self.samples = []; self.targets = []
        for class_name in self.class_to_idx.keys():
            class_dir = self.root / class_name
            if not class_dir.exists(): continue
            for video_path in sorted(list(class_dir.rglob('*.mp4'))):
                label = self.class_to_idx[class_name]
                self.samples.append((str(video_path), label)); self.targets.append(label)
        if not self.samples: print(f"  --> WARNING: No .mp4 videos found in {self.root}")
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        cap = None
        try:
            cap = cv2.VideoCapture(video_path)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames < 1: return self.__getitem__((idx + 1) % len(self))
            frame_idx = random.randint(0, total_frames - 1)
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret: return self.__getitem__((idx + 1) % len(self))
        except Exception as e:
            print(f"  --> Error processing {video_path}: {e}. Skipping."); return self.__getitem__((idx + 1) % len(self))
        finally:
            if cap: cap.release()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB); img = Image.fromarray(frame_rgb)
        if self.transform: img = self.transform(img)
        return img, label

# ----------------------- Data Loading Pipeline -----------------------
def build_indices_limit_per_seq(dataset: ConcatDataset, max_per_seq=60, seed=42):
    if max_per_seq is None or max_per_seq <= 0: return list(range(len(dataset)))
    by_seq = defaultdict(list)
    print("Building sequence-limited sampler for combined dataset...")
    for idx, (path, _) in enumerate(tqdm(dataset.samples, desc="Indexing sequences")):
        path_obj = Path(path)
        seq_id = path_obj.stem if path_obj.suffix.lower() == '.mp4' else path_obj.parent.name
        by_seq[seq_id].append(idx)
    rng = random.Random(seed); indices = []
    for seq, idxs in by_seq.items(): rng.shuffle(idxs); indices.extend(idxs[:max_per_seq])
    rng.shuffle(indices); return indices

def build_loaders(image_data_root, video_data_root, img_size=224, batch_size=64, workers=2, per_seq_cap=60):
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    train_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.2)), transforms.RandomResizedCrop(img_size, scale=(0.55, 1.0), ratio=(0.75, 1.33)),
        transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=10),
        transforms.RandomPerspective(distortion_scale=0.25, p=0.25), RandomJPEGCompression(qmin=35, qmax=92, p=0.7),
        RandomDownscale(scale_min=0.4, scale_max=0.85, p=0.6), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)),
        transforms.ColorJitter(0.15, 0.15, 0.15, 0.05), transforms.RandomAdjustSharpness(sharpness_factor=0.6, p=0.3),
        transforms.RandomAutocontrast(p=0.3), RandomGaussianNoise(sigma_min=0.0, sigma_max=0.03, p=0.4),
        transforms.ToTensor(), normalize, transforms.RandomErasing(p=0.25, scale=(0.02,0.2), ratio=(0.3,3.3), value='random'),
    ])
    val_tf = transforms.Compose([transforms.Resize(int(img_size*1.15)), transforms.CenterCrop(img_size), transforms.ToTensor(), normalize])
    train_datasets, val_datasets = [], []; master_class_to_idx = {'fake': 0, 'real': 1}
    if image_data_root:
        root = Path(image_data_root); print(f"Loading IMAGE data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(datasets.ImageFolder(root/"train", transform=train_tf))
            val_datasets.append(datasets.ImageFolder(root/"val", transform=val_tf))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if video_data_root:
        root = Path(video_data_root); print(f"Loading VIDEO data from: {root}")
        if (root/"train").exists() and (root/"val").exists():
            train_datasets.append(VideoFrameDataset(root/"train", transform=train_tf, class_to_idx=master_class_to_idx))
            val_datasets.append(VideoFrameDataset(root/"val", transform=val_tf, class_to_idx=master_class_to_idx))
        else: print(f"  --> WARNING: 'train' or 'val' not found in {root}. Skipping.")
    if not train_datasets: raise ValueError("No valid datasets were loaded. Check paths in USER CONFIGURATION.")
    train_ds = ConcatDataset(train_datasets); val_ds = ConcatDataset(val_datasets)
    train_ds.class_to_idx = master_class_to_idx; train_ds.classes = list(master_class_to_idx.keys())
    train_ds.samples = [s for ds in train_ds.datasets for s in ds.samples]
    train_ds.targets = [t for ds in train_ds.datasets for t in ds.targets]
    val_ds.targets = [t for ds in val_ds.datasets for t in ds.targets]
    print(f"\nCombined {len(train_datasets)} dataset(s). Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}")
    indices = build_indices_limit_per_seq(train_ds, max_per_seq=per_seq_cap, seed=42)
    num_workers = 2 if os.cpu_count() <= 4 else 4
    train_dl = DataLoader(train_ds, batch_size=batch_size, sampler=SubsetRandomSampler(indices), num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers>0))
    return train_ds, val_ds, train_dl, val_dl

# --- [The rest of the helper functions are unchanged] ---
class WarmupThenCosine(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1): self.warmup_steps=max(1,warmup_steps); self.total_steps=max(self.warmup_steps+1,total_steps); super().__init__(optimizer, last_epoch)
    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.warmup_steps: return [base * (step/self.warmup_steps) for base in self.base_lrs]
        t = (step-self.warmup_steps)/(self.total_steps-self.warmup_steps); return [base*0.5*(1+math.cos(math.pi*t)) for base in self.base_lrs]
class EMA:
    def __init__(self, model, decay=0.999): self.decay = decay; self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if not v.is_floating_point(): self.shadow[k] = v.detach().clone(); continue
            if (self.shadow[k].dtype!=v.dtype)or(self.shadow[k].device!=v.device): self.shadow[k]=v.detach().clone(); continue
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
def rand_bbox(W, H, lam):
    cut_w=int(W*math.sqrt(1-lam)); cut_h=int(H*math.sqrt(1-lam)); cx,cy=random.randint(0,W-1),random.randint(0,H-1)
    x1,y1=max(cx-cut_w//2,0),max(cy-cut_h//2,0); x2,y2=min(cx+cut_w//2,W),min(cy+cut_h//2,H); return x1,y1,x2,y2
def mixup_data(x, y, alpha=0.2):
    lam=np.random.beta(alpha,alpha) if alpha>0 else 1.0; idx=torch.randperm(x.size(0),device=x.device)
    mixed_x=lam*x+(1-lam)*x[idx]; y_a,y_b=y,y[idx]; return mixed_x,y_a,y_b,lam
def mixup_criterion(criterion, pred, y_a, y_b, lam): return lam*criterion(pred,y_a)+(1-lam)*criterion(pred,y_b)
@torch.no_grad()
def tta_logits(model,x,device,tta=4):
    outs=[]
    for op in range(tta):
        xx=x;
        if op%2==1: xx=torch.flip(xx,dims=[-1])
        with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")): outs.append(model(xx))
    return torch.stack(outs,0).mean(0)
@torch.no_grad()
def evaluate(model,dl,device,pos_index,tta=4):
    model.eval(); y_true,y_pred,y_prob_pos=[],[],[]; total_loss,n=0.0,0
    criterion=nn.CrossEntropyLoss()
    for x,y in dl:
        x,y=x.to(device,non_blocking=True),y.to(device,non_blocking=True)
        logits=tta_logits(model,x,device,tta=tta) if tta and tta>1 else model(x)
        with torch.amp.autocast(device_type=device.type,enabled=(device.type=="cuda")): loss=criterion(logits,y)
        prob=torch.softmax(logits,dim=1)
        y_true.extend(y.cpu().tolist());y_pred.extend(prob.argmax(1).cpu().tolist());y_prob_pos.extend(prob[:,pos_index].cpu().tolist())
        bs=y.size(0);total_loss+=loss.item()*bs;n+=bs
    acc=accuracy_score(y_true,y_pred); prec,rec,f1,_=precision_recall_fscore_support(y_true,y_pred,average="binary",pos_label=pos_index,zero_division=0)
    try: auc=roc_auc_score([1 if t==pos_index else 0 for t in y_true], y_prob_pos)
    except Exception: auc=float("nan")
    cm=confusion_matrix(y_true,y_pred,labels=[0,1]); return (total_loss/n,acc,prec,rec,f1,auc,np.array(cm),np.array(y_true),np.array(y_prob_pos))
def plot_curves(history,out_dir:Path):
    out_dir.mkdir(parents=True,exist_ok=True);epochs=np.arange(1,len(history["train_loss"])+1)
    plt.figure();plt.plot(epochs,history["train_loss"],label="train loss");plt.plot(epochs,history["val_loss"],label="val loss");plt.xlabel("epoch");plt.ylabel("loss");plt.legend();plt.tight_layout();plt.savefig(out_dir/"loss_curves.png");plt.close()
    plt.figure();plt.plot(epochs,history["train_acc"],label="train acc");plt.plot(epochs,history["val_acc"],label="val acc");plt.xlabel("epoch");plt.ylabel("accuracy");plt.legend();plt.tight_layout();plt.savefig(out_dir/"acc_curves.png");plt.close()
def plot_cm(cm,out_dir:Path,class_names):
    plt.figure();plt.imshow(cm,interpolation="nearest");plt.title("Confusion Matrix");plt.colorbar();ticks=np.arange(len(class_names));plt.xticks(ticks,class_names,rotation=45);plt.yticks(ticks,class_names);thresh=cm.max()/2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]): plt.text(j,i,int(cm[i,j]),ha="center",va="center",color="white" if cm[i,j]>thresh else "black")
    plt.ylabel("True");plt.xlabel("Predicted");plt.tight_layout();plt.savefig(out_dir/"confusion_matrix.png");plt.close()
def plot_roc_and_save(y_true_bin,y_prob_pos,out_dir:Path,pos_label_name="fake"):
    fpr,tpr,_=roc_curve(y_true_bin,y_prob_pos);auc=roc_auc_score(y_true_bin,y_prob_pos)
    plt.figure();plt.plot(fpr,tpr,label=f"AUC={auc:.3f}");plt.plot([0,1],[0,1],"--");plt.xlabel("FPR");plt.ylabel("TPR");plt.title(f"ROC ({pos_label_name} positive)");plt.legend();plt.tight_layout();plt.savefig(out_dir/"roc_curve.png");plt.close();return auc

# ----------------------- Main -----------------------
def main():
    
    # =================================================================================
    # ===> KAGGLE USER CONFIGURATION SECTION (CRITICAL!) <===
    USER_IMAGE_DATA_ROOT = "/kaggle/input/faceforensics/WD_subset_png"
    USER_VIDEO_DATA_ROOT = "/kaggle/input/wilddeepfake/ffpp_subset_c23"
    USER_OUT_DIR = "/kaggle/working/training_results"

    # ===> NEW: RESUME TRAINING CONFIGURATION <===
    RESUME_TRAINING = True # SET THIS TO True TO RESUME
    # This path points to the 'latest' checkpoint inside your output directory
    CHECKPOINT_PATH = f"{USER_OUT_DIR}/checkpoints/resnet50_epoch80.pt"
    # =================================================================================

    args_dict = {
        "image_data_root": USER_IMAGE_DATA_ROOT, "video_data_root": USER_VIDEO_DATA_ROOT,
        "out_dir": USER_OUT_DIR, "per_seq_cap": 60, 
        "epochs": 90, # <== INCREASE THE TOTAL NUMBER OF EPOCHS
        "batch_size": 128, "img_size": 224, "workers": 2,
        "lr": 3e-4, "weight_decay": 1e-4, "label_smoothing": 0.05, "freeze_backbone": False,
        "class_weights": True, "mixup_p": 0.3, "cutmix_p": 0.2, "mix_alpha": 0.2,
        "ema_decay": 0.999, "grad_clip": 1.0, "warmup_pct": 0.1, "tta": 4, "seed": 42,
    }
    args = argparse.Namespace(**args_dict)

    print("\n[Kaggle Setup]");
    if not torch.cuda.is_available(): print("\n!!!! WARNING: GPU ACCELERATOR NOT DETECTED !!!!\n")
    else: print(f"GPU Detected: {torch.cuda.get_device_name(0)}")

    set_seed(args.seed); device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); out_dir = Path(args.out_dir)
    ckpt_dir = out_dir / "checkpoints"; art_dir = out_dir / "artifacts"
    ckpt_dir.mkdir(parents=True, exist_ok=True); art_dir.mkdir(parents=True, exist_ok=True)
    print(f"\nSaving outputs to: {out_dir}")

    train_ds, val_ds, train_dl, val_dl = build_loaders(
        image_data_root=args.image_data_root, video_data_root=args.video_data_root,
        img_size=args.img_size, batch_size=args.batch_size, workers=args.workers, per_seq_cap=args.per_seq_cap
    )
    print(f"Master Classes: {train_ds.class_to_idx}"); idx_fake = train_ds.class_to_idx.get("fake", 0)

    class_weights_tensor = None
    if args.class_weights:
        print("Computing class weights for combined dataset...")
        sampler_indices = list(train_dl.sampler.indices); labels = [train_ds.targets[i] for i in sampler_indices]
        counts = np.bincount(labels, minlength=len(train_ds.classes)); total = counts.sum()
        weights = [total/(len(counts)*c) if c > 0 else 0.0 for c in counts]
        class_weights_tensor = torch.tensor(weights, dtype=torch.float32, device=device)
        print(f"Class counts in one epoch: {counts}. Weights: {class_weights_tensor.cpu().numpy()}")

    print("\nInitializing ResNet-50 model..."); model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    model.fc = nn.Linear(model.fc.in_features, len(train_ds.classes))
    if args.freeze_backbone: [p.requires_grad_(False) for n, p in model.named_parameters() if not n.startswith("fc.")]
    model.to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=args.label_smoothing)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    total_steps = len(train_dl) * args.epochs; warmup_steps = int(args.warmup_pct * total_steps)
    scheduler = WarmupThenCosine(optimizer, warmup_steps=warmup_steps, total_steps=total_steps)
    scaler = torch.amp.GradScaler(enabled=(device.type=="cuda"))
    ema = EMA(model, decay=args.ema_decay)
    history = defaultdict(list); best_auc, best_epoch = -1.0, -1; global_step = 0; start_epoch = 1

    # ===> NEW: LOGIC TO LOAD CHECKPOINT AND RESUME <===
    if RESUME_TRAINING:
        print("\n[Resuming Training]")
        if not Path(CHECKPOINT_PATH).exists():
            print(f"  !!! WARNING: Checkpoint not found at {CHECKPOINT_PATH}. Starting from scratch. !!!")
        else:
            # For modern PyTorch, weights_only=False is needed for non-tensor data
            ckpt = torch.load(CHECKPOINT_PATH, map_location='cpu', weights_only=False)
            
            # The saved model is the EMA model
            model.load_state_dict(ckpt['model'], strict=True) 
            model.to(device)
            ema.shadow = {k: v.to(device) for k, v in ckpt['model'].items()}
            
            start_epoch = ckpt.get('epoch', 0) + 1
            
            print(f"  Fast-forwarding LR scheduler to step of epoch {start_epoch - 1}...")
            steps_to_advance = (start_epoch - 1) * len(train_dl)
            for _ in range(steps_to_advance): scheduler.step()

            history_path = out_dir / "history.json"
            if history_path.exists():
                with open(history_path, 'r') as f: history = defaultdict(list, json.load(f))
                if history.get('val_auc'): best_auc = max(history.get('val_auc', [0.0]))
            print(f"  --> Checkpoint loaded. Resuming from Epoch {start_epoch}. Previous best AUC: {best_auc:.4f}")

    print("\nStarting Training...")
    # The loop now starts from 'start_epoch'
    for epoch in range(start_epoch, args.epochs + 1):
        model.train(); tr_loss_sum, tr_correct, n_samples = 0.0, 0, 0
        pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{args.epochs}", leave=False)
        for x, y in pbar:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            use_mixup = random.random() < args.mixup_p; use_cutmix = (not use_mixup) and (random.random() < args.cutmix_p)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
                if use_mixup:
                    x_mix, y_a, y_b, lam = mixup_data(x, y, alpha=args.mix_alpha)
                    logits = model(x_mix); loss = mixup_criterion(criterion, logits, y_a, y_b, lam); y_for_acc = y_a
                elif use_cutmix:
                    lam = np.random.beta(args.mix_alpha, args.mix_alpha); idx_perm = torch.randperm(x.size(0), device=x.device)
                    y_a, y_b = y, y[idx_perm]; W, H = x.shape[3], x.shape[2]; x1, y1, x2, y2 = rand_bbox(W, H, lam)
                    x[:, :, y1:y2, x1:x2] = x[idx_perm, :, y1:y2, x1:x2]; lam_adj = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
                    logits = model(x); loss = mixup_criterion(criterion, logits, y_a, y_b, lam_adj); y_for_acc = y_a
                else: logits = model(x); loss = criterion(logits, y); y_for_acc = y
            scaler.scale(loss).backward(); scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer); scaler.update(); ema.update(model); scheduler.step(); global_step += 1
            bs = y.size(0); tr_loss_sum += loss.item() * bs; tr_correct += (logits.detach().argmax(1) == y_for_acc).sum().item(); n_samples += bs
            pbar.set_postfix(loss=tr_loss_sum/max(1,n_samples), acc=tr_correct/max(1,n_samples))
        tr_loss = tr_loss_sum / max(1, n_samples); tr_acc  = tr_correct / max(1, n_samples)
        state_backup = {k: v.detach().clone() for k, v in model.state_dict().items()}
        model.load_state_dict(ema.shadow, strict=True)
        val_loss, val_acc, prec, rec, f1, val_auc, cm, y_true, y_prob_pos = evaluate(model, val_dl, device, pos_index=idx_fake, tta=args.tta)
        model.load_state_dict(state_backup, strict=True)
        
        # We now save optimizer and scheduler states for even better resumption next time
        ckpt = {
            "model": ema.shadow, "epoch": epoch, "val_auc": float(val_auc), 
            "class_to_idx": train_ds.class_to_idx, "args": args_dict,
            "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict()
        }
        torch.save(ckpt, ckpt_dir / f"resnet50_epoch{epoch:02d}.pt")
        # Overwrite a 'latest' file for easy resuming
        torch.save(ckpt, ckpt_dir / "resnet50_latest.pt")
        
        history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc); history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc); history["val_auc"].append(val_auc); history["val_f1"].append(f1)
        tqdm.write(f"Epoch {epoch:02d} | tr Ls {tr_loss:.4f} Ac {tr_acc:.4f} | val Ls {val_loss:.4f} Ac {val_acc:.4f} AUC {val_auc:.4f} F1 {f1:.4f}")
        class_names = [k for k,_ in sorted(train_ds.class_to_idx.items(), key=lambda kv: kv[1])]
        plot_cm(cm, art_dir, class_names=class_names)
        y_true_bin = (y_true == idx_fake).astype(int)
        _ = plot_roc_and_save(y_true_bin, y_prob_pos, art_dir, pos_label_name="fake")
        if not math.isnan(val_auc) and val_auc > best_auc:
            best_auc, best_epoch = val_auc, epoch; torch.save(ckpt, ckpt_dir/"resnet50_best.pt")
            tqdm.write(f"  >>> New Best AUC: {best_auc:.4f} saved.")
    plot_curves(history, art_dir);
    with open(out_dir/"history.json","w") as f: json.dump(history, f, indent=2)
    print(f"\nTraining Complete.\nBest epoch by AUC: {best_epoch} (AUC={best_auc:.4f})\nCheckpoints: {ckpt_dir}\nArtifacts: {art_dir}")

if __name__ == "__main__":
    main()


[Kaggle Setup]
GPU Detected: Tesla T4
Seed set to 42.

Saving outputs to: /kaggle/working/training_results
Loading IMAGE data from: /kaggle/input/faceforensics/WD_subset_png
Loading VIDEO data from: /kaggle/input/wilddeepfake/ffpp_subset_c23

Combined 2 dataset(s). Training samples: 47927, Validation samples: 5219
Building sequence-limited sampler for combined dataset...


Indexing sequences:   0%|          | 0/47927 [00:00<?, ?it/s]

Master Classes: {'fake': 0, 'real': 1}
Computing class weights for combined dataset...
Class counts in one epoch: [1211 1530]. Weights: [1.1317093  0.89575166]

Initializing ResNet-50 model...

[Resuming Training]
  Fast-forwarding LR scheduler to step of epoch 80...
  --> Checkpoint loaded. Resuming from Epoch 81. Previous best AUC: 0.9306

Starting Training...




Epoch 81/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 81 | tr Ls 0.2474 Ac 0.8760 | val Ls 0.3791 Ac 0.8335 AUC 0.9295 F1 0.8211


Epoch 82/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 82 | tr Ls 0.1672 Ac 0.8606 | val Ls 0.3765 Ac 0.8373 AUC 0.9296 F1 0.8241


Epoch 83/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 83 | tr Ls 0.1798 Ac 0.8840 | val Ls 0.3737 Ac 0.8396 AUC 0.9299 F1 0.8260


Epoch 84/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 84 | tr Ls 0.2044 Ac 0.9602 | val Ls 0.3711 Ac 0.8429 AUC 0.9301 F1 0.8288


Epoch 85/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 85 | tr Ls 0.1642 Ac 0.8916 | val Ls 0.3689 Ac 0.8444 AUC 0.9304 F1 0.8299


Epoch 86/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 86 | tr Ls 0.1945 Ac 0.9325 | val Ls 0.3666 Ac 0.8461 AUC 0.9306 F1 0.8313


Epoch 87/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 87 | tr Ls 0.1869 Ac 0.8916 | val Ls 0.3644 Ac 0.8481 AUC 0.9308 F1 0.8327
  >>> New Best AUC: 0.9308 saved.


Epoch 88/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 88 | tr Ls 0.1874 Ac 0.9019 | val Ls 0.3624 Ac 0.8498 AUC 0.9308 F1 0.8338
  >>> New Best AUC: 0.9308 saved.


Epoch 89/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 89 | tr Ls 0.1947 Ac 0.9449 | val Ls 0.3602 Ac 0.8509 AUC 0.9311 F1 0.8344
  >>> New Best AUC: 0.9311 saved.


Epoch 90/90:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 90 | tr Ls 0.1734 Ac 0.9369 | val Ls 0.3584 Ac 0.8527 AUC 0.9312 F1 0.8359
  >>> New Best AUC: 0.9312 saved.

Training Complete.
Best epoch by AUC: 90 (AUC=0.9312)
Checkpoints: /kaggle/working/training_results/checkpoints
Artifacts: /kaggle/working/training_results/artifacts


In [5]:
import os, json, zipfile
from pathlib import Path

OUT = Path("/kaggle/working/training_results")
CKPT = OUT / "checkpoints" / "resnet50_best.pt"
ART = OUT / "artifacts"
HIST = OUT / "history.json"
ZIP_PATH = Path("/kaggle/working/resnet50_mixed_export.zip")

assert CKPT.exists(), f"Best checkpoint not found: {CKPT}"

# Also save a light state_dict-only file for easy loading
import torch
best = torch.load(CKPT, map_location="cpu")
state_dict = best["model"]
torch.save(state_dict, OUT / "resnet50_streamlit_state_dict.pth")

# Pack ZIP
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED) as z:
    z.write(CKPT, arcname="resnet50_best.pt")
    if (OUT/"resnet50_streamlit_state_dict.pth").exists():
        z.write(OUT/"resnet50_streamlit_state_dict.pth", arcname="resnet50_streamlit_state_dict.pth")
    if HIST.exists():
        z.write(HIST, arcname="history.json")
    if ART.exists():
        for p in ART.glob("*"):
            z.write(p, arcname=f"artifacts/{p.name}")

print(f"Ready: {ZIP_PATH}")

Ready: /kaggle/working/resnet50_mixed_export.zip
