# Training (Sobel + YOLO + FaceMesh + Xception) - Module

In [None]:
# Importing Libraries
import os
import cv2
import torch
import numpy as np
import mediapipe as mp
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import pretrainedmodels  # For Xception model
from ultralytics import YOLO  # For YOLOv8 face detection
import ssl
import logging
import matplotlib.pyplot as plt
import warnings
import torch.nn as nn

## Training Pipeline below is for WildRF -- Can be extended Similarly for CollabDiff and eVe StyleGAN

In [None]:

warnings.filterwarnings("ignore")
logging.getLogger("ultralytics").setLevel(logging.WARNING)
ssl._create_default_https_context = ssl._create_unverified_context
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations for Xception
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((299, 299)),  # Xception requires 299x299 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Set up Mediapipe for facial landmarks extraction
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1)

# Loading YOLOv8 model
yolo_model = YOLO("yolov8n.pt").to(device)  # Ensure YOLO runs on GPU if available)  # Choose the YOLOv8 model variant based on resources
# Loading Xception model
xception_model = pretrainedmodels.__dict__['xception'](pretrained='imagenet').to(device)
xception_model.last_linear = torch.nn.Linear(xception_model.last_linear.in_features, 128).to(device)  # Adjust for feature extraction

# Define COCO classes we are interested in (people, vehicles, animals, household items, etc.)
COCO_CLASSES = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
    "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
    "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
    "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
    "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
    "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
    "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
    "hair drier", "toothbrush"
]

