<a href="https://colab.research.google.com/github/rajkumar9474/project_foff/blob/main/multimodal_deepfake_detection(updated).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!unzip -q '/content/drive/MyDrive/archive.zip' -d '/content/archive'

unzip:  cannot find or open /content/drive/MyDrive/archive.zip, /content/drive/MyDrive/archive.zip.zip or /content/drive/MyDrive/archive.zip.ZIP.


In [None]:
!pip install numpy torch facenet-pytorch moviepy librosa opencv-python diffusers

Collecting facenet-pytorch
  Downloading facenet_pytorch-2.6.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 k

#**CONFIGURATION**

In [None]:
# Import necessary libraries that will be used later.
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader
# from diffusers import DDPMPipeline  # Diffusion preprocessor
from facenet_pytorch import MTCNN
import cv2
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


In [None]:
from moviepy.editor import VideoFileClip


In [None]:
# Global configuration variables
FRAME_SKIP = 2
VIDEO_FRAMES = 16
VIDEO_SIZE = 224
AUDIO_SAMPLE_RATE = 16000
BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# For quick testing, use a flag to disable heavy diffusion processing.
USE_DIFFUSION = False  # Set to False for faster testing in Colab.

In [None]:
print("Running on:", DEVICE)

Running on: cuda


#**DATASET** **CLASS**

