In [3]:
import pandas as pd
from pathlib import Path
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1"

BASE = Path("/kaggle/input/faceforencispp-extracted-frames")

rows = []

# real videos
for vid_dir in sorted((BASE/"real").iterdir()):
    if not vid_dir.is_dir(): continue
    frames = sorted(list(vid_dir.glob("*.jpg")) + list(vid_dir.glob("*.png")))
    if not frames: continue
    rows.append({
        "video_id": vid_dir.name,
        "class": "real",
        "subtype": "real",
        "video_dir": str(vid_dir),
        "n_frames": len(frames)
    })

# fake videos (5 subtypes)
for subtype_dir in sorted((BASE/"fake").iterdir()):
    if not subtype_dir.is_dir(): continue
    subtype = subtype_dir.name
    for vid_dir in sorted(subtype_dir.iterdir()):
        if not vid_dir.is_dir(): continue
        frames = sorted(list(vid_dir.glob("*.jpg")) + list(vid_dir.glob("*.png")))
        if not frames: continue
        rows.append({
            "video_id": f"{subtype}/{vid_dir.name}",
            "class": "fake",
            "subtype": subtype,
            "video_dir": str(vid_dir),
            "n_frames": len(frames)
        })

df = pd.DataFrame(rows)
print(df.head(), "\n")
print("Total videos:", len(df))
print(df["class"].value_counts())
print(df["subtype"].value_counts())

df.to_csv("/kaggle/working/ffpp_manifest.csv", index=False)
print("Saved manifest ‚Üí /kaggle/working/ffpp_manifest.csv")

  video_id class subtype                                          video_dir  \
0      000  real    real  /kaggle/input/faceforencispp-extracted-frames/...   
1      001  real    real  /kaggle/input/faceforencispp-extracted-frames/...   
2      002  real    real  /kaggle/input/faceforencispp-extracted-frames/...   
3      003  real    real  /kaggle/input/faceforencispp-extracted-frames/...   
4      004  real    real  /kaggle/input/faceforencispp-extracted-frames/...   

   n_frames  
0        32  
1        32  
2        32  
3        32  
4        32   

Total videos: 5995
class
fake    4996
real     999
Name: count, dtype: int64
subtype
FaceSwap          1000
real               999
Deepfakes          999
Face2Face          999
FaceShifter        999
NeuralTextures     999
Name: count, dtype: int64
Saved manifest ‚Üí /kaggle/working/ffpp_manifest.csv


In [None]:
# FFT function with normalization
import torch

def to_fft_tensor(img_tensor: torch.Tensor) -> torch.Tensor:
    """
    Convert RGB tensor to FFT magnitude spectrum
    img_tensor: [3, H, W] normalized tensor
    returns: [1, H, W] normalized FFT magnitude
    """
    # Convert to grayscale
    gray = img_tensor.mean(dim=0, keepdim=True) 
    
    # 2D FFT
    fft = torch.fft.fft2(gray)
    fft_shift = torch.fft.fftshift(fft)
    magnitude = torch.abs(fft_shift)
    
    # Log scale for better dynamic range
    log_mag = torch.log1p(magnitude)
    
    # Robust normalization
    min_val = log_mag.min()
    max_val = log_mag.max()
    
    if (max_val - min_val) > 1e-8:
        log_mag = (log_mag - min_val) / (max_val - min_val)
    else:
        log_mag = torch.zeros_like(log_mag)
    
    return log_mag

In [None]:
# Dataset with frame sampling and class balancing
import os, random, math, glob
from pathlib import Path
from typing import List, Tuple
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
import numpy as np
import torchvision.transforms as T
from sklearn.model_selection import train_test_split

MANIFEST = Path("/kaggle/working/ffpp_manifest.csv")
IMG_SIZE = 224
FRAMES_PER_VIDEO = 12

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

def list_frames(video_dir: Path) -> List[Path]:
    frames = sorted(video_dir.glob("*.jpg")) + sorted(video_dir.glob("*.png"))
    return frames

