# Training ViT Model

In [None]:
import zipfile
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import torch

data_dir = "/kaggle/input/imagedataset"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_data = datasets.ImageFolder(root=f"{data_dir}/train", transform=transform)
val_data = datasets.ImageFolder(root=f"{data_dir}/validation", transform=transform)
test_data = datasets.ImageFolder(root=f"{data_dir}/test", transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)

# Model Setup (Vision Transformer)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = vit_b_16(pretrained=True)
model.heads.head = nn.Linear(model.heads.head.in_features, 2)  # binary classification
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
def train_model(model, loader, val_loader, epochs=5):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        train_acc = 100 * correct / total
        train_epoch_loss = total_loss / len(loader)

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)

        val_acc = 100 * val_correct / val_total
        val_epoch_loss = val_loss / len(val_loader)

        train_losses.append(train_epoch_loss)
        val_losses.append(val_epoch_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        print(f"Epoch [{epoch+1}/{epochs}] Train Loss: {train_epoch_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_epoch_loss:.4f} | Val Acc: {val_acc:.2f}%")

    return train_losses, val_losses, train_accuracies, val_accuracies

# Train the Model
train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_loader, val_loader)

# Save the model
model_save_path = "/kaggle/working/vit_model.pth"
torch.save(model, model_save_path)
print(f"Model saved to {model_save_path}")

# Plot evaluation metrics
def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curves')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Train Acc')
    plt.plot(epochs, val_accuracies, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curves')
    plt.legend()

    plt.tight_layout()
    plt.savefig('/kaggle/working/training_curves.png')  # Optional: Save plot
    plt.show()

def evaluate_model(model, val_loader, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())

    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('/kaggle/working/confusion_matrix.png')  # Optional: Save plot
    plt.show()

    # Classification Report
    print("Classification Report:\n")
    print(classification_report(all_labels, all_preds, target_names=["Real", "Fake"]))

    # ROC Curve
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc='lower right')
    plt.grid()
    plt.tight_layout()
    plt.savefig('/kaggle/working/roc_curve.png')  # Optional: Save plot
    plt.show()

# Visualize Training and Evaluate
plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies)
evaluate_model(model, val_loader)

# Training XceptionNet+GRU model (on .mp4 data)