In [None]:
class FakeAVCelebDataset(Dataset):
    """
    Dataset loader for FakeAVCeleb.

    Expected structure:
      DATA_ROOT/
        real/     -> contains one real video (filename without "fake")
        fake/    -> contains multiple fake videos
    For testing, you can limit the number of samples.
    """
    # def __init__(self, data_root, transform=None, audio_transform=None, max_samples=None, metadata_list=None):
    #     self.data_root = data_root
    #     self.transform = transform
    #     self.audio_transform = audio_transform
    #     if metadata_list is not None:
    #         self.samples = metadata_list  # Use provided list (for splits)
    #     else:
    #         self.samples = []
    #         # Process the "real" folder:
    #         real_folder = os.path.join(data_root, "real")
    #         if os.path.isdir(real_folder):
    #             for subdir in os.listdir(real_folder):
    #                 subdir_path = os.path.join(real_folder, subdir)
    #                 if os.path.isdir(subdir_path):
    #                     for file in os.listdir(subdir_path):
    #                         if file.lower().endswith(".mp4") and "fake" not in file.lower():
    #                             video_path = os.path.join(subdir_path, file)
    #                             self.samples.append({"video_path": video_path, "label": 0})
    #                             break
    #         # Process the "fake" folder:
    #         fake_folder = os.path.join(data_root, "fake")
    #         if os.path.isdir(fake_folder):
    #             for subdir in os.listdir(fake_folder):
    #                 subdir_path = os.path.join(fake_folder, subdir)
    #                 if os.path.isdir(subdir_path):
    #                     for file in os.listdir(subdir_path):
    #                         if file.lower().endswith(".mp4"):
    #                             video_path = os.path.join(subdir_path, file)
    #                             self.samples.append({"video_path": video_path, "label": 1})
    #     if max_samples is not None and len(self.samples) > max_samples:
    #         self.samples = self.samples[:max_samples]
    #     print(f"Total samples loaded: {len(self.samples)}")

    def __init__(self, data_root, transform=None, audio_transform=None, max_samples=None, metadata_list=None):
        self.data_root = data_root
        self.transform = transform
        self.audio_transform = audio_transform

        if metadata_list is not None:
            self.samples = metadata_list  # Use provided list (for splits)
        else:
            self.samples = []

            # Process the "real" folder:
            real_folder = os.path.join(data_root, "real")
            if os.path.isdir(real_folder):
                for file in os.listdir(real_folder):
                    if file.lower().endswith(".mp4") and "fake" not in file.lower():
                        video_path = os.path.join(real_folder, file)
                        self.samples.append({"video_path": video_path, "label": 0})

            # Process the "fake" folder:
            fake_folder = os.path.join(data_root, "fake")
            if os.path.isdir(fake_folder):
                for file in os.listdir(fake_folder):
                    if file.lower().endswith(".mp4"):
                        video_path = os.path.join(fake_folder, file)
                        self.samples.append({"video_path": video_path, "label": 1})

        if max_samples is not None and len(self.samples) > max_samples:
            self.samples = self.samples[:max_samples]

        print(f"Total samples loaded: {len(self.samples)}")


    def __len__(self):
        return len(self.samples)

    # def __getitem__(self, idx):
    #     sample = self.samples[idx]
    #     video_path = sample["video_path"]
    #     label = sample["label"]
    #     try:
    #         video = self.extract_frames(video_path)
    #     except Exception as e:
    #         print(f"Error extracting frames from {video_path}: {e}")
    #         video = torch.zeros(VIDEO_FRAMES, 3, VIDEO_SIZE, VIDEO_SIZE)
    #     try:
    #         audio = self.extract_audio(video_path)
    #     except Exception as e:
    #         print(f"Error extracting audio from {video_path}: {e}")
    #         audio = torch.zeros(1, AUDIO_SAMPLE_RATE)
    #     if self.transform and not isinstance(video, torch.Tensor):
    #         video = self.transform(video) # Assuming your transform can handle video directly
    #     elif self.transform: # If video is already a Tensor, loop to apply per frame
    #         video = torch.stack([self.transform(frame) for frame in video]) # If it was a tensor then loop through frames
    #     return video, audio, torch.tensor(label, dtype=torch.float32)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        video_path = sample["video_path"]
        label = sample["label"]
        try:
            video = self.extract_frames(video_path)
        except Exception as e:
            print(f"Error extracting frames from {video_path}: {e}")
            video = torch.zeros(VIDEO_FRAMES, 3, VIDEO_SIZE, VIDEO_SIZE)
        try:
            audio = self.extract_audio(video_path)
        except Exception as e:
            print(f"Error extracting audio from {video_path}: {e}")
            audio = torch.zeros(1, AUDIO_SAMPLE_RATE)
        # Check if video is already a tensor and the transform is set
        if self.transform and not isinstance(video, torch.Tensor):
            video = self.transform(video) # Assuming your transform can handle video directly
        # If video is already a Tensor and the transform is set, apply it frame-by-frame.
        elif self.transform and isinstance(video, torch.Tensor):
            video = torch.stack([self.transform(transforms.ToPILImage()(frame)) for frame in video])  # Convert to PIL Image before applying transform
        return video, audio, torch.tensor(label, dtype=torch.float32)


    def extract_frames(self,video_path):
        cap = cv2.VideoCapture(video_path)
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float()  # Convert to tensor
            frame_tensor = F.interpolate(frame_tensor.unsqueeze(0), size=(VIDEO_SIZE, VIDEO_SIZE), mode='bilinear').squeeze(0)
            frames.append(frame_tensor / 255.0)
            if len(frames) >= VIDEO_FRAMES:
                break
        cap.release()
        frames = torch.stack(frames)
        if frames.shape[0] < VIDEO_FRAMES:
            pad = VIDEO_FRAMES - frames.shape[0]
            frames = torch.cat([frames, frames[-1].unsqueeze(0).repeat(pad, 1, 1, 1)], dim=0)
        return frames

    def extract_audio(self, video_path):
        """
        Loads the audio from the video file using MoviePy (no torchaudio).
        Returns a mono numpy array of shape [samples].
        """
        clip = VideoFileClip(video_path)
        if clip.audio is None:
            # No audio track → return one second of silence
            return np.zeros(AUDIO_SAMPLE_RATE, dtype=np.float32)

        # Extract at the desired sample rate
        audio_np = clip.audio.to_soundarray(fps=AUDIO_SAMPLE_RATE)  # [n_samples, channels]
        clip.audio.reader.close_proc()
        # Convert to mono by averaging channels
        if audio_np.ndim == 2:
            audio_mono = audio_np.mean(axis=1)
        else:
            audio_mono = audio_np
        return audio_mono.astype(np.float32)


#**DIFFUSION** **PREPROCESSOR**

