In [2]:
# IMPORTS

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import os
from glob import glob
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split
import json

In [3]:
# ! pip install efficientnet_pytorch

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on", device)

Running on cuda


In [5]:
torch.cuda.get_device_name(0)

'Tesla T4'

In [6]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

2.6.0+cu124
12.4
True


In [7]:
# mount drive

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Load and validate videos
def frame_extract(path):
    video = cv2.VideoCapture(path)
    success = True
    while success:
        success, image = video.read()
        if success:
            yield image

    video.release()

# Validate video by trying to extract and transform 20 frames.


def validate_video(video_path, transform, count=20):
    all_frames = [frame for frame in frame_extract(video_path) if frame is not None]

    if len(all_frames) < count:
        raise ValueError(f"Not enough frames in video: {video_path} (Found {len(all_frames)})")

    selected_frames = random.sample(all_frames, count)

    transformed_frames = [transform(frame) for frame in selected_frames]
    frames_tensor = torch.stack(transformed_frames)

    return frames_tensor

# parameters
image_size = 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transforms = T.Compose(
    [
        T.ToPILImage(),
        T.Resize((image_size, image_size)),
        T.ToTensor(),
        T.Normalize(mean, std)
    ]
)


In [None]:
# !find "/content/drive/My Drive" -name "*celeb_df_face_cropped*"
# !ls "/content/drive/My Drive/"


In [None]:
# DATASET FETCH

# Paths to real and fake video folders on Google Drive
celeb_df_real_path = '/content/drive/My Drive/celeb_df_face_cropped/real_face_only224'
celeb_df_fake_path = '/content/drive/My Drive/celeb_df_face_cropped/fake_face_only224'

# Get all .mp4 file paths
real_videos = sorted(glob(f"{celeb_df_real_path}/*.mp4"))
fake_videos = sorted(glob(f"{celeb_df_fake_path}/*.mp4"))

In [None]:
print(f"Total Real Videos: {len(real_videos)}")
print(f"Total Fake Videos: {len(fake_videos)}")
# print("Sample real video:", real_videos[0] if real_videos else "None")
# print("Sample fake video:", fake_videos[0] if fake_videos else "None")


Total Real Videos: 588
Total Fake Videos: 5634


In [None]:
# real_count = 0
# valid_real_videos= []
# valid_fake_videos= []

# def valid_video_list(video_files, train_transforms, valid_videos):
#     for video_path in tqdm(video_files):
#         try:
#             validate_video(video_path, train_transforms)
#             video_name = os.path.basename(video_path)
#             valid_videos.append(video_name)
#         except:
#             continue


# valid_video_list(real_videos, train_transforms, valid_real_videos)
# valid_video_list(fake_videos, train_transforms, valid_fake_videos)

# print(f"Valid real videos: {len(valid_real_videos)}")
# print(f"Valid fake videos: {len(valid_fake_videos)}")

# RUN FROM HERE

In [8]:
# Paths where you want to save
save_real_path = "/content/drive/My Drive/celeb_df_face_cropped/valid_real_videos.json"
save_fake_path = "/content/drive/My Drive/celeb_df_face_cropped/valid_fake_videos.json"


In [9]:
# Already Saved -----> DO NOT TOUCH

# # Save lists
# with open(save_real_path, 'w') as f:
#     json.dump(valid_real_videos, f)

# with open(save_fake_path, 'w') as f:
#     json.dump(valid_fake_videos, f)

# print("Lists saved successfully.")

In [10]:
# Load lists from JSON
valid_real_videos = []
valid_fake_videos = []

with open(save_real_path, 'r') as f:
    valid_real_videos = json.load(f)

with open(save_fake_path, 'r') as f:
    valid_fake_videos = json.load(f)

# Paths to real and fake video folders on Google Drive
celeb_df_real_path = '/content/drive/My Drive/celeb_df_face_cropped/real_face_only224'
celeb_df_fake_path = '/content/drive/My Drive/celeb_df_face_cropped/fake_face_only224'

# Reconstruct full paths
valid_real_videos_path = [os.path.join(celeb_df_real_path, name) for name in valid_real_videos]
valid_fake_videos_path = [os.path.join(celeb_df_fake_path, name) for name in valid_fake_videos]

print(f"Total real videos for training: {len(valid_real_videos_path)}")
print(f"Total fake videos for training: {len(valid_fake_videos_path)}")



Total real videos for training: 587
Total fake videos for training: 5635


In [11]:
valid_real_videos_path[0]
valid_fake_videos_path[0]

'/content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id0_id16_0000.mp4'

In [12]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [13]:
class DeepFakeDataset(Dataset):
    def __init__(self, video_paths, labels, transform=None, num_frames=16):
        self.video_paths = video_paths
        self.labels = labels
        self.transform = transform
        self.num_frames = num_frames

    def read_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_indices = np.linspace(0, total_frames - 1, self.num_frames).astype(int)   #
        frames = []

        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if not ret:
                continue
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            if self.transform:
                frame = self.transform(image=frame)['image']
            frames.append(frame)

        cap.release()

        if len(frames) < self.num_frames:
            # pad missing frames with black images
            for _ in range(self.num_frames - len(frames)):
                frames.append(torch.zeros_like(frames[0]))

        return torch.stack(frames)

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        try:
            video_tensor = self.read_frames(self.video_paths[idx])
            label = torch.tensor(self.labels[idx], dtype=torch.float32)
            return video_tensor, label
        except Exception as e:
            print(f"Failed loading video: {self.video_paths[idx]}, Error: {e}")
            return self.__getitem__((idx + 1) % len(self))

In [14]:
# DATA TRANSFORMS
transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(),
    ToTensorV2(),
])

