In [None]:
#Extract Frames from the training videos to train the model
import cv2
import os

def extract_frames(video_path, output_folder, target_fps=1 ):
    # Open the video file
    vidcap = cv2.VideoCapture(video_path)
    
    # Get the original frame rate of the video (fps)
    original_fps = 30

    # Calculate the interval based on the target fps and the original fps
    frame_interval = int(original_fps // target_fps)

    # Get the video name (without extension) to include it in the frame filename
    video_name = os.path.splitext(os.path.basename(video_path))[0]

    success, image = vidcap.read()
    count = 0
    while success:
        # If we are at the right interval, save the frame
        if count % frame_interval == 0:
            # Include the video name in the frame filename
            frame_path = os.path.join(output_folder, f"{video_name}_frame_{count:04d}.jpg")
            cv2.imwrite(frame_path, image)
        
        success, image = vidcap.read()
        count += 1

# # Process all videos in the real and fake folders
real_video_folder = "Celeb-DF-v2/Celeb-real"
fake_video_folder = "Celeb-DF-v2/Celeb-synthesis"
real_frames_folder = "Frames/real_frames"
fake_frames_folder = "Frames/fake_frames"
# # #Replace these paths with respective paths of the directory

if not os.path.isdir(real_frames_folder) and not os.path.isdir(fake_frames_folder):
    os.makedirs(real_frames_folder, exist_ok=True)
    os.makedirs(fake_frames_folder, exist_ok=True)

    for video_name in os.listdir(real_video_folder):
        video_path = os.path.join(real_video_folder, video_name)
        extract_frames(video_path, real_frames_folder)

    for video_name in os.listdir(fake_video_folder):
        video_path = os.path.join(fake_video_folder, video_name)
        extract_frames(video_path, fake_frames_folder)

In [None]:
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class DeepFakeFrameDataset(Dataset):
    def __init__(self, real_folder, fake_folder, transform=None):
        self.real_frames = [os.path.join(real_folder, f) for f in os.listdir(real_folder)]
        self.fake_frames = [os.path.join(fake_folder, f) for f in os.listdir(fake_folder)]
        self.all_frames = self.real_frames + self.fake_frames
        self.labels = [0] * len(self.real_frames) + [1] * len(self.fake_frames)  # 0 = real, 1 = fake
        self.transform = transform

    def __len__(self):
        return len(self.all_frames)

    def __getitem__(self, idx):
        image = Image.open(self.all_frames[idx]).convert("RGB")  # Ensure 3 channels
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

# Define transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),  # Converts to [0, 1] range
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Create dataset and dataloader
dataset = DeepFakeFrameDataset(real_frames_folder, fake_frames_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
import torch
import torch.nn as nn
import timm

class Discriminator(nn.Module):
    def __init__(self, num_classes=1):
        super(Discriminator, self).__init__()
        # Load XceptionNet as the backbone
        self.backbone = timm.create_model('legacy_xception', pretrained=True, features_only=True)
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Fully connected layer for binary classification
        self.fc = nn.Linear(self.backbone.feature_info.channels()[-1], num_classes)
        
        # ReLU activation for binary classification
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Extract features from XceptionNet
        features = self.backbone(x)[-1]  # Use the last feature map
        
        # Global average pooling
        pooled = self.global_pool(features)
        pooled = pooled.view(pooled.size(0), -1)
        
        # Fully connected layer
        output = self.fc(pooled)
        
        # Sigmoid activation
        output = self.sigmoid(output)
        return output

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 256, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            # 256 x 4 x 4
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            # 128 x 8 x 8
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # 64 x 16 x 16
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output: 3 x 32 x 32 (normalized to [-1, 1])
        )

    def forward(self, x):
        return self.model(x)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Set device to GPU or CPU 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
discriminator = Discriminator().to(device)
generator = Generator().to(device)

# Loss function and optimizers
criterion = nn.BCELoss()


In [None]:
#Small Learning Rate helps model converge quickly leading to less training time
#If the model overfits, learning rate will be decreased
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))
optimizer_g = optim.Adam(generator.parameters(), lr=0.0005, betas=(0.5, 0.999))

In [None]:
# Training loop
num_epochs = 20
for epoch in range(8, num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        batch_size = real_images.size(0)

        # Move data to device
        real_images = real_images.to(device)

        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)

        # Train Discriminator
        optimizer_d.zero_grad()
        
        # Real images
        real_outputs = discriminator(real_images)
        d_loss_real = criterion(real_outputs, real_labels)
        
        # Fake images
        z = torch.randn(batch_size, 100, 1, 1, device=device)  # Random noise
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(fake_outputs, fake_labels)
        
        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()
        z = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()


        # Print losses
        if not i%5:
            print(f"Epoch: [{epoch}/{num_epochs}], Step: [{i}/{len(dataloader)}], "
                    f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    # Save models after each epoch
    torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pth")
    print(f"Models saved after epoch {epoch}")

In [None]:
torch.save(discriminator.state_dict(), "discriminator.pth")

## Validating model on Testing videos

In [None]:
def load_testing_videos(file_path, common_directory):
    video_paths = []
    labels = []
    with open(file_path, "r") as file:
        for line in file:
            label, video_name = line.strip().split(" ", 1)
            video_path = os.path.join(common_directory, video_name)
            video_paths.append(video_path)
            labels.append(int(label))
    return video_paths, labels

# Example usage
common_directory = "Celeb-DF-v2"
video_paths, labels = load_testing_videos("Celeb-Df-v2/List_of_testing_videos.txt", common_directory)

In [None]:

def extract_frames(video_path, frame_interval=30):
    frames = []
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frame_count % frame_interval == 0:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert to RGB
            frames.append(frame)
        frame_count += 1
    cap.release()
    return frames

In [None]:

# Define preprocessing transform
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def preprocess_frames(frames):
    return torch.stack([transform(frame) for frame in frames])

In [None]:
def evaluate_video(discriminator, video_path, device):
    frames = extract_frames(video_path)
    frames = preprocess_frames(frames).to(device)
    
    with torch.no_grad():
        outputs = discriminator(frames).mean(dim=(2, 3))  # Global average pooling
        predictions = (outputs > 0.5).float()  # Threshold at 0.5
        avg_prediction = predictions.mean().item()  # Average prediction for the video
    
    return avg_prediction

In [None]:
def calculate_accuracy(discriminator, video_paths, labels, device):
    correct = 0
    total = len(video_paths)
    
    for video_path, label in zip(video_paths, labels):
        avg_prediction = evaluate_video(discriminator, video_path, device)
        predicted_label = 1 if avg_prediction > 0.5 else 0
        if predicted_label == label:
            correct += 1
    
    accuracy = correct / total
    return accuracy

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator().to(device)
discriminator.load_state_dict(torch.load("discriminator.pth"))  # Load trained model
discriminator.eval()  # Set to evaluation mode 

accuracy = calculate_accuracy(discriminator, video_paths, labels, device)
print(f"Accuracy on test videos: {accuracy * 100:.2f}%")