# Dataset Class
class DeepfakeDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, limit=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.data = []
        self.labels = []
        self._load_data(limit)

    def _load_data(self, limit):
        if self.split == 'test':
            for platform in ['facebook', 'reddit', 'twitter']:
                for label in ['0_real', '1_fake']:
                    path = os.path.join(self.root_dir, self.split, platform, label)
                    self._load_images(path, label, limit)
        else:
            for label in ['0_real', '1_fake']:
                path = os.path.join(self.root_dir, self.split, label)
                self._load_images(path, label, limit)

    def _load_images(self, path, label, limit):
        if os.path.exists(path):
            for i, img_name in enumerate(os.listdir(path)):
                if limit and len(self.data) >= limit:
                    break
                self.data.append(os.path.join(path, img_name))
                self.labels.append(0 if label == '0_real' else 1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        image = cv2.imread(img_path)
        label = self.labels[idx]

        # YOLOv8 for object detection
        results = yolo_model(image)
        detected_objects = []
        face_landmarks = np.zeros((936,), dtype=np.float32) 

        for result in results[0].boxes:
            class_id = int(result.cls[0])  # YOLOv8 returns class IDs
            # class_name = YOLO.names[class_id]  # Get class name from YOLO COCO classes
            class_name = yolo_model.names[class_id] 

            # Check if the detected object is one of the COCO classes we care about
            if class_name in COCO_CLASSES:
                x1, y1, x2, y2 = result.xyxy[0].cpu().numpy()
                obj_crop = image[int(y1):int(y2), int(x1):int(x2)]
                detected_objects.append(class_id)

                # If the detected object is a person, get facial landmarks
                if class_name == 'person':
                    results_face = face_mesh.process(cv2.cvtColor(obj_crop, cv2.COLOR_BGR2RGB))
                    if results_face.multi_face_landmarks:
                        face_landmarks = np.array([[p.x, p.y] for p in results_face.multi_face_landmarks[0].landmark])
                        face_landmarks = face_landmarks.flatten()

        # Encode detected objects as a one-hot vector of COCO class detections
        yolo_features = np.zeros(len(COCO_CLASSES))
        for obj_id in detected_objects:
            yolo_features[obj_id] = 1  # Mark the detected class in the one-hot vector
        
        # Applying Sobel edge detection
        gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        sobel_x = cv2.Sobel(gray_image, cv2.CV_64F, 1, 0, ksize=3)
        sobel_y = cv2.Sobel(gray_image, cv2.CV_64F, 0, 1, ksize=3)
        sobel_combined = cv2.magnitude(sobel_x, sobel_y)
        sobel_combined = cv2.convertScaleAbs(sobel_combined)
        sobel_combined = cv2.merge([sobel_combined, sobel_combined, sobel_combined])
        # Transform the image for Xception model
        if self.transform:
            image = self.transform(image)
            sobel_combined = self.transform(sobel_combined)

        yolo_features = torch.tensor(yolo_features, dtype=torch.float32).to(device)
        face_landmarks = torch.tensor(face_landmarks, dtype=torch.float32).to(device)

        return image.to(device), sobel_combined.to(device), yolo_features, face_landmarks, torch.tensor(label, dtype=torch.long).to(device)

# Define the Classifier Model using Xception
class DeepfakeClassifier(torch.nn.Module):
    def __init__(self):
        super(DeepfakeClassifier, self).__init__()
        self.xception = xception_model  # Outputs 128 features
        self.sobel_cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.sobel_linear = None  # Will initialize dynamically
        self.fc_landmarks = nn.Linear(936, 128).to(device)  # 936 = flattened landmarks
        self.fc_yolo = nn.Linear(80, 64).to(device)  # Adjust YOLO features to 64
        self.fc1 = None  # To be initialized dynamically
        self.fc2 = nn.Linear(128, 2).to(device)

    def initialize_sobel_linear(self, input_shape):
        with torch.no_grad():
            # Initialize Sobel Linear
            sample_input = torch.zeros(1, *input_shape).to(device)
            output = self.sobel_cnn(sample_input)
            flattened_size = output.view(-1).size(0)
            self.sobel_linear = nn.Linear(flattened_size, 128).to(device)

            # Calculate the total feature size for fc1
            total_feature_size = 128 + 128 + 128 + 64  # xception + sobel + landmarks + YOLO
            self.fc1 = nn.Linear(total_feature_size, 128).to(device)

    def forward(self, image, sobel_image, yolo_features, face_landmarks):
        # Process features
        
        image_features = self.xception(image)  # Output: [batch_size, 128]
        sobel_features = self.sobel_cnn(sobel_image)  # Output: [batch_size, C, H, W]
        sobel_features = self.sobel_linear(sobel_features.view(sobel_features.size(0), -1))
        yolo_features = torch.relu(self.fc_yolo(yolo_features))
        landmark_features = torch.relu(self.fc_landmarks(face_landmarks))

        # Combine features
        combined = torch.cat((image_features, sobel_features, yolo_features, landmark_features), dim=1)

        # Fully connected layers
        x = torch.relu(self.fc1(combined))
        x = self.fc2(x)
        return x

W0000 00:00:1732399347.913995  238834 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1732399347.924421  238834 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


In [None]:
# Training parameters
num_epochs = 10
learning_rate = 0.001
batch_size = 16

# Dataset and DataLoader
root_dir = 'WildRF'  # Replace with actual path
train_dataset = DeepfakeDataset(root_dir=root_dir, split='train', transform=transform)
val_dataset = DeepfakeDataset(root_dir=root_dir, split='val', transform=transform)
test_dataset = DeepfakeDataset(root_dir=root_dir, split='test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define model, loss function, and optimizer
model = DeepfakeClassifier().to(device)
# Dynamically calculate the flattened size for sobel_cnn
model.initialize_sobel_linear(input_shape=(3, 299, 299))  # Assuming Sobel image size is 299x299
criterion = torch.nn.CrossEntropyLoss()

# Hyperparameter (tunable)
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Early stopping parameters
patience = 3  # Stop if no improvement in validation loss after 5 epochs
best_val_loss = float('inf')
epochs_no_improve = 0

# Lists to store accuracy and loss values for plotting
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

best_model_path = "best_model_module1.pth"

# Training Loop with validation and early stopping
for epoch in range(num_epochs):
    # Training Phase
    model.train()
    running_train_loss = 0.0
    correct_train = 0
    total_train = 0
    
    for images, sobel_images, yolo_features, landmarks, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        outputs = model(images, sobel_images, yolo_features, landmarks)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)
    
    epoch_train_loss = running_train_loss / len(train_loader)
    epoch_train_accuracy = correct_train / total_train
    train_losses.append(epoch_train_loss)
    train_accuracies.append(epoch_train_accuracy)

    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for images, sobel_images, yolo_features, landmarks, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Validation"):
            outputs = model(images, sobel_images, yolo_features, landmarks)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item()
            
            _, preds = torch.max(outputs, 1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)

    epoch_val_loss = running_val_loss / len(val_loader)
    epoch_val_accuracy = correct_val / total_val
    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val_accuracy)

    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_accuracy * 100:.2f}%, "
          f"Val Loss: {epoch_val_loss:.4f}, Val Accuracy: {epoch_val_accuracy * 100:.2f}%")

    # Early Stopping and Model Selection
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), best_model_path)  # Save the best model
        print(f"Best model saved with validation loss: {best_val_loss:.4f}")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping triggered.")
            break

