In [None]:
!pip install torch torchvision timm kagglehub matplotlib seaborn scikit-learn tqdm

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import timm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import kagglehub

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Download dataset
dataset_path = kagglehub.dataset_download("manjilkarki/deepfake-and-real-images")
base_dir = os.path.join(dataset_path, "Dataset")

In [None]:
# Data augmentation and preprocessing
def get_transforms(augment=False):
    if augment:
        return transforms.Compose([
            transforms.RandomApply([
                transforms.RandomChoice([
                    transforms.RandomHorizontalFlip(p=1.0),
                    transforms.RandomVerticalFlip(p=1.0),
                    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                    transforms.RandomRotation(degrees=10),
                    transforms.GaussianBlur(kernel_size=3),
                ])
            ], p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def create_dataset(data_dir, augment=False, batch_size=128):
    transform = get_transforms(augment)
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=augment, num_workers=4, pin_memory=True)

In [None]:
# Create datasets
train_ds = create_dataset(os.path.join(base_dir, "Train"), augment=True)
val_ds = create_dataset(os.path.join(base_dir, "Validation"))
test_ds = create_dataset(os.path.join(base_dir, "Test"))

class_names = ["Deepfake", "Real"]

In [None]:
print(train_ds.dataset.class_to_idx)

In [None]:
# Initialize Vision Transformer model with multi-GPU support
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('vit_base_patch16_224', img_size=256, pretrained=False, num_classes=2)

In [None]:
# Wrap model for multi-GPU training
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model = model.to(device)

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

# Training parameters
best_val_loss = float('inf')
patience = 4
current_patience = 0
checkpoint_path = 'best_model_vit.pth'
num_epochs = 20

In [None]:
from tqdm.notebook import tqdm
from torch.amp import autocast, GradScaler

val_interval = 2  # Run validation every 2 epochs
scaler = GradScaler(device='cuda')

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    for inputs, labels in tqdm(train_ds, desc=f"Training Epoch {epoch+1}", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        # Mixed-precision forward pass
        with autocast("cuda"):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        
        # Backward pass with scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item() * inputs.size(0)
    
    train_loss /= len(train_ds.dataset)
    print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}', flush=True)
    
    # Run validation only every 'val_interval' epochs or at the last epoch
    if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs:
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in tqdm(val_ds, desc=f"Validation Epoch {epoch+1}", leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_loss /= len(val_ds.dataset)
        val_accuracy = accuracy_score(all_labels, all_preds)
        val_precision = precision_score(all_labels, all_preds)
        val_recall = recall_score(all_labels, all_preds)
        val_f1 = f1_score(all_labels, all_preds)

        
        scheduler.step(val_loss)
        current_lr = scheduler.get_last_lr()  # Get the current learning rate(s)
        print(f"Current learning rate: {current_lr}", flush=True)
        
        print(f'Epoch {epoch+1} Validation | Loss: {val_loss:.4f} | Acc: {val_accuracy:.4f} | '
              f'Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | F1: {val_f1:.4f}', flush=True)
        
        # Early stopping and best model checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), checkpoint_path)
            else:
                torch.save(model.state_dict(), checkpoint_path)
            current_patience = 0
        else:
            current_patience += 1
            if current_patience >= patience:
                print(f'Early stopping at epoch {epoch+1}', flush=True)
                break


In [None]:
# Load best model for testing
if isinstance(model, nn.DataParallel):
    model.module.load_state_dict(torch.load(checkpoint_path))
else:
    model.load_state_dict(torch.load(checkpoint_path))
model.eval()

In [None]:
# Test evaluation
test_preds = []
test_labels = []
with torch.no_grad():
    for inputs, labels in test_ds:
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        
        test_preds.extend(preds.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

In [None]:
# Calculate test metrics
test_accuracy = accuracy_score(test_labels, test_preds)
test_precision = precision_score(test_labels, test_preds)
test_recall = recall_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds)

print('\nFinal Test Results:')
print(f'Accuracy: {test_accuracy:.4f}')
print(f'Precision: {test_precision:.4f}')
print(f'Recall: {test_recall:.4f}')
print(f'F1 Score: {test_f1:.4f}')

In [None]:
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# Classification report
print('\nClassification Report:')
print(classification_report(test_labels, test_preds, target_names=class_names))