In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets
from transformers import ViTForImageClassification


In [4]:
# Block 2: Data Preparation and Transformation
data_dir = "/kaggle/input/deepfake-image-detection"

# Define the image transformations (resize to 224x224, convert to tensor, and normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Create training and validation datasets using ImageFolder.
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train-20250112T065955Z-001'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test-20250112T065939Z-001'), transform=transform)

# Create DataLoaders for efficient batching and shuffling
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# Block 3a: Load Pre-trained ViT Model from Local Files downloaded from hugging face.
local_vit_path = "/kaggle/input/pretrained-vit"  

model_vit = ViTForImageClassification.from_pretrained(
    local_vit_path,
    num_labels=2, 
    ignore_mismatched_sizes=True  # Allows the classifier head to reinitialize for 2 classes.
)
#  Explicitly replace the classifier head:
model_vit.classifier = nn.Linear(model_vit.config.hidden_size, 2)
model_vit = model_vit.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at /kaggle/input/pretrained-vit and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
#Load a Pre-trained ResNet Model and Adjust for Binary Classification
from torchvision.models import resnet18

model_resnet = resnet18(pretrained=True)
# Modify the final fully-connected layer for 2 classes.
num_ftrs = model_resnet.fc.in_features
model_resnet.fc = nn.Linear(num_ftrs, 2)
model_resnet = model_resnet.to(device)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 218MB/s]


In [8]:
# Defining a Training Function and then Training Each Model.
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()  # Set model to training mode
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        # For ViT, outputs are in outputs.logits; for ResNet, outputs are directly logits.
        logits = outputs.logits if hasattr(outputs, "logits") else outputs
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(logits, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

criterion = nn.CrossEntropyLoss()
optimizer_vit = optim.AdamW(model_vit.parameters(), lr=2e-5)
optimizer_resnet = optim.AdamW(model_resnet.parameters(), lr=2e-5)

num_epochs = 3  
for epoch in range(num_epochs):
    vit_loss, vit_acc = train_epoch(model_vit, train_loader, optimizer_vit, criterion, device)
    resnet_loss, resnet_acc = train_epoch(model_resnet, train_loader, optimizer_resnet, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"ViT -> Loss: {vit_loss:.4f}, Acc: {vit_acc:.4f}")
    print(f"ResNet -> Loss: {resnet_loss:.4f}, Acc: {resnet_acc:.4f}\n")




Epoch 1/3
ViT -> Loss: 0.5002, Acc: 0.9374
ResNet -> Loss: 1.2166, Acc: 0.0710





Epoch 2/3
ViT -> Loss: 0.1840, Acc: 1.0000
ResNet -> Loss: 0.9427, Acc: 0.1190





Epoch 3/3
ViT -> Loss: 0.0509, Acc: 1.0000
ResNet -> Loss: 0.7279, Acc: 0.4572



In [9]:
# Ensemble Prediction using a Voting Ensemble

def ensemble_predict(images):
    # Setrting both models to evaluation mode
    model_vit.eval()
    model_resnet.eval()
    with torch.no_grad():
        # Get predictions from each model.
        outputs_vit = model_vit(images)
        outputs_resnet = model_resnet(images)
        
        # Retrieve logits; ViT returns outputs.logits
        logits_vit = outputs_vit.logits if hasattr(outputs_vit, "logits") else outputs_vit
        logits_resnet = outputs_resnet.logits if hasattr(outputs_resnet, "logits") else outputs_resnet
        
        # Convert logits to probabilities
        probs_vit = F.softmax(logits_vit, dim=1)
        probs_resnet = F.softmax(logits_resnet, dim=1)
        
        # Average the probabilities from both models
        avg_probs = (probs_vit + probs_resnet) / 2.0
        
        # Final prediction is the class with the highest averaged probability
        preds = torch.argmax(avg_probs, dim=1)
    return preds

# Evaluating the ensemble on the validation set
def evaluate_ensemble(val_loader, device):
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            preds = ensemble_predict(images)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    accuracy = correct / total
    return accuracy

ensemble_acc = evaluate_ensemble(val_loader, device)
print(f"Ensemble Validation Accuracy: {ensemble_acc:.4f}")




Ensemble Validation Accuracy: 1.0000


In [10]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

In [11]:
# Function to Collect Predictions and True Labels from a DataLoader
def get_predictions_and_labels(model, data_loader, device):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []
    
    with torch.no_grad():  # No gradients needed during inference
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # For ViT, outputs are in outputs.logits; for ResNet, outputs are directly logits.
            logits = outputs.logits if hasattr(outputs, "logits") else outputs
            
            # Get predicted class indices
            _, preds = torch.max(logits, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return all_preds, all_labels


In [12]:
# Get Predictions and Labels for ViT and ResNet
vit_preds, vit_labels = get_predictions_and_labels(model_vit, val_loader, device)
resnet_preds, resnet_labels = get_predictions_and_labels(model_resnet, val_loader, device)

# Printing classification reports for additional insights:
print("ViT Classification Report:")
print(classification_report(vit_labels, vit_preds))

print("ResNet Classification Report:")
print(classification_report(resnet_labels, resnet_preds))




ViT Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       499

    accuracy                           1.00       499
   macro avg       1.00      1.00      1.00       499
weighted avg       1.00      1.00      1.00       499

ResNet Classification Report:
              precision    recall  f1-score   support

           0       1.00      0.60      0.75       499
           1       0.00      0.00      0.00         0

    accuracy                           0.60       499
   macro avg       0.50      0.30      0.37       499
weighted avg       1.00      0.60      0.75       499



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [1]:
# Computing Confusion Matrices for Both Models
cm_vit = confusion_matrix(vit_labels, vit_preds)
cm_resnet = confusion_matrix(resnet_labels, resnet_preds)

# Plotting the confusion matrices side-by-side for comparison.
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Confusion matrix for ViT
sns.heatmap(cm_vit, annot=True, fmt='d', cmap='Blues', ax=axes[0])
axes[0].set_title('ViT Confusion Matrix')
axes[0].set_xlabel('Predicted Label')
axes[0].set_ylabel('True Label')

# Confusion matrix for ResNet
sns.heatmap(cm_resnet, annot=True, fmt='d', cmap='Blues', ax=axes[1])
axes[1].set_title('ResNet Confusion Matrix')
axes[1].set_xlabel('Predicted Label')
axes[1].set_ylabel('True Label')

plt.tight_layout()
plt.show()


NameError: name 'confusion_matrix' is not defined