# Load the best model for final testing or further evaluation
model.load_state_dict(torch.load(best_model_path))
print("Loaded the best model based on validation performance.")

In [None]:
# Plot Training and Validation Accuracy
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label="Training Accuracy", marker='o')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label="Validation Accuracy", marker='o')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracy over Epochs")
plt.legend()
plt.grid(True)
plt.savefig("training_validation_accuracy_module1_WildRf.png")
print("Training and Validation accuracy plot saved as 'training_validation_accuracy_module1_WildRf.png'.")

# Plot Training and Validation Loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Training Loss", marker='o')
plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss", marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss over Epochs")
plt.legend()
plt.grid(True)
plt.savefig("training_validation_loss_module1_WildRf.png")
print("Training and Validation loss plot saved as 'training_validation_loss_module1_WildRf.png'.")

# Testing


In [None]:
# Testing onn facebook, twitter, reddit and validation

In [4]:
# Load the best model
model.load_state_dict(torch.load(best_model_path))
model.eval()  # Set to evaluation mode
print("Loaded the best model for testing.")

# Test Phase
running_test_loss = 0.0
correct_test = 0
total_test = 0

with torch.no_grad():
    for images, sobel_images, yolo_features, landmarks, labels in tqdm(test_loader, desc="Testing on Test Set"):
        outputs = model(images, sobel_images, yolo_features, landmarks)
        loss = criterion(outputs, labels)
        running_test_loss += loss.item()
        
        _, preds = torch.max(outputs, 1)
        correct_test += (preds == labels).sum().item()
        total_test += labels.size(0)

test_loss = running_test_loss / len(test_loader)
test_accuracy = correct_test / total_test

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy * 100:.2f}%")

Loaded the best model for testing.


Testing on Test Set: 100%|██████████| 157/157 [28:16<00:00, 10.81s/it]

Test Loss: 0.2924, Test Accuracy: 87.61%