In [None]:
class DiffusionPreprocessor(nn.Module):
    """
    Uses a pretrained DDPM to refine video frames.
    Audio is left as identity (replace with DiffWave if desired).
    """
    def __init__(self, for_audio=True, for_video=True, device=DEVICE):
        super().__init__()
        self.for_audio = for_audio
        self.for_video = for_video
        self.device = device

        # Load pretrained DDPM (e.g., google/ddpm-celebahq-256)
        self.ddpm = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
        self.ddpm.to(device)
        self.ddpm.enable_model_cpu_offload()

    def preprocess_image_frames(self, frames):
        """
        frames: [B, T, C, H, W] tensor in [0,1]
        Returns: denoised frames (resized to 256x256 as required)
        """
        B, T, C, H, W = frames.shape
        frames = frames.to(self.device)
        resize = transforms.Resize((256, 256))
        denoised_frames = []
        for t in range(T):
            frame_t = frames[:, t]  # shape: [B, C, H, W]
            frame_t = torch.stack([resize(img) for img in frame_t])
            with torch.no_grad():
                out = self.ddpm(frame_t, num_inference_steps=50)
            denoised = out["sample"]  # [B, C, 256, 256]
            denoised_frames.append(denoised)
        denoised_frames = torch.stack(denoised_frames, dim=1)  # [B, T, C, 256, 256]
        return denoised_frames

    def preprocess_audio_waveform(self, waveforms):
        """
        waveforms: [B, samples]
        Identity for now (can plug in DiffWave).
        """
        return waveforms

    def forward(self, video_frames=None, audio_waveform=None):
        out_video = video_frames
        out_audio = audio_waveform
        if video_frames is not None and self.for_video:
            out_video = self.preprocess_image_frames(video_frames)
        if audio_waveform is not None and self.for_audio:
            out_audio = self.preprocess_audio_waveform(audio_waveform)
        return out_video, out_audio

#**FACE** **REGION** **EXTRACTION**

In [None]:
import torchvision.transforms as T

class FaceRegionExtractor(nn.Module):
    def __init__(self, device=DEVICE):
        super().__init__()
        self.mtcnn = MTCNN(keep_all=True, device=device, thresholds=[0.9, 0.95, 0.99])
        self.device = device

    def forward(self, frame):
        """
        frame: [B, C, H, W] tensor (values in [0,1])
        Returns two crops: one for lips and one for eyes.
        """
        from PIL import Image
        transform_to_pil = transforms.ToPILImage()
        B = frame.shape[0]
        lips_crops = []
        eyes_crops = []
        resize = T.Resize((224, 224))
        for i in range(B):
            img = transform_to_pil(frame[i].cpu())
            boxes, probs, landmarks = self.mtcnn.detect(img, landmarks=True)
            if boxes is None or landmarks is None:
                lips_crops.append(frame[i])
                eyes_crops.append(frame[i])
            else:
                lm = landmarks[0]
                left_mouth, right_mouth = lm[3], lm[4]
                x1 = int(left_mouth[0] - 0.2 * abs(right_mouth[0] - left_mouth[0]))
                y1 = int(left_mouth[1] - 0.3 * abs(right_mouth[0] - left_mouth[0]))
                x2 = int(right_mouth[0] + 0.2 * abs(right_mouth[0] - left_mouth[0]))
                y2 = int(right_mouth[1] + 0.3 * abs(right_mouth[0] - left_mouth[0]))
                left_eye, right_eye = lm[0], lm[1]
                x1_e = int(min(left_eye[0], right_eye[0]) - 0.2 * abs(right_eye[0]-left_eye[0]))
                y1_e = int(min(left_eye[1], right_eye[1]) - 0.2 * abs(right_eye[0]-left_eye[0]))
                x2_e = int(max(left_eye[0], right_eye[0]) + 0.2 * abs(right_eye[0]-left_eye[0]))
                y2_e = int(max(left_eye[1], right_eye[1]) + 0.2 * abs(right_eye[0]-left_eye[0]))
                lips_crop = img.crop((x1, y1, x2, y2))
                lips_crop = resize(lips_crop)
                eyes_crop = img.crop((x1_e, y1_e, x2_e, y2_e))
                eyes_crop = resize(eyes_crop)
                to_tensor = transforms.ToTensor()
                lips_crops.append(to_tensor(lips_crop).to(self.device))
                eyes_crops.append(to_tensor(eyes_crop).to(self.device))
        lips_batch = torch.stack(lips_crops, dim=0)
        eyes_batch = torch.stack(eyes_crops, dim=0)
        return lips_batch, eyes_batch

# **VIDEO** **FEATURE** **EXTRACTION**