In [None]:
import os
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# Customise the video dataset
class VideoDataset(Dataset):
    def __init__(self, video_dir, transform, max_frames=16):
        self.samples = []
        self.transform = transform
        self.max_frames = max_frames
        for label, subdir in enumerate(['real', 'fake']):
            folder = os.path.join(video_dir, subdir)
            for file in os.listdir(folder):
                if file.endswith(".mp4"):
                    self.samples.append((os.path.join(folder, file), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        cap = cv2.VideoCapture(path)
        frames = []
        count = 0
        while count < self.max_frames and cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            frames.append(self.transform(frame))
            count += 1
        cap.release()
        while len(frames) < self.max_frames:
            frames.append(frames[-1])
        return torch.stack(frames), label

# Image transforms
video_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load dataset
video_ds = VideoDataset("/content/datasets", transform=video_transform)
train_size = int(0.8 * len(video_ds))
val_size = len(video_ds) - train_size
train_ds, val_ds = torch.utils.data.random_split(video_ds, [train_size, val_size])
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=4)

# Model: EfficientNet + GRU
class XceptionGRU(nn.Module):
    def __init__(self, hidden_size=256, num_classes=2):
        super().__init__()
        self.cnn = models.efficientnet_b0(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])
        for param in self.cnn.parameters():
            param.requires_grad = False
        self.gru = nn.GRU(input_size=1280, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        with torch.no_grad():
            x = self.cnn(x)
        x = x.view(B, T, -1)
        _, h = self.gru(x)
        return self.fc(h[-1])

# Device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = XceptionGRU().to(device)

# Train + Eval Functions
def train_model(model, train_dl, val_dl, epochs=10, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for inputs, labels in train_dl:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        train_losses.append(total_loss / len(train_dl))
        train_accuracies.append(correct / total)

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_dl:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        val_losses.append(val_loss / len(val_dl))
        val_accuracies.append(val_correct / val_total)

        print(f"Epoch [{epoch+1}/{epochs}] "
              f"Train Loss: {train_losses[-1]:.4f}, Acc: {train_accuracies[-1]*100:.2f}% | "
              f"Val Loss: {val_losses[-1]:.4f}, Acc: {val_accuracies[-1]*100:.2f}%")

    return train_losses, val_losses, train_accuracies, val_accuracies

# Plotting curves
def plot_training_curves(train_losses, val_losses, train_acc, val_acc):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Train Loss')
    plt.plot(epochs, val_losses, 'r-', label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curve')

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc, 'b-', label='Train Accuracy')
    plt.plot(epochs, val_acc, 'r-', label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy Curve')

    plt.tight_layout()
    plt.show()

# Evaluate the model
def evaluate_model(model, dataloader):
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    print("📊 Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=["Real", "Fake"]))

    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Real", "Fake"])
    disp.plot(cmap='Blues')
    plt.show()

# Train and Save
train_losses, val_losses, train_acc, val_acc = train_model(model, train_dl, val_dl)
plot_training_curves(train_losses, val_losses, train_acc, val_acc)
evaluate_model(model, val_dl)

torch.save(model.state_dict(), "/content/drive/MyDrive/DeepFake/xception_gru_video_model.pth")
print("Model Saved Successfully!")

# Training XceptionNet+GRU model (on sequential images data)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm

# Unzip Dataset
!unzip -q "/content/drive/MyDrive/DeepFake/SeqImagesDataset.zip" -d /content/

# Dataset Path
seq_data_path = "/content"

# Custom Dataset Loader (for Sequences of Images)
class SeqImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, sequence_len=10):
        self.sequence_len = sequence_len
        self.data = []
        self.labels = []
        self.transform = transform

        for label_idx, label in enumerate(['real', 'fake']):
            folder = os.path.join(root_dir, label)
            if not os.path.exists(folder):
                print(f"Warning: Folder {folder} does not exist!")
                continue

            videos = {}
            for file in sorted(os.listdir(folder)):
                video_id = "_".join(file.split("_")[:-1])
                videos.setdefault(video_id, []).append(os.path.join(folder, file))

            for frames in videos.values():
                if len(frames) >= sequence_len:
                    self.data.append(frames[:sequence_len])
                    self.labels.append(label_idx)

        print(f"Loaded {len(self.data)} sequences from {root_dir}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        frames = self.data[idx]
        label = self.labels[idx]
        images = [self.transform(Image.open(f).convert('RGB')) for f in frames]
        return torch.stack(images), label

# Define Image Transformations
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
])

# Load Training & Validation Datasets
train_ds = SeqImageDataset(f"{seq_data_path}/train", transform=transform)
val_ds = SeqImageDataset(f"{seq_data_path}/validation", transform=transform)

# Create DataLoaders
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=4)

# Define Xception + GRU Model
class XceptionGRU(nn.Module):
    def __init__(self, hidden_size=128):
        super().__init__()
        self.cnn = timm.create_model("xception", pretrained=True, num_classes=0)  # Pretrained Xception
        self.rnn = nn.GRU(input_size=2048, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)  # Binary classification: Real or Fake

    def forward(self, x):
        B, T, C, H, W = x.shape  # Batch, Time Steps, Channels, Height, Width
        x = x.view(B * T, C, H, W)  # Flatten for CNN processing
        feats = self.cnn(x)  # Extract CNN features (B*T, 2048)
        feats = feats.view(B, T, -1)  # Reshape for GRU (Batch, Time, Features)
        _, h = self.rnn(feats)  # GRU output
        return self.fc(h[-1])  # Fully connected layer

# Device Configuration (Use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = XceptionGRU().to(device)

# Define Optimizer and Loss Function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Train the Model
def train_model(model, train_dl, val_dl, epochs=5):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct, total = 0, 0

        for images, labels in train_dl:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total
        train_losses.append(total_loss / len(train_dl))
        train_accuracies.append(train_acc)

        # Validation
        model.eval()
        val_loss = 0
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for images, labels in val_dl:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)

        val_acc = val_correct / val_total
        val_losses.append(val_loss / len(val_dl))
        val_accuracies.append(val_acc)

        print(f"Epoch {epoch+1}/{epochs}, "
              f"Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, "
              f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    return train_losses, val_losses, train_accuracies, val_accuracies

# Validation metrics
def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 5))

    # Loss Plot
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.title("Loss over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    # Accuracy Plot
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label="Train Accuracy")
    plt.plot(val_accuracies, label="Val Accuracy")
    plt.title("Accuracy over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.show()

def evaluate_model(model, val_dl):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_dl:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Classification report
    from sklearn.utils.multiclass import unique_labels
    labels_present = unique_labels(all_labels, all_preds)

    class_names = ["Real", "Fake"]
    class_labels = list(range(len(class_names)))

    # Filter names to match existing labels
    filtered_names = [class_names[i] for i in labels_present]

    print("Classification Report:\n", classification_report(all_labels, all_preds, labels=labels_present, target_names=filtered_names))

    print("Validation class counts:", {label: all_labels.count(label) for label in set(all_labels)})

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["Real", "Fake"], yticklabels=["Real", "Fake"])
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.show()