# Frame sampling for temporal coverage
def sample_frames(n_frames: int, num_samples: int, train: bool) -> List[int]:
    """Sample frames with better temporal coverage"""
    if n_frames <= 0:
        return []
    
    if n_frames <= num_samples:
        indices = list(range(n_frames))
        # Pad with last frame if needed
        while len(indices) < num_samples:
            indices.append(indices[-1])
        return indices
    
    if train:
        # During training: random sampling but with temporal spread
        segments = np.linspace(0, n_frames - 1, num_samples + 1, dtype=int)
        indices = []
        for i in range(num_samples):
            start, end = segments[i], segments[i + 1]
            indices.append(random.randint(start, end - 1) if start < end else start)
        return sorted(indices)
    else:
        # During validation: uniform sampling
        return np.linspace(0, n_frames - 1, num_samples, dtype=int).tolist()

# ---------- Enhanced Dataset ----------
class FrameVideoDataset(Dataset):
    def __init__(self, df, train=True, frames_per_video=32, img_size=224):
        self.df = df.reset_index(drop=True)
        self.train = train
        self.frames_per_video = frames_per_video

        aug = []
        if train:
            aug += [
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
                T.RandomGrayscale(p=0.1),
                T.RandomApply([T.GaussianBlur(3)], p=0.1),
            ]
        self.tf = T.Compose([
            T.Resize((img_size, img_size)),
            *aug,
            T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        vid_dir = Path(row["video_dir"])
        frames = sorted(list(vid_dir.glob("*.jpg")) + list(vid_dir.glob("*.png")))

        # Sampling
        chosen = sample_frames(len(frames), self.frames_per_video, self.train)
        rgb_imgs, fft_imgs = [], []

        for i in chosen:
            with Image.open(frames[i]) as im:
                im = im.convert("RGB")
                rgb = self.tf(im)  
                rgb_imgs.append(rgb)
                fft_map = to_fft_tensor(rgb)  
                fft_imgs.append(fft_map)

        rgb_imgs = torch.stack(rgb_imgs, dim=0) 
        fft_imgs = torch.stack(fft_imgs, dim=0) 

        label = torch.tensor(1 if row["class"] == "fake" else 0, dtype=torch.long)
        meta = {
            "video_id": row["video_id"],
            "class": row["class"],
            "subtype": row["subtype"],
        }
        return rgb_imgs, fft_imgs, label, meta

# Split with balanced subtype representation
def make_balanced_splits(df: pd.DataFrame, val_ratio=0.2, seed=42) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Create splits with balanced representation of all manipulation types"""
    train_dfs, val_dfs = [], []
    
    # Split each subtype separately to ensure balance
    for subtype in df['subtype'].unique():
        subtype_df = df[df['subtype'] == subtype].copy()
        if len(subtype_df) > 1:
            subtype_train, subtype_val = train_test_split(
                subtype_df, test_size=val_ratio, random_state=seed, stratify=subtype_df['class']
            )
        else:
            subtype_train, subtype_val = subtype_df, subtype_df.iloc[:0]
        
        train_dfs.append(subtype_train)
        val_dfs.append(subtype_val)
    
    train_df = pd.concat(train_dfs, ignore_index=True)
    val_df = pd.concat(val_dfs, ignore_index=True)
    
    print(f"Train: {len(train_df)} videos | Val: {len(val_df)} videos")
    print(f"Train class balance: {train_df['class'].value_counts().to_dict()}")
    print(f"Val class balance: {val_df['class'].value_counts().to_dict()}")
    
    return train_df, val_df

# ---------- Build splits + loaders ----------
full = pd.read_csv(MANIFEST)
train_df, val_df = make_balanced_splits(full, val_ratio=0.2, seed=42)

train_ds = FrameVideoDataset(train_df, train=True,  frames_per_video=FRAMES_PER_VIDEO)
val_ds   = FrameVideoDataset(val_df,   train=False, frames_per_video=FRAMES_PER_VIDEO)  

class_counts = train_df["class"].value_counts().to_dict()
print(f"Class counts: {class_counts}")

# Give more weight to real class to handle imbalance
weights = []
for _, row in train_df.iterrows():
    if row["class"] == "real":
        weights.append(3.0 / class_counts["real"])  # Higher weight for real class
    else:
        weights.append(1.0 / class_counts["fake"])

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)


num_workers = 2
batch_size = 4

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
)

print(f"‚úÖ DataLoaders ready | train videos: {len(train_ds)} | val videos: {len(val_ds)}")
print(f"Class balance in train set: {train_df['class'].value_counts().to_dict()}")

Train: 4795 videos | Val: 1200 videos
Train class balance: {'fake': 3996, 'real': 799}
Val class balance: {'fake': 1000, 'real': 200}
Class counts: {'fake': 3996, 'real': 799}
‚úÖ DataLoaders ready | train videos: 4795 | val videos: 1200
Class balance in train set: {'fake': 3996, 'real': 799}


In [None]:
# Sanity check
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
batch = next(iter(train_loader))
rgb, fft, labels, metas = batch

print("rgb shape:", rgb.shape)   
print("fft shape:", fft.shape)  
print("labels shape:", labels.shape) 
print("sample labels:", labels.tolist())
print("meta[0]:", {k: v[0] for k, v in metas.items()})

# Test GPU memory
rgb = rgb.to(device, non_blocking=True)
fft = fft.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
print("Successfully moved to", device)
print(f"Batch size: {rgb.shape[0]}, Frames: {rgb.shape[1]}") 

rgb shape: torch.Size([4, 12, 3, 224, 224])
fft shape: torch.Size([4, 12, 1, 224, 224])
labels shape: torch.Size([4])
sample labels: [0, 0, 0, 0]
meta[0]: {'video_id': '195', 'class': 'real', 'subtype': 'real'}
Successfully moved to cuda
Batch size: 4, Frames: 12


In [None]:
# Dual-branch model with batch norm handling
import torch
import torch.nn as nn
import timm

class DualBranchEfficientNet(nn.Module):
    def __init__(self, embed_dim=512, num_classes=1, dropout=0.3):
        super().__init__()
        # RGB branch ‚Üí EfficientNet-B4
        self.rgb_backbone = timm.create_model(
            "tf_efficientnet_b4_ns", pretrained=True, num_classes=0, global_pool="avg"
        )
        rgb_dim = self.rgb_backbone.num_features
        self.rgb_proj = nn.Sequential(
            nn.Linear(rgb_dim, embed_dim),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )

        # FFT branch ‚Üí EfficientNet-B0
        self.fft_backbone = timm.create_model(
            "tf_efficientnet_b0_ns", pretrained=True, in_chans=1, num_classes=0, global_pool="avg"
        )
        fft_dim = self.fft_backbone.num_features
        self.fft_proj = nn.Sequential(
            nn.Linear(fft_dim, embed_dim),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )

        self.temporal_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, rgb, fft):
        B, F, _, H, W = rgb.shape

        # RGB branch
        rgb = rgb.view(B * F, 3, H, W)
        rgb_feats = self.rgb_backbone(rgb)
        rgb_feats = self.rgb_proj(rgb_feats)
        rgb_feats = rgb_feats.view(B, F, -1).permute(0, 2, 1)
        rgb_pooled = self.temporal_pool(rgb_feats).squeeze(-1)

        # FFT branch
        fft = fft.view(B * F, 1, H, W)
        fft_feats = self.fft_backbone(fft)
        fft_feats = self.fft_proj(fft_feats)
        fft_feats = fft_feats.view(B, F, -1).permute(0, 2, 1)
        fft_pooled = self.temporal_pool(fft_feats).squeeze(-1)

        # Fused features
        fused = torch.cat([rgb_pooled, fft_pooled], dim=1)
        out = self.classifier(fused)
        return out



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_score, recall_score
import numpy as np
import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize model with enhanced architecture
model = DualBranchEfficientNet(embed_dim=512, num_classes=1, dropout=0.3).to(device)

# Class weight calculation
real_count = train_df['class'].value_counts()['real']
fake_count = train_df['class'].value_counts()['fake']
pos_weight = torch.tensor([fake_count / real_count]).to(device)  
print(f"Class counts - Real: {real_count}, Fake: {fake_count}")
print(f"pos_weight: {pos_weight.item():.2f} (should be >1)")

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scaler = GradScaler()

training_history = []

def train_one_epoch(loader, epoch):
    model.train()
    epoch_loss = 0.0
    all_labels, all_probs, all_preds = [], [], []

    for batch_idx, (rgb, fft, labels, _) in enumerate(loader):
        if batch_idx % 200 == 0:
            print(f"Batch {batch_idx}/{len(loader)}")
            
        rgb, fft = rgb.to(device), fft.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        with autocast():
            logits = model(rgb, fft)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item() * rgb.size(0)

        probs = torch.sigmoid(logits).detach().cpu().numpy().flatten()
        preds = (probs > 0.5).astype(int)

        all_labels.extend(labels.cpu().numpy().flatten())
        all_probs.extend(probs)
        all_preds.extend(preds)

    # Convert to numpy arrays
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    all_preds = np.array(all_preds)
    
    acc = accuracy_score(all_labels, all_preds)
    
    # Handle AUC calculation
    if len(np.unique(all_labels)) > 1:
        auc = roc_auc_score(all_labels, all_probs)
    else:
        auc = 0.0
    
    # Handle F1 calculation
    if len(np.unique(all_preds)) > 1:
        f1 = f1_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds)
        recall = recall_score(all_labels, all_preds)
    else:
        f1 = 0.0
        precision = 0.0
        recall = 0.0
    
    return epoch_loss / len(loader.dataset), acc, auc, f1, precision, recall

@torch.no_grad()
def validate(loader, epoch):
    model.eval()
    epoch_loss = 0.0
    all_labels, all_probs, all_preds = [], [], []

    for rgb, fft, labels, _ in loader:
        rgb, fft = rgb.to(device), fft.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        with autocast():
            logits = model(rgb, fft)
            loss = criterion(logits, labels)

        epoch_loss += loss.item() * rgb.size(0)

        probs = torch.sigmoid(logits).detach().cpu().numpy().flatten()
        preds = (probs > 0.5).astype(int)

        all_labels.extend(labels.cpu().numpy().flatten())
        all_probs.extend(probs)
        all_preds.extend(preds)

    # Convert to numpy arrays
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    all_preds = np.array(all_preds)
    
    acc = accuracy_score(all_labels, all_preds)
    
    # Handle AUC calculation
    if len(np.unique(all_labels)) > 1:
        auc = roc_auc_score(all_labels, all_probs)
    else:
        auc = 0.0
    
    # Handle F1 calculation
    if len(np.unique(all_preds)) > 1:
        f1 = f1_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds)
        recall = recall_score(all_labels, all_preds)
    else:
        f1 = 0.0
        precision = 0.0
        recall = 0.0

    return epoch_loss / len(loader.dataset), acc, auc, f1, precision, recall

print("‚úÖ Training setup complete!")
print(f"Device: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"pos_weight: {pos_weight.item():.2f}x (should be ~5.0x)")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Frames per video: {FRAMES_PER_VIDEO} | Batch size: {batch_size}")
print("üíæ Will save ALL models (not just best)")

model.safetensors:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Class counts - Real: 799, Fake: 3996
pos_weight: 5.00 (should be >1)
‚úÖ Training setup complete!
Device: cuda
Model parameters: 23,656,837
pos_weight: 5.00x (should be ~5.0x)
Training samples: 4795
Validation samples: 1200
Frames per video: 12 | Batch size: 4
üíæ Will save ALL models (not just best)


In [None]:
import os
import torch

# Hyperparameters
EPOCHS = 15
PATIENCE = 4
epochs_no_improve = 0
base_save_path = "/kaggle/working/model_epoch_{}.pth"

# Ensure required variables are initialized
training_history = []
assert 'pos_weight' in globals(), "pos_weight must be defined"
assert 'train_loader' in globals(), "train_loader must be defined"
assert 'val_loader' in globals(), "val_loader must be defined"
assert 'model' in globals(), "model must be defined"
assert 'optimizer' in globals(), "optimizer must be defined"
assert 'scaler' in globals(), "scaler must be defined"

print("Starting enhanced training with class balancing...")
print(f"Real class weight: {pos_weight.item():.2f}x")
print("üíæ Will save ALL models after each epoch")

# Track best metrics for early stopping (but save all models)
best_auc = 0.0

for epoch in range(1, EPOCHS + 1):
    # Train
    train_loss, train_acc, train_auc, train_f1, train_prec, train_rec = train_one_epoch(train_loader, epoch)
    # Validate
    val_loss, val_acc, val_auc, val_f1, val_prec, val_rec = validate(val_loader, epoch)

    # Store history
    training_history.append({
        'epoch': epoch,
        'train_loss': train_loss, 'train_acc': train_acc, 'train_auc': train_auc, 'train_f1': train_f1,
        'val_loss': val_loss, 'val_acc': val_acc, 'val_auc': val_auc, 'val_f1': val_f1,
        'val_precision': val_prec, 'val_recall': val_rec
    })

    print(f"""\n[Epoch {epoch}]
    Train ‚Üí Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | AUC: {train_auc:.4f} | F1: {train_f1:.4f}
    Val   ‚Üí Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | AUC: {val_auc:.4f} | F1: {val_f1:.4f}
    Val Details ‚Üí Precision: {val_prec:.4f} | Recall: {val_rec:.4f}""")

    # Save model after every epoch
    save_path = base_save_path.format(epoch)
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scaler_state": scaler.state_dict(),
        "val_auc": val_auc,
        "val_acc": val_acc,
        "val_f1": val_f1,
        "val_precision": val_prec,
        "val_recall": val_rec,
        "training_history": training_history
    }, save_path)
    print(f"üíæ Saved model for epoch {epoch} ‚Üí AUC: {val_auc:.4f} | Path: {save_path}")

    # Track best AUC only for early stopping (not for saving)
    if val_auc > best_auc:
        best_auc = val_auc
        epochs_no_improve = 0
        print(f"üèÜ New best AUC: {val_auc:.4f}")
    else:
        epochs_no_improve += 1
        print(f"‚ö†Ô∏è No AUC improvement for {epochs_no_improve}/{PATIENCE} epochs.")

    # Early stopping (optional - keeps training but stops if no improvement)
    if epochs_no_improve >= PATIENCE:
        print(f"‚èπ Early stopping at epoch {epoch}. Best val_auc = {best_auc:.4f}")
        break

print(f"\nTraining complete. Best val_auc = {best_auc:.4f}")
print(f"üìÅ All models saved to /kaggle/working/model_epoch_*.pth")

Starting enhanced training with class balancing...
Real class weight: 5.00x
üíæ Will save ALL models after each epoch
Batch 0/1198
Batch 200/1198
Batch 400/1198
Batch 600/1198
Batch 800/1198
Batch 1000/1198

[Epoch 1]
    Train ‚Üí Loss: 1.1248 | Acc: 0.5620 | AUC: 0.7479 | F1: 0.4899
    Val   ‚Üí Loss: 2.5501 | Acc: 0.7233 | AUC: 0.9070 | F1: 0.8042
    Val Details ‚Üí Precision: 0.9799 | Recall: 0.6820
üíæ Saved model for epoch 1 ‚Üí AUC: 0.9070 | Path: /kaggle/working/model_epoch_1.pth
üèÜ New best AUC: 0.9070
Batch 0/1198
Batch 200/1198
Batch 400/1198
Batch 600/1198
Batch 800/1198
Batch 1000/1198

[Epoch 2]
    Train ‚Üí Loss: 0.7576 | Acc: 0.7938 | AUC: 0.9053 | F1: 0.6860
    Val   ‚Üí Loss: 1.1329 | Acc: 0.8908 | AUC: 0.9156 | F1: 0.9334
    Val Details ‚Üí Precision: 0.9493 | Recall: 0.9180
üíæ Saved model for epoch 2 ‚Üí AUC: 0.9156 | Path: /kaggle/working/model_epoch_2.pth
üèÜ New best AUC: 0.9156
Batch 0/1198
Batch 200/1198
Batch 400/1198
Batch 600/1198
Batch 800/1198
