In [None]:
!pip install torch torchvision torchaudio tqdm face_recognition pandas timm opencv-python --quiet

In [None]:
import os
import cv2
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from tqdm import tqdm
from google.colab import drive

if not os.path.exists('/content/drive'):
    drive.mount('/content/drive', force_remount=True)

CONFIG = {
    "gpu_id": 0,
    "num_workers": 2,
    "im_size": 299,
    "batch_size": 4,
    "epochs": 6,
    "lr": 1e-4,
    "epsilon": 8/255,
    "alpha": 2/255,
    "pgd_steps": 3,
    "adv_ratio": 0.5,
    "sequence_length": 5,
    "data_root": "/content/drive/MyDrive/deepfake_detection_project/Dataset_split/baseline_splits_madry",
    "checkpoint_dir": "/content/drive/MyDrive/deepfake_detection_project/Madry-Style_training/checkpoints"
}

device = torch.device(f"cuda:{CONFIG['gpu_id']}" if torch.cuda.is_available() else "cpu")
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CONFIG['im_size'], CONFIG['im_size'])),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

val_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CONFIG['im_size'], CONFIG['im_size'])),
    transforms.ToTensor(),
])

class DeepfakeVideoDataset(Dataset):
    def __init__(self, txt_file, seq_len, transform=None):
        path = os.path.join(CONFIG['data_root'], txt_file)
        if not os.path.exists(path):
            self.paths = []
            print(f"File not found: {path}")
        else:
            with open(path, 'r') as f:
                self.paths = [l.strip() for l in f.readlines() if l.strip()]
        self.seq_len = seq_len
        self.transform = transform

    def __len__(self): return len(self.paths)

    def get_label(self, path):
        return 0 if "original" in path.lower() else 1

    def __getitem__(self, idx):
        path = self.paths[idx]
        label = self.get_label(path)
        try:
            cap = cv2.VideoCapture(path)
            cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            
            if cnt > self.seq_len:
                indices = sorted(random.sample(range(cnt), self.seq_len))
            else:
                indices = list(range(cnt))
            
            frames = []
            for i in range(cnt):
                ret, frame = cap.read()
                if not ret: break
                if i in indices:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    if self.transform:
                        frame = self.transform(frame)
                    frames.append(frame)
                    if len(frames) >= self.seq_len: break
            cap.release()
        except:
            frames = []
        
        if len(frames) == 0:
            return torch.zeros((self.seq_len, 3, CONFIG['im_size'], CONFIG['im_size'])), label
        
        while len(frames) < self.seq_len:
            frames.append(frames[-1])
            
        return torch.stack(frames), label

class VideoXception(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('xception', pretrained=True, num_classes=2)
    
    def normalize(self, x):
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
        std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
        return (x - mean) / std

    def forward(self, x):
        b, s, c, h, w = x.shape
        x = x.view(b*s, c, h, w)
        x = self.normalize(x)
        logits = self.backbone(x)
        return torch.mean(logits.view(b, s, -1), dim=1)

def pgd_attack(model, x, y, eps, alpha, steps):
    model.eval()
    x_adv = x.clone().detach().requires_grad_(True)
    for _ in range(steps):
        out = model(x_adv)
        loss = F.cross_entropy(out, y)
        model.zero_grad()
        loss.backward()
        with torch.no_grad():
            x_adv += alpha * x_adv.grad.sign()
            delta = torch.clamp(x_adv - x, -eps, eps)
            x_adv = torch.clamp(x + delta, 0, 1)
            x_adv.requires_grad_(True)
    model.train()
    return x_adv.detach()

def main():
    print("loading data...")
    train_ds = DeepfakeVideoDataset("train.txt", CONFIG['sequence_length'], train_transforms)
    val_ds = DeepfakeVideoDataset("val.txt", CONFIG['sequence_length'], val_transforms)
    
    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'])
    val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'])
    
    print(f"train size: {len(train_ds)}, val size: {len(val_ds)}")

    model = VideoXception().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    
    weights = torch.tensor([3.0, 1.0]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weights)
    
    best_score = 0.0
    
    print("starting training loop")
    
    for epoch in range(1, CONFIG['epochs'] + 1):
        model.train()
        t_corr, t_total = 0, 0
        
        loop = tqdm(train_loader)
        for x, y in loop:
            x, y = x.to(device), y.to(device)
            bs = x.size(0)
            
            n_adv = int(bs * CONFIG['adv_ratio'])
            if n_adv > 0:
                idx = torch.randperm(bs)[:n_adv]
                x[idx] = pgd_attack(model, x[idx], y[idx], CONFIG['epsilon'], CONFIG['alpha'], CONFIG['pgd_steps'])
            
            out = model(x)
            loss = criterion(out, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            t_corr += out.argmax(1).eq(y).sum().item()
            t_total += bs
            loop.set_description(f"Epoch {epoch}")
            loop.set_postfix(loss=loss.item())
            
        print(f"epoch {epoch} train acc: {100.*t_corr/t_total:.2f}")
        
        model.eval()
        c_corr, r_corr, v_total = 0, 0, 0
        print("validating...")
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                v_total += x.size(0)
                
                c_corr += model(x).argmax(1).eq(y).sum().item()
                
                with torch.enable_grad():
                    x_adv = pgd_attack(model, x, y, CONFIG['epsilon'], CONFIG['alpha'], steps=2)
                r_corr += model(x_adv).argmax(1).eq(y).sum().item()
        
        acc_c = 100.*c_corr/v_total
        acc_r = 100.*r_corr/v_total
        score = (acc_c + acc_r) / 2
        
        print(f"clean acc: {acc_c:.2f}, robust acc: {acc_r:.2f}, score: {score:.2f}")
        
        if score > best_score:
            best_score = score
            torch.save(model.state_dict(), os.path.join(CONFIG['checkpoint_dir'], "best_madry_model.pth"))
            print("saved best model")
        
        torch.save(model.state_dict(), os.path.join(CONFIG['checkpoint_dir'], f"epoch_{epoch}.pth"))

if __name__ == "__main__":
    main()