In [15]:
# Combine video paths and labels
video_paths = valid_real_videos_path + valid_fake_videos_path
labels = [0]*len(valid_real_videos_path) + [1]*len(valid_fake_videos_path)

# Use train_test_split to shuffle and split
train_paths, val_paths, train_labels, val_labels = train_test_split(
    video_paths, labels, test_size=0.2, stratify=labels, random_state=10
)

# Create datasets and loaders
train_dataset = DeepFakeDataset(train_paths, train_labels, transform)
val_dataset = DeepFakeDataset(val_paths, val_labels, transform)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2, pin_memory=True)


In [16]:
# MODEL COMPONENTS
class TemporalAttention(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.attention = nn.Linear(feature_dim, 1)

    def forward(self, x):
        # x: (batch, time, features)
        weights = F.softmax(self.attention(x), dim=1)
        return torch.sum(weights * x, dim=1)


class DeepFakeDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = EfficientNet.from_pretrained('efficientnet-b0')
        self.feature_extractor._fc = nn.Identity()
        self.lstm = nn.LSTM(input_size=1280, hidden_size=256, num_layers=1, batch_first=True, bidirectional=True)
        self.attention = TemporalAttention(512)
        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        with torch.no_grad():  # freeze feature extractor to reduce memory
            feats = self.feature_extractor(x)
        feats = feats.view(B, T, -1)
        lstm_out, _ = self.lstm(feats)
        attn_out = self.attention(lstm_out)
        return self.classifier(attn_out).squeeze(1)


In [17]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0

    loop = tqdm(dataloader, desc="Training", leave=False)
    for inputs, labels in loop:
        inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        preds = (torch.sigmoid(outputs) > 0.5).float()
        correct += (preds == labels).sum().item()
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / len(dataloader.dataset)
    return epoch_loss, epoch_acc


def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0

    loop = tqdm(dataloader, desc="Validation", leave=False)
    with torch.no_grad():
        for inputs, labels in loop:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            preds = (torch.sigmoid(outputs) > 0.5).float()
            correct += (preds == labels).sum().item()
            running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / len(dataloader.dataset)
    return epoch_loss, epoch_acc

In [18]:
# START TRAINING
model = DeepFakeDetector().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

Loaded pretrained weights for efficientnet-b0


In [19]:
model

DeepFakeDetector(
  (feature_extractor): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dSt

In [20]:
# Prepare save path
model_dir = '/content/drive/My Drive/trained_model'
os.makedirs(model_dir, exist_ok=True)
save_path = os.path.join(model_dir, 'best_model_celeb_df.pth')

# Training loop
EPOCHS = 5
best_val_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)

    print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_accuracy': best_val_acc
        }, save_path)
        print(f"💾 Model saved to Google Drive with Val Accuracy: {best_val_acc:.4f}")


Epoch 1/5


Training:   0%|          | 0/2489 [00:00<?, ?it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id56_id54_0003.mp4, Error: list index out of range
Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id55_id58_0007.mp4, Error: list index out of range


Validation:  29%|██▉       | 180/623 [01:33<04:08,  1.78it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/real_face_only224/id1_0003.mp4, Error: list index out of range


Validation:  63%|██████▎   | 394/623 [03:28<02:04,  1.83it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id31_id6_0001.mp4, Error: list index out of range




Train Loss: 0.3091, Accuracy: 0.9050
Val   Loss: 0.2378, Accuracy: 0.9084
💾 Model saved to Google Drive with Val Accuracy: 0.9084

Epoch 2/5


Validation:  28%|██▊       | 177/623 [00:28<01:05,  6.79it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/real_face_only224/id1_0003.mp4, Error: list index out of range


Validation:  63%|██████▎   | 395/623 [01:04<00:47,  4.82it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id31_id6_0001.mp4, Error: list index out of range




Train Loss: 0.2641, Accuracy: 0.9076
Val   Loss: 0.2263, Accuracy: 0.9229
💾 Model saved to Google Drive with Val Accuracy: 0.9229

Epoch 3/5


Validation:  28%|██▊       | 177/623 [00:29<01:15,  5.92it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/real_face_only224/id1_0003.mp4, Error: list index out of range


Validation:  63%|██████▎   | 395/623 [01:04<00:33,  6.75it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id31_id6_0001.mp4, Error: list index out of range




Train Loss: 0.2483, Accuracy: 0.9098
Val   Loss: 0.2197, Accuracy: 0.9293
💾 Model saved to Google Drive with Val Accuracy: 0.9293

Epoch 4/5


Validation:  28%|██▊       | 177/623 [00:29<01:14,  5.99it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/real_face_only224/id1_0003.mp4, Error: list index out of range


Validation:  63%|██████▎   | 395/623 [01:05<00:33,  6.81it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id31_id6_0001.mp4, Error: list index out of range




Train Loss: 0.2349, Accuracy: 0.9154
Val   Loss: 0.2171, Accuracy: 0.9317
💾 Model saved to Google Drive with Val Accuracy: 0.9317

Epoch 5/5


Validation:  29%|██▊       | 178/623 [00:28<01:05,  6.77it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/real_face_only224/id1_0003.mp4, Error: list index out of range


Validation:  63%|██████▎   | 394/623 [01:02<00:32,  6.95it/s]

Failed loading video: /content/drive/My Drive/celeb_df_face_cropped/fake_face_only224/id31_id6_0001.mp4, Error: list index out of range




Train Loss: 0.2321, Accuracy: 0.9182
Val   Loss: 0.2058, Accuracy: 0.9333
💾 Model saved to Google Drive with Val Accuracy: 0.9333