# Train and collect metrics
train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_dl, val_dl, epochs=5)

# Save the model
torch.save(model, "/content/drive/MyDrive/DeepFake/xception_gru_seq_model.pth")
print("Model Saved Successfully!")

# Plot curves
plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies)

# Evaluate performance
evaluate_model(model, val_dl)

# Building ensemble model for prediction using weighted average method

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import cv2
import timm
import os
import zipfile
import shutil
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== Code 1: ViT Model ====
model1 = torch.load(
    "/content/drive/MyDrive/DeepFake/deepfake_detector/vit_model.pth",
    map_location=device,
    weights_only=False
)
model1.to(device)
model1.eval()

transform1 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def predict_code1_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform1(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model1(image)
        probabilities = F.softmax(output, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    class_names = ['Deepfake', 'Real']
    prediction = class_names[predicted.item()]
    return prediction

# ==== Code 2: XceptionNet+GRU Model (for video data) ====
class XceptionGRU_Code2(nn.Module):
    def __init__(self, hidden_size=256, num_classes=2):
        super().__init__()
        self.cnn = models.efficientnet_b0(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])
        for param in self.cnn.parameters():
            param.requires_grad = False
        self.gru = nn.GRU(input_size=1280, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        with torch.no_grad():
            x = self.cnn(x)
        x = x.view(B, T, -1)
        _, h = self.gru(x)
        return self.fc(h[-1])

model2 = XceptionGRU_Code2().to(device)
model2.load_state_dict(torch.load(
    "/content/drive/MyDrive/DeepFake/deepfake_detector/xception_gru_video_model.pth",
    map_location=device))
model2.eval()

video_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def predict_code2_video(video_path, max_frames=16):
    cap = cv2.VideoCapture(video_path)
    frames = []
    count = 0
    while count < max_frames and cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = Image.fromarray(frame)
        frames.append(video_transform(frame))
        count += 1
    cap.release()
    while len(frames) < max_frames:
        frames.append(frames[-1])
    input_tensor = torch.stack(frames).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model2(input_tensor)
        probabilities = torch.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
    class_names = ["Real", "Fake"]
    return class_names[predicted_class]

# ==== Code 3: Xception+GRU Model (for sequential images data) ====
class XceptionGRU_Code3(nn.Module):
    def __init__(self, hidden_size=128):
        super().__init__()
        self.cnn = timm.create_model("xception", pretrained=True, num_classes=0)
        self.rnn = nn.GRU(input_size=2048, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feats = self.cnn(x)
        feats = feats.view(B, T, -1)
        _, h = self.rnn(feats)
        return self.fc(h[-1])

model3 = torch.load(
    '/content/drive/MyDrive/DeepFake/deepfake_detector/xception_gru_seq_model.pth',
    map_location=device,
    weights_only=False
)
model3.eval()

transform3 = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
])

def extract_frames(video_path, num_frames=10):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while len(frames) < num_frames and cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    while len(frames) < num_frames:
        frames.append(frames[-1])
    return frames[:num_frames]

def preprocess_frames(frames):
    transformed = []
    for frame in frames:
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame_rgb)
        img_t = transform3(img)
        transformed.append(img_t)
    return torch.stack(transformed)

