In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
# CELL 1: Imports, paths, hyperparameters

import os
import json
import glob
import random
from collections import Counter

import cv2
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnext50_32x4d, ResNeXt50_32X4D_Weights

import matplotlib.pyplot as plt

# ---- Paths (adjust only if your paths are different) ----
BASE_DIR = "/content/drive/MyDrive/FF_REAL_Face_only_data"
METADATA_JSON = "/content/drive/MyDrive/metadata.json"

# ---- Hyperparameters ----
SEQUENCE_LENGTH = 10      # frames per video
IM_SIZE = 112
BATCH_SIZE = 4
NUM_WORKERS = 0           # keep 0 on Colab
NUM_EPOCHS = 10
LR = 1e-4
HIDDEN_SIZE = 256

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [None]:
# CELL 2: Load metadata + build labels + split

# 1) Load metadata.json -> labels_map
with open(METADATA_JSON, "r") as f:
    meta = json.load(f)

labels_map = {
    os.path.basename(k): str(v["label"]).upper()
    for k, v in meta.items()
}

print("Total entries in metadata:", len(labels_map))

# 2) Collect all video files that actually exist
video_files = sorted(glob.glob(os.path.join(BASE_DIR, "*.mp4")))
print("Total video files found:", len(video_files))

# 3) Filter to only videos that have a label in metadata
valid_video_files = []
valid_labels = []

for vp in video_files:
    fname = os.path.basename(vp)
    if fname in labels_map:
        valid_video_files.append(vp)
        valid_labels.append(labels_map[fname])

print("Videos with labels:", len(valid_video_files))
print("Label counts:", Counter(valid_labels))

# 4) Convert labels to numeric (FAKE=0, REAL=1)
label_to_num = {"FAKE": 0, "REAL": 1}
numeric_labels = [label_to_num[lbl] for lbl in valid_labels]

# 5) Stratified split
indices = list(range(len(valid_video_files)))

train_idx, valid_idx = train_test_split(
    indices,
    test_size=0.2,
    stratify=numeric_labels,
    random_state=42
)

train_videos = [valid_video_files[i] for i in train_idx]
valid_videos = [valid_video_files[i] for i in valid_idx]

train_labels = [numeric_labels[i] for i in train_idx]
valid_labels_num = [numeric_labels[i] for i in valid_idx]

print("\nTrain size:", len(train_videos), " Valid size:", len(valid_videos))
print("Train label counts:", Counter(train_labels))
print("Valid label counts:", Counter(valid_labels_num))

# sanity: no overlap
overlap = set(os.path.basename(p) for p in train_videos) & set(os.path.basename(p) for p in valid_videos)
print("Overlap between train and valid filenames (should be empty):", overlap)


Total entries in metadata: 200
Total video files found: 200
Videos with labels: 200
Label counts: Counter({'REAL': 100, 'FAKE': 100})

Train size: 160  Valid size: 40
Train label counts: Counter({0: 80, 1: 80})
Valid label counts: Counter({0: 20, 1: 20})
Overlap between train and valid filenames (should be empty): set()


In [None]:
# CELL 3: Dataset class (for LSTM)

mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IM_SIZE, IM_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IM_SIZE, IM_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

class VideoDatasetLSTM(Dataset):
    def __init__(self, video_paths, labels, sequence_length=SEQUENCE_LENGTH, transform=None):
        """
        video_paths: list of full paths
        labels: list of ints (0=FAKE, 1=REAL)
        """
        self.video_paths = video_paths
        self.labels = labels
        self.seq_len = sequence_length
        self.transform = transform if transform is not None else train_transforms

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        path = self.video_paths[idx]
        label = self.labels[idx]

        frames = []
        cap = cv2.VideoCapture(path)
        try:
            while len(frames) < self.seq_len:
                ok, frame = cap.read()
                if not ok:
                    break
                # BGR -> RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = self.transform(frame)   # (3,H,W)
                frames.append(frame)
        finally:
            cap.release()

        # If not enough frames, pad with last frame
        if len(frames) == 0:
            # completely broken video: fill zeros
            frame = torch.zeros(3, IM_SIZE, IM_SIZE)
            frames = [frame] * self.seq_len
        elif len(frames) < self.seq_len:
            last = frames[-1]
            while len(frames) < self.seq_len:
                frames.append(last.clone())

        frames = torch.stack(frames)[:self.seq_len]  # (T,3,H,W)

        return frames, torch.tensor(label, dtype=torch.long)


In [None]:
# CELL 4: Dataloaders + sanity check

