<a href="https://colab.research.google.com/github/sarveshshirulkar/Deepfake-Detection-Model/blob/main/Vision_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Fix the sympy/torch conflict
!pip install --upgrade sympy==1.12
!pip install --force-reinstall torch torchvision torchaudio

Collecting sympy==1.12
  Downloading sympy-1.12-py3-none-any.whl.metadata (12 kB)
Downloading sympy-1.12-py3-none-any.whl (5.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m72.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sympy
  Attempting uninstall: sympy
    Found existing installation: sympy 1.13.1
    Uninstalling sympy-1.13.1:
      Successfully uninstalled sympy-1.13.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.
torch 2.6.0+cu124 requires nvidia-cuda-cupti-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-cupti-cu12 12.5.82 which is incompatible.
torch 2

Collecting torch
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Downloading typing_extensions-4.13.2-py3-none-any.whl.metadata (3.0 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidi

In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from tqdm import tqdm
from PIL import Image
from transformers import ViTModel, ViTConfig

# Configuration
class Config:
    # Model
    image_size = 224
    patch_size = 16
    num_classes = 2

    # Training
    frames_per_video = 10  # Was 3 → More temporal info
    batch_size = 8        # Was 16 → Better gradient updates
    epochs = 20           # Was 5 → Needs more training
    lr = 0.0001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Video Processing
    frames_per_video = 3  # Reduced for short videos

# 1. Video Processing Functions
def extract_frames(video_path, num_frames):
    """Extract uniformly spaced frames from video"""
    cap = cv2.VideoCapture(video_path)
    frames = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    for i in range(num_frames):
        frame_idx = i * (total_frames // num_frames)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))

    cap.release()
    return frames

# 2. Dataset Class
class VideoDataset(Dataset):
    def __init__(self, real_dir, fake_dir, transform=None):
        self.samples = []
        self.transform = transform

        # Process real videos
        for vid in os.listdir(real_dir):
            frames = extract_frames(os.path.join(real_dir, vid), Config.frames_per_video)
            self.samples.extend([(frame, 0) for frame in frames])

        # Process fake videos
        for vid in os.listdir(fake_dir):
            frames = extract_frames(os.path.join(fake_dir, vid), Config.frames_per_video)
            self.samples.extend([(frame, 1) for frame in frames])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, label = self.samples[idx]
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.long)

# 3. Vision Transformer Model
class DeepfakeViT(nn.Module):
    def __init__(self):
        super().__init__()
        config = ViTConfig(
            image_size=Config.image_size,
            patch_size=Config.patch_size,
            num_classes=Config.num_classes
        )
        self.vit = ViTModel(config)
        # self.classifier = nn.Sequential(
        #     nn.Linear(768, 256),
        #     nn.ReLU(),
        #     nn.Dropout(0.2),
        #     nn.Linear(256, 2)
        # )
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),  # Deeper
            nn.GELU(),            # Better than ReLU
            nn.Dropout(0.3),      # More regularization
            nn.Linear(512, 2)
)

    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0]
        return self.classifier(cls_token)

# 4. Training Function
def train_model(real_dir, fake_dir):
    # Data preparation
    # transform = transforms.Compose([
    #     transforms.Resize((Config.image_size, Config.image_size)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    # ])
    transform = transforms.Compose([
    transforms.Resize((Config.image_size, Config.image_size)),
    transforms.RandomHorizontalFlip(),  # New
    transforms.ColorJitter(0.1, 0.1, 0.1),  # New
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

    dataset = VideoDataset(real_dir, fake_dir, transform)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.batch_size)

    # Model setup
    model = DeepfakeViT().to(Config.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.lr)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(Config.epochs):
        model.train()
        train_loss = 0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            inputs, labels = inputs.to(Config.device), labels.to(Config.device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(Config.device), labels.to(Config.device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds)
        auc = roc_auc_score(all_labels, all_preds)
        tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()

        print(f"\nEpoch {epoch+1} Results:")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}")
        print(f"Val Loss: {val_loss/len(val_loader):.4f}")
        print(f"Accuracy: {accuracy:.4f} | F1: {f1:.4f} | AUC: {auc:.4f}")
        print(f"Confusion Matrix:\nTP: {tp} | FP: {fp}\nFN: {fn} | TN: {tn}")

    return model

# 5. Test Function
def test_video(model, video_path):
    transform = transforms.Compose([
        transforms.Resize((Config.image_size, Config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    frames = extract_frames(video_path, Config.frames_per_video)
    model.eval()
    fake_probs = []

    with torch.no_grad():
        for frame in frames:
            img_tensor = transform(frame).unsqueeze(0).to(Config.device)
            output = model(img_tensor)
            prob = torch.softmax(output, dim=1)[0, 1].item()
            fake_probs.append(prob)

    avg_prob = np.mean(fake_probs)
    prediction = "FAKE" if avg_prob > 0.5 else "REAL"

    print(f"\nTest Results for {video_path}:")
    print(f"Prediction: {prediction} (Confidence: {avg_prob:.4f})")
    print(f"Frame-level probabilities: {[round(p, 4) for p in fake_probs]}")
    return prediction, avg_prob

# Example Usage
if __name__ == "__main__":
    # 1. Train the model (update paths)
    print("Starting training...")
    model = train_model(
        real_dir="/content/drive/MyDrive/Documents/SDFVD/SDFVD/videos_real",
        fake_dir="/content/drive/MyDrive/Documents/SDFVD/SDFVD/videos_fake"
    )
    torch.save(model.state_dict(), "deepfake_vit.pth")

    # 2. Test a video
    print("\nTesting sample video...")
    test_model = DeepfakeViT().to(Config.device)
    test_model.load_state_dict(torch.load("deepfake_vit.pth"))
    test_video(test_model, "/content/dfvideo.mp4")

Starting training...


Epoch 1: 100%|██████████| 32/32 [00:12<00:00,  2.66it/s]



Epoch 1 Results:
Train Loss: 0.7794
Val Loss: 0.7106
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 2: 100%|██████████| 32/32 [00:11<00:00,  2.89it/s]



Epoch 2 Results:
Train Loss: 0.7157
Val Loss: 0.7034
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 3: 100%|██████████| 32/32 [00:11<00:00,  2.82it/s]



Epoch 3 Results:
Train Loss: 0.7383
Val Loss: 0.7018
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 4: 100%|██████████| 32/32 [00:10<00:00,  3.02it/s]



Epoch 4 Results:
Train Loss: 0.7112
Val Loss: 0.6948
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 5: 100%|██████████| 32/32 [00:10<00:00,  3.12it/s]



Epoch 5 Results:
Train Loss: 0.7096
Val Loss: 0.6945
Accuracy: 0.4531 | F1: 0.0541 | AUC: 0.4668
Confusion Matrix:
TP: 1 | FP: 3
FN: 32 | TN: 28


Epoch 6: 100%|██████████| 32/32 [00:10<00:00,  3.11it/s]



Epoch 6 Results:
Train Loss: 0.7009
Val Loss: 0.6964
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 7: 100%|██████████| 32/32 [00:10<00:00,  3.07it/s]



Epoch 7 Results:
Train Loss: 0.7001
Val Loss: 0.6938
Accuracy: 0.5156 | F1: 0.6804 | AUC: 0.5000
Confusion Matrix:
TP: 33 | FP: 31
FN: 0 | TN: 0


Epoch 8: 100%|██████████| 32/32 [00:10<00:00,  2.95it/s]



Epoch 8 Results:
Train Loss: 0.7058
Val Loss: 0.7122
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 9: 100%|██████████| 32/32 [00:10<00:00,  3.00it/s]



Epoch 9 Results:
Train Loss: 0.6992
Val Loss: 0.6975
Accuracy: 0.5156 | F1: 0.6804 | AUC: 0.5000
Confusion Matrix:
TP: 33 | FP: 31
FN: 0 | TN: 0


Epoch 10: 100%|██████████| 32/32 [00:10<00:00,  3.00it/s]



Epoch 10 Results:
Train Loss: 0.7002
Val Loss: 0.7017
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 11: 100%|██████████| 32/32 [00:10<00:00,  3.02it/s]



Epoch 11 Results:
Train Loss: 0.7138
Val Loss: 0.6964
Accuracy: 0.4531 | F1: 0.0000 | AUC: 0.4677
Confusion Matrix:
TP: 0 | FP: 2
FN: 33 | TN: 29


Epoch 12: 100%|██████████| 32/32 [00:10<00:00,  3.01it/s]



Epoch 12 Results:
Train Loss: 0.6992
Val Loss: 0.6961
Accuracy: 0.4375 | F1: 0.5814 | AUC: 0.4272
Confusion Matrix:
TP: 25 | FP: 28
FN: 8 | TN: 3


Epoch 13: 100%|██████████| 32/32 [00:10<00:00,  3.02it/s]



Epoch 13 Results:
Train Loss: 0.7008
Val Loss: 0.7023
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31


Epoch 14: 100%|██████████| 32/32 [00:10<00:00,  3.08it/s]



Epoch 14 Results:
Train Loss: 0.7191
Val Loss: 0.6967
Accuracy: 0.4688 | F1: 0.0000 | AUC: 0.4839
Confusion Matrix:
TP: 0 | FP: 1
FN: 33 | TN: 30


Epoch 15: 100%|██████████| 32/32 [00:10<00:00,  3.08it/s]



Epoch 15 Results:
Train Loss: 0.6988
Val Loss: 0.6976
Accuracy: 0.5156 | F1: 0.6804 | AUC: 0.5000
Confusion Matrix:
TP: 33 | FP: 31
FN: 0 | TN: 0


Epoch 16: 100%|██████████| 32/32 [00:10<00:00,  3.03it/s]



Epoch 16 Results:
Train Loss: 0.6977
Val Loss: 0.6962
Accuracy: 0.4688 | F1: 0.0556 | AUC: 0.4829
Confusion Matrix:
TP: 1 | FP: 2
FN: 32 | TN: 29


Epoch 17: 100%|██████████| 32/32 [00:10<00:00,  3.02it/s]



Epoch 17 Results:
Train Loss: 0.7112
Val Loss: 0.6955
Accuracy: 0.5156 | F1: 0.6804 | AUC: 0.5000
Confusion Matrix:
TP: 33 | FP: 31
FN: 0 | TN: 0


Epoch 18: 100%|██████████| 32/32 [00:10<00:00,  3.01it/s]



Epoch 18 Results:
Train Loss: 0.6952
Val Loss: 0.6958
Accuracy: 0.5156 | F1: 0.6804 | AUC: 0.5000
Confusion Matrix:
TP: 33 | FP: 31
FN: 0 | TN: 0


Epoch 19: 100%|██████████| 32/32 [00:10<00:00,  3.03it/s]



Epoch 19 Results:
Train Loss: 0.6974
Val Loss: 0.7116
Accuracy: 0.4688 | F1: 0.0556 | AUC: 0.4829
Confusion Matrix:
TP: 1 | FP: 2
FN: 32 | TN: 29


Epoch 20: 100%|██████████| 32/32 [00:10<00:00,  3.00it/s]



Epoch 20 Results:
Train Loss: 0.7035
Val Loss: 0.7012
Accuracy: 0.4844 | F1: 0.0000 | AUC: 0.5000
Confusion Matrix:
TP: 0 | FP: 0
FN: 33 | TN: 31

Testing sample video...

Test Results for /content/dfvideo.mp4:
Prediction: REAL (Confidence: 0.4669)
Frame-level probabilities: [0.4671, 0.4666, 0.4669]


             Input Token Embeddings (197 x D)
                          ↓
          ┌───────────────────────────────────┐
          │  Multi-Head Self-Attention (MHSA) │
          └───────────────────────────────────┘
                          ↓
             Add & LayerNorm (Residual #1)
                          ↓
     ┌─────────────────────────────────────────┐
     │   Feedforward Network (MLP / FFN block) │
     └─────────────────────────────────────────┘
                          ↓
             Add & LayerNorm (Residual #2)
                          ↓
                  Output Token Embeddings