def predict_code3_video(video_path):
    frames = extract_frames(video_path)
    input_tensor = preprocess_frames(frames).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model3(input_tensor)
    predicted_class = torch.argmax(output).item()
    return 'Real' if predicted_class == 0 else 'Fake'

# ==== Weighted Average ensemble ====
def biased_majority_vote(predictions):
    weights = {'Code1': 1, 'Code2': 2, 'Code3': 0}
    scores = {'Real': 0, 'Fake': 0}
    for idx, pred in enumerate(predictions):
        model_name = f'Code{idx+1}'
        scores[pred] += weights[model_name]
    final = max(scores, key=scores.get)
    return final

# ==== Process Input ====
def process_input(input_path):
    if input_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
        print("Detected: Video Input")
        code2_result = predict_code2_video(input_path)
        frames = extract_frames(input_path)
        img = Image.fromarray(cv2.cvtColor(frames[0], cv2.COLOR_BGR2RGB))
        image_tensor = transform1(img).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model1(image_tensor)
            pred = torch.argmax(F.softmax(output, dim=1)).item()
        code1_result = 'Fake' if pred == 0 else 'Real'
        code3_result = predict_code3_video(input_path)
        predictions = [code1_result, code2_result, code3_result]
    else:
        print("Detected: Image Input")
        code1_result = predict_code1_image(input_path)
        img = Image.open(input_path).convert('RGB')
        frame_tensor = video_transform(img).unsqueeze(0).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model2(frame_tensor)
            pred = torch.argmax(torch.softmax(output, dim=1)).item()
        code2_result = "Real" if pred == 0 else "Fake"
        img_t = transform3(img).unsqueeze(0).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model3(img_t)
            pred = torch.argmax(output).item()
        code3_result = 'Real' if pred == 0 else 'Fake'
        predictions = ['Fake' if code1_result == 'Deepfake' else 'Real', code2_result, code3_result]

    final = biased_majority_vote(predictions)
    print(f"\nVotes: {predictions}")
    print(f"Final Decision: {final}")
    return final

# === UNZIP THE DATASET ===
zip_path = "/content/drive/MyDrive/DeepFake/Overall_Test.zip"
extract_to = "/content/deepfake_data"

if os.path.exists(extract_to):
    shutil.rmtree(extract_to)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

# === COLLECT FILE PATHS ===
real_dir = os.path.join(extract_to, "Overall_Test/Real")
fake_dir = os.path.join(extract_to, "Overall_Test/Fake")

real_paths = [os.path.join(real_dir, f) for f in os.listdir(real_dir) if f.endswith(('.jpg', '.png'))]
fake_paths = [os.path.join(fake_dir, f) for f in os.listdir(fake_dir) if f.endswith(('.jpg', '.png'))]

# === EVALUATE ===
y_true = []
y_pred = []

print("----- Predicting Real Images -----\n")
for path in real_paths:
    pred = process_input(path)
    y_true.append("Real")
    y_pred.append(pred)
    print(f"{os.path.basename(path)} → Predicted: {pred} | Actual: Real")

print("\n----- Predicting Fake Images -----\n")
for path in fake_paths:
    pred = process_input(path)
    y_true.append("Fake")
    y_pred.append(pred)
    print(f"{os.path.basename(path)} → Predicted: {pred} | Actual: Fake")

# === METRICS & VISUALIZATION ===
accuracy = accuracy_score(y_true, y_pred)
print(f"\n\nOverall Accuracy: {accuracy*100:.2f}%")
print("\nClassification Report:\n", classification_report(y_true, y_pred))

# Confusion Matrix
conf_matrix = confusion_matrix(y_true, y_pred, labels=["Real", "Fake"])
plt.figure(figsize=(6,5))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=["Real", "Fake"], yticklabels=["Real", "Fake"])
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

# Bar plot: Correct vs Incorrect
correct = sum(1 for a, b in zip(y_true, y_pred) if a == b)
incorrect = len(y_true) - correct

plt.figure(figsize=(5,4))
sns.barplot(x=["Correct", "Incorrect"], y=[correct, incorrect], palette="Set2")
plt.title("Prediction Accuracy Summary")
plt.ylabel("Number of Images")
plt.show()