In [None]:
class VideoFeatureExtractor(nn.Module):
    """
    Uses a pretrained ResNet-18 to extract per-frame features.
    These features are then passed through an LSTM to obtain a video-level representation.
    """
    def __init__(self, device=DEVICE):
        super().__init__()
        self.cnn = resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()  # remove final classification layer
        self.device = device
        self.to(device)
        # LSTM to model temporal dynamics: input_dim=512 (ResNet18 output), hidden=256
        self.lstm = nn.LSTM(input_size=512, hidden_size=256, num_layers=1, batch_first=True)

    def forward(self, video_frames):
        """
        video_frames: [B, T, C, H, W]
        Returns: video representation of shape [B, 256]
        """
        B, T, C, H, W = video_frames.shape
        # Process each frame through CNN:
        frames = video_frames.view(B * T, C, H, W)
        features = self.cnn(frames)  # [B*T, 512]
        features = features.view(B, T, 512)
        # Pass through LSTM:
        lstm_out, (hn, cn) = self.lstm(features)
        # Use final hidden state as video representation:
        video_repr = hn[-1]  # shape: [B, 256]
        return video_repr

#**AUDIO** **FEATURE** **EXTRACTION**

In [None]:
import librosa
import torch.nn as nn
import torch

class AudioFeatureExtractorLSTM(nn.Module):
    """
    Extracts MFCCs via librosa and passes them through a Bi-LSTM.
    """
    def __init__(self, device=DEVICE):
        super().__init__()
        self.device = device
        # Bidirectional LSTM: input_dim=13 MFCC, hidden_dim=128
        self.lstm = nn.LSTM(input_size=13, hidden_size=128,
                            num_layers=1, batch_first=True, bidirectional=True)
        self.to(device)

    def forward(self, audio_waveform):
        """
        audio_waveform: numpy array [samples] or torch tensor [B, samples]
        Returns: torch tensor [B, 256] (concatenated hidden states)
        """
        # Ensure we have a batch dimension
        if isinstance(audio_waveform, np.ndarray):
            wave_np = audio_waveform[None, :]
        else:
            # torch tensor
            wave_np = audio_waveform.detach().cpu().numpy()

        # If multi-channel, average to mono
        if wave_np.ndim == 2 and wave_np.shape[1] > AUDIO_SAMPLE_RATE:
            # assume shape [B, samples]
            pass
        elif wave_np.ndim == 2:
            pass
        elif wave_np.ndim == 1:
            wave_np = wave_np[None, :]

        # Compute MFCC for each sample in batch
        mfcc_list = []
        for b in range(wave_np.shape[0]):
            y = wave_np[b]
            m = librosa.feature.mfcc(
                y=y,
                sr=AUDIO_SAMPLE_RATE,
                n_mfcc=13,
                n_fft=400,
                hop_length=160,
                n_mels=26
            )  # shape: [13, time_frames]
            # transpose to [time_frames, 13]
            m = m.T
            # subsample in time
            m = m[::FRAME_SKIP, :]
            mfcc_list.append(m)

        # Stack into tensor [B, T_audio, 13]
        mfcc_tensor = torch.tensor(np.stack(mfcc_list, axis=0), dtype=torch.float32).to(self.device)

        # Pass through LSTM
        out, (hn, cn) = self.lstm(mfcc_tensor)
        # hn.shape = [2 (directions), B, 128]
        forward_h, backward_h = hn[0], hn[1]  # each [B,128]
        audio_repr = torch.cat((forward_h, backward_h), dim=1)  # [B,256]
        return audio_repr


# **FINAL** **DEEPFAKE** **DETECTOR** **MODEL**

In [None]:
class DeepfakeDetector(nn.Module):
    def __init__(self, device=DEVICE):
        super(DeepfakeDetector, self).__init__()
        self.device = device
        # Video branch: Face extraction followed by video feature extraction.
        self.face_extractor = FaceRegionExtractor(device=device)
        self.video_extractor = VideoFeatureExtractor(device=device)
        # Audio branch:
        self.audio_extractor = AudioFeatureExtractorLSTM(device=device)
        # Fusion and classification: Fuse video and audio representations.
        # Video repr is 256, audio repr is 256 -> fused=512.
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, video_frames, audio_waveform):
        """
        video_frames: [B, T, C, H, W] with values in [0,1]
        audio_waveform: [B, channels, samples]
        """
        # Assume video_frames have already been sub-sampled externally if needed.
        # First, process video frames with diffusion preprocessor is assumed done externally.

        # Video branch:
        # Extract face regions from each frame.
        B, T, C, H, W = video_frames.shape
        # Process each frame (if needed, one can loop – here, for simplicity, we assume using the whole frame):
        # Alternatively, if you want to focus on facial regions, use the face_extractor.
        # Here we demonstrate using it per frame:
        processed_frames = []
        for t in range(T):
            frame = video_frames[:, t]  # [B, C, H, W]
            lips, eyes = self.face_extractor(frame)  # both: [B, C, h, w]
            # For simplicity, we concatenate the two crops along the channel dimension:
            face_concat = torch.cat([lips, eyes], dim=1)  # shape: [B, 2*C, h, w]
            # You could also process lips and eyes separately.
            # For our purpose, we simply average the two crops:
            face_avg = (lips + eyes) / 2.0
            processed_frames.append(face_avg)
        processed_frames = torch.stack(processed_frames, dim=1)  # [B, T, C, h, w]
        # Now extract video temporal features via the CNN+LSTM pipeline:
        video_repr = self.video_extractor(processed_frames)  # [B, 256]

        # Audio branch:
        audio_repr = self.audio_extractor(audio_waveform)  # [B, 256]

        # Fusion:
        fused = torch.cat([video_repr, audio_repr], dim=1)  # [B, 512]
        out = self.classifier(fused)
        return out

# **TRAINING** **PIPELINE**

In [None]:
# Define hyperparameters for training
EPOCHS = 20  # Or any suitable number of epochs
LEARNING_RATE = 1e-4  # Or another appropriate learning rate
NUM_CLASSES = 2       # For binary classification

In [None]:
def evaluate_metrics(y_true, y_pred, threshold=0.5):
    y_pred_bin = (y_pred >= threshold).astype(int)
    accuracy = accuracy_score(y_true, y_pred_bin)
    precision = precision_score(y_true, y_pred_bin, zero_division=0)
    recall = recall_score(y_true, y_pred_bin, zero_division=0)
    f1 = f1_score(y_true, y_pred_bin, zero_division=0)
    return accuracy, precision, recall, f1

def train_model(model, train_loader, val_loader, num_epochs=EPOCHS, lr=LEARNING_RATE, device=DEVICE):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        all_preds = []
        all_labels = []
        for video, audio, labels in train_loader:
            video = video.to(device)  # [B, T, C, H, W]
            audio = audio.to(device)  # [B, channels, samples]
            labels = labels.to(device).unsqueeze(1)  # [B, 1]
            optimizer.zero_grad()
            outputs = model(video, audio)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            all_preds.extend(outputs.cpu().detach().numpy().flatten().tolist())
            all_labels.extend(labels.cpu().detach().numpy().flatten().tolist())
        train_acc, train_prec, train_rec, train_f1 = evaluate_metrics(np.array(all_labels), np.array(all_preds))
        print(f"Epoch {epoch+1}/{num_epochs} Train Loss: {np.mean(train_losses):.4f} "
              f"Acc: {train_acc:.4f} Prec: {train_prec:.4f} Recall: {train_rec:.4f} F1: {train_f1:.4f}")

        # Validation
        model.eval()
        val_losses = []
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for video, audio, labels in val_loader:
                video = video.to(device)
                audio = audio.to(device)
                labels = labels.to(device).unsqueeze(1)
                outputs = model(video, audio)
                loss = criterion(outputs, labels)
                val_losses.append(loss.item())
                all_preds.extend(outputs.cpu().detach().numpy().flatten().tolist())
                all_labels.extend(labels.cpu().detach().numpy().flatten().tolist())
        val_acc, val_prec, val_rec, val_f1 = evaluate_metrics(np.array(all_labels), np.array(all_preds))
        print(f"Epoch {epoch+1}/{num_epochs} Val Loss: {np.mean(val_losses):.4f} "
              f"Acc: {val_acc:.4f} Prec: {val_prec:.4f} Recall: {val_rec:.4f} F1: {val_f1:.4f}")

#**TESTING**