In [None]:
root_dir = 'WildRF'  # Replace with actual path
test_dataset = DeepfakeDataset(root_dir=root_dir, split='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

from torchcam.methods import GradCAM
# Define Model
model = DeepfakeClassifier().to(device)

# Dynamically calculate the flattened size for sobel_cnn
sobel_input_shape = (3, 299, 299)  # Assuming Sobel image size is 299x299
model.initialize_sobel_linear(input_shape=sobel_input_shape)

# Load pretrained weights
model.load_state_dict(torch.load("best_model_module1.pth"))
model.eval()

# Use TorchCAM's GradCAM
# Replace with the correct convolutional layer from your model
target_layer = "xception.block5.rep.4.pointwise"  # Example convolutional layer
grad_cam = GradCAM(model, target_layer)
import os

def gradcam_visualization_on_fake_images(model, loader, grad_cam, num_images=5, save_dir="gradcam_outputs"):
    model.eval()  # Set the model to evaluation mode
    images_processed = 0

    # Create directory to save Grad-CAM outputs
    os.makedirs(save_dir, exist_ok=True)

    for images, sobel_images, yolo_features, landmarks, labels in loader:
        images, sobel_images, yolo_features, landmarks, labels = (
            images.to(device),
            sobel_images.to(device),
            yolo_features.to(device),
            landmarks.to(device),
            labels.to(device),
        )

        for i in range(len(images)):
            # Process only fake images (label == 1)
            if labels[i].item() != 1:
                continue  # Skip non-fake images

            if images_processed >= num_images:
                return  # Stop after visualizing `num_images`

            input_image = images[i].unsqueeze(0)
            sobel_image = sobel_images[i].unsqueeze(0)
            yolo_feature = yolo_features[i].unsqueeze(0)
            landmark = landmarks[i].unsqueeze(0)

            # Forward pass to get predictions
            outputs = model(input_image, sobel_image, yolo_feature, landmark)
            pred_class = outputs.argmax(dim=1).item()

            print(f"Fake Image {images_processed + 1} - Predicted Class: {pred_class}, True Label: {labels[i].item()}")

            # Generate Grad-CAM heatmap
            activation_map = grad_cam(class_idx=pred_class, scores=outputs)  # Explicitly pass class_idx and scores

            # Remove batch dimension for visualization
            heatmap = activation_map[0].squeeze().cpu().numpy()  # Shape: (19, 19)

            # Resize heatmap to match input image dimensions
            heatmap_resized = cv2.resize(heatmap, (299, 299))  # Assuming the input image size is 299x299

            # Normalize heatmap for better visualization
            heatmap_resized = (heatmap_resized - heatmap_resized.min()) / (heatmap_resized.max() - heatmap_resized.min())

            input_image_vis = to_pil_image(input_image.squeeze().cpu())

            # Plot and save the images
            plt.figure(figsize=(10, 5))

            # Original Image
            plt.subplot(1, 3, 1)
            plt.imshow(input_image_vis)
            plt.title("Original Image")
            plt.axis("off")

            # Heatmap
            plt.subplot(1, 3, 2)
            plt.imshow(heatmap_resized, cmap="jet")
            plt.title("Grad-CAM Heatmap")
            plt.axis("off")

            # Overlayed Image
            plt.subplot(1, 3, 3)
            plt.imshow(input_image_vis)
            plt.imshow(heatmap_resized, cmap="jet", alpha=0.5)  # Overlay heatmap
            plt.title("Overlay")
            plt.axis("off")

            # Save the figure
            output_path = os.path.join(save_dir, f"gradcam_fake_output_{images_processed + 1}.png")
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved Grad-CAM visualization for fake image as {output_path}")

            plt.close()
            images_processed += 1


# Apply Grad-CAM only on fake images from the test dataset
gradcam_visualization_on_fake_images(model, test_loader, grad_cam, num_images=20, save_dir="gradcam_fake_outputs")


Fake Image 1 - Predicted Class: 1, True Label: 1
Saved Grad-CAM visualization for fake image as gradcam_fake_outputs/gradcam_fake_output_1.png
Fake Image 2 - Predicted Class: 0, True Label: 1
Saved Grad-CAM visualization for fake image as gradcam_fake_outputs/gradcam_fake_output_2.png
Fake Image 3 - Predicted Class: 0, True Label: 1
Saved Grad-CAM visualization for fake image as gradcam_fake_outputs/gradcam_fake_output_3.png
Fake Image 4 - Predicted Class: 1, True Label: 1
Saved Grad-CAM visualization for fake image as gradcam_fake_outputs/gradcam_fake_output_4.png
Fake Image 5 - Predicted Class: 1, True Label: 1
Saved Grad-CAM visualization for fake image as gradcam_fake_outputs/gradcam_fake_output_5.png
Fake Image 6 - Predicted Class: 1, True Label: 1
Saved Grad-CAM visualization for fake image as gradcam_fake_outputs/gradcam_fake_output_6.png
Fake Image 7 - Predicted Class: 1, True Label: 1
Saved Grad-CAM visualization for fake image as gradcam_fake_outputs/gradcam_fake_output_7.png