train_dataset = VideoDatasetLSTM(train_videos, train_labels, sequence_length=SEQUENCE_LENGTH, transform=train_transforms)
valid_dataset = VideoDatasetLSTM(valid_videos, valid_labels_num, sequence_length=SEQUENCE_LENGTH, transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print("Train batches:", len(train_loader), " Valid batches:", len(valid_loader))

# check one sample
sample_frames, sample_label = train_dataset[0]
print("Sample frames shape:", sample_frames.shape)  # expected (T,3,112,112)
print("Sample numeric label:", sample_label.item())
print("From labels_map:", labels_map.get(os.path.basename(train_videos[0])))


Train batches: 40  Valid batches: 10
Sample frames shape: torch.Size([10, 3, 112, 112])
Sample numeric label: 0
From labels_map: FAKE


In [None]:
# CELL 5: Define LSTM model

class DeepfakeLSTMModel(nn.Module):
    def __init__(self, hidden_size=HIDDEN_SIZE, num_layers=1, num_classes=2):
        super().__init__()
        # CNN backbone
        self.backbone = resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V1)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()  # remove final FC

        # LSTM on top of frame features
        self.lstm = nn.LSTM(
            input_size=in_features,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )

        # final classifier
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x: (B, T, 3, H, W)
        B, T, C, H, W = x.shape

        x = x.view(B*T, C, H, W)          # (B*T,3,H,W)
        feats = self.backbone(x)          # (B*T, 2048)

        feats = feats.view(B, T, -1)      # (B, T, 2048)
        lstm_out, _ = self.lstm(feats)    # (B, T, hidden)
        last_out = lstm_out[:, -1, :]     # (B, hidden)

        logits = self.fc(last_out)        # (B, 2)
        return logits

model = DeepfakeLSTMModel().to(device)
print("Model on", device)


Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth


100%|██████████| 95.8M/95.8M [00:01<00:00, 81.0MB/s]


Model on cuda


In [None]:
# CELL 6: Train LSTM model

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

def run_epoch(loader, train=True):
    if train:
        model.train()
    else:
        model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for frames, labels in loader:
        frames = frames.to(device)      # (B,T,3,112,112)
        labels = labels.to(device)      # (B,)

        if train:
            optimizer.zero_grad()

        outputs = model(frames)         # (B,2)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)

        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * labels.size(0)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total * 100.0
    return avg_loss, acc

for epoch in range(1, NUM_EPOCHS+1):
    train_loss, train_acc = run_epoch(train_loader, train=True)
    val_loss, val_acc = run_epoch(valid_loader, train=False)
    print(f"[Epoch {epoch}/{NUM_EPOCHS}] "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}%")

print("Training complete.")


[Epoch 1/10] Train Loss: 0.1832 Acc: 93.75% | Val Loss: 0.0798 Acc: 100.00%
[Epoch 2/10] Train Loss: 0.2050 Acc: 93.12% | Val Loss: 0.1644 Acc: 92.50%
[Epoch 3/10] Train Loss: 0.0521 Acc: 98.12% | Val Loss: 0.1202 Acc: 95.00%
[Epoch 4/10] Train Loss: 0.1276 Acc: 96.88% | Val Loss: 0.1432 Acc: 95.00%
[Epoch 5/10] Train Loss: 0.0443 Acc: 98.75% | Val Loss: 0.0414 Acc: 97.50%
[Epoch 6/10] Train Loss: 0.0773 Acc: 97.50% | Val Loss: 0.0459 Acc: 97.50%
[Epoch 7/10] Train Loss: 0.0618 Acc: 97.50% | Val Loss: 0.1941 Acc: 87.50%
[Epoch 8/10] Train Loss: 0.0393 Acc: 98.12% | Val Loss: 0.2115 Acc: 90.00%
[Epoch 9/10] Train Loss: 0.0690 Acc: 97.50% | Val Loss: 0.1978 Acc: 92.50%
[Epoch 10/10] Train Loss: 0.0803 Acc: 97.50% | Val Loss: 0.4327 Acc: 85.00%
Training complete.


In [None]:
# CELL 7: Evaluate on validation set

model.eval()
all_true = []
all_pred = []

with torch.no_grad():
    for frames, labels in valid_loader:
        frames = frames.to(device)
        labels = labels.to(device)
        outputs = model(frames)
        _, preds = torch.max(outputs, 1)
        all_true.extend(labels.cpu().numpy().tolist())
        all_pred.extend(preds.cpu().numpy().tolist())

num_to_str = {0: "FAKE", 1: "REAL"}
true_str = [num_to_str[x] for x in all_true]
pred_str = [num_to_str[x] for x in all_pred]

print("y_true counts:", Counter(true_str))
print("y_pred counts:", Counter(pred_str))

labels_order = ["REAL", "FAKE"]
cm = confusion_matrix(true_str, pred_str, labels=labels_order)
print("\nConfusion matrix (rows=true, cols=pred):", labels_order)
print(cm)

acc = accuracy_score(true_str, pred_str) * 100.0
print(f"\nAccuracy: {acc:.2f}%")
print("\nClassification report:")
print(classification_report(true_str, pred_str, labels=labels_order, zero_division=0))


y_true counts: Counter({'FAKE': 20, 'REAL': 20})
y_pred counts: Counter({'FAKE': 26, 'REAL': 14})

Confusion matrix (rows=true, cols=pred): ['REAL', 'FAKE']
[[14  6]
 [ 0 20]]

Accuracy: 85.00%

Classification report:
              precision    recall  f1-score   support

        REAL       1.00      0.70      0.82        20
        FAKE       0.77      1.00      0.87        20

    accuracy                           0.85        40
   macro avg       0.88      0.85      0.85        40
weighted avg       0.88      0.85      0.85        40



In [None]:
# CELL 8: Save trained model

save_path = "/content/drive/MyDrive/deepfake_lstm_trained.pth"
torch.save(model.state_dict(), save_path)
print("Saved model to:", save_path)


Saved model to: /content/drive/MyDrive/deepfake_lstm_trained.pth