In [None]:
def test_model(model, test_loader, device=DEVICE):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for video, audio, labels in test_loader:
            video = video.to(device)
            audio = audio.to(device)
            outputs = model(video, audio)
            preds = (outputs > 0.5).float().cpu().numpy()
            all_preds.extend(preds.flatten().tolist())
            all_labels.extend(labels.cpu().numpy().flatten().tolist())
    acc = accuracy_score(np.array(all_labels), np.array(all_preds))
    prec = precision_score(np.array(all_labels), np.array(all_preds), zero_division=0)
    rec = recall_score(np.array(all_labels), np.array(all_preds), zero_division=0)
    f1 = f1_score(np.array(all_labels), np.array(all_preds), zero_division=0)
    print(f"Test Metrics - Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")


#**MAIN**

In [None]:
from sklearn.model_selection import train_test_split

# Load full metadata from a JSON file if available,
# Or, if your dataset is folder-based, you can get the list from the dataset object.
# For demonstration, we use the dataset's own sample list.

DATA_ROOT = "/content/drive/MyDrive/deepfake_dataset/small_dataset"  # Update with your FakeAVCeleb path
MAX_SAMPLES = 50  # For quick testing

# Create a temporary dataset to fetch all samples.
temp_dataset = FakeAVCelebDataset(data_root=DATA_ROOT, transform=None, audio_transform=None, max_samples=MAX_SAMPLES)
all_samples = temp_dataset.samples

# Split 70/30 train/test, and then 80/20 train/val from train.
train_samples, test_samples = train_test_split(all_samples, test_size=0.3, random_state=42)
train_samples, val_samples = train_test_split(train_samples, test_size=0.2, random_state=42)

print(f"Train samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"Test samples: {len(test_samples)}")


Total samples loaded: 40
Train samples: 22
Validation samples: 6
Test samples: 12


In [None]:
# Define video and audio transforms (if needed)
video_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5)
])
audio_transform = None

# Create dataset objects using the splits:
train_dataset = FakeAVCelebDataset(data_root=DATA_ROOT, metadata_list=train_samples,
                                    transform=video_transform, audio_transform=audio_transform)
val_dataset = FakeAVCelebDataset(data_root=DATA_ROOT, metadata_list=val_samples,
                                  transform=video_transform, audio_transform=audio_transform)
test_dataset = FakeAVCelebDataset(data_root=DATA_ROOT, metadata_list=test_samples,
                                   transform=video_transform, audio_transform=audio_transform)

# For Colab, set num_workers=0 to avoid multiprocessing issues.
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


Total samples loaded: 22
Total samples loaded: 6
Total samples loaded: 12


In [None]:
model = DeepfakeDetector(device=DEVICE)
model = model.to(DEVICE)





In [None]:
train_model(model, train_loader, val_loader)

Epoch 1/20 Train Loss: 0.7013 Acc: 0.4091 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 1/20 Val Loss: 0.7008 Acc: 0.6667 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 2/20 Train Loss: 0.6608 Acc: 0.4091 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 2/20 Val Loss: 0.7292 Acc: 0.6667 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 3/20 Train Loss: 0.6248 Acc: 0.4091 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 3/20 Val Loss: 0.7659 Acc: 0.6667 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 4/20 Train Loss: 0.6098 Acc: 0.4545 Prec: 1.0000 Recall: 0.0769 F1: 0.1429
Epoch 4/20 Val Loss: 0.7975 Acc: 0.6667 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 5/20 Train Loss: 0.5695 Acc: 0.5909 Prec: 1.0000 Recall: 0.3077 F1: 0.4706
Epoch 5/20 Val Loss: 0.8851 Acc: 0.5000 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 6/20 Train Loss: 0.5294 Acc: 0.8182 Prec: 1.0000 Recall: 0.6923 F1: 0.8182
Epoch 6/20 Val Loss: 0.9838 Acc: 0.0000 Prec: 0.0000 Recall: 0.0000 F1: 0.0000
Epoch 7/20 Train Loss: 0.4912 Acc: 0.909

In [None]:
# Save the model after training is done
def save_model(model, path='model.pth'):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")
# Call this after the last epoch or training completion
save_model(model, 'deepfake_detector.pth')

Model saved to deepfake_detector.pth
