In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install mlflow torch torchvision timm --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.7/24.7 MB[0m [31m87.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m85.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m58.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import mlflow

MLFLOW_DIR = "/content/drive/MyDrive/deepfake-detection/runs"
mlflow.set_tracking_uri(f"file:{MLFLOW_DIR}")

In [10]:
import os
import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, recall_score, precision_score
import numpy as np
from PIL import Image
import random
import timm
import matplotlib.pyplot as plt
import seaborn as sns

# MLFlow Setup


In [11]:
config = {
    "experiment_name": "custom_deepfake_multiclass_large_random",
    "model_name": "custom",
    "data_path": "/content/drive/MyDrive/deepfake-detection/datasets/large_dataset_random",
    "batch_size": 32,
    "num_epochs": 10,
    "learning_rate": 1e-4,
    "img_height": 218,
    "img_width": 178,
    "dropout": 0.2,
    "optimizer": "adam",
    "loss_fn": "crossentropy",
    "random_seed": 42
}

In [12]:
mlflow.set_tracking_uri("file:/content/drive/MyDrive/deepfake-detection/runs")
mlflow.set_experiment(config["experiment_name"])

<Experiment: artifact_location='file:///content/drive/MyDrive/deepfake-detection/runs/993973592097192765', creation_time=1749687457511, experiment_id='993973592097192765', last_update_time=1749687457511, lifecycle_stage='active', name='custom_deepfake_multiclass_large_random', tags={}>

# Data Preparation

In [13]:
def compute_fft_channel(img):
    # img: [H,W,3], numpy, range [0,255]
    gray = np.mean(img, axis=2)
    fft = np.abs(np.fft.fft2(gray))
    fft = np.fft.fftshift(fft)
    fft = np.log(fft + 1)
    fft = (fft - fft.min()) / (fft.max() - fft.min() + 1e-8)
    return fft.astype(np.float32)

In [14]:
class MultiClassDataset(Dataset):
    def __init__(self, root_dir, transform=None, real_fraction=0.2, real_class_name="real", seed=42):
        self.samples = []
        self.transform = transform
        self.class_to_idx = {}
        self.idx_to_class = {}

        subdirs = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        for idx, subdir in enumerate(subdirs):
            self.class_to_idx[subdir] = idx
            self.idx_to_class[idx] = subdir

        rng = random.Random(seed)
        for subdir in subdirs:
            subdir_path = os.path.join(root_dir, subdir)
            label = self.class_to_idx[subdir]
            images = [
                os.path.join(subdir_path, fname)
                for fname in os.listdir(subdir_path)
                if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
            ]
            if subdir.lower() == real_class_name.lower() and real_fraction < 1.0:
                n_keep = max(1, int(len(images) * real_fraction))
                images = rng.sample(images, n_keep)
            self.samples.extend((img_path, label) for img_path in images)
        self.num_classes = len(self.class_to_idx)
        self.total_images = len(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB").resize((224,224))
        img_np = np.array(img).astype(np.float32) / 255.0
        fft_channel = compute_fft_channel((img_np * 255).astype(np.uint8))
        img_4ch = np.concatenate([img_np.transpose(2,0,1), fft_channel[None]], axis=0)
        return torch.tensor(img_4ch, dtype=torch.float32), int(label)

In [15]:
torch.manual_seed(config["random_seed"])

transform = transforms.Compose([
    transforms.Resize((config["img_height"], config["img_width"])),
    transforms.ToTensor()
])

dataset = MultiClassDataset(config["data_path"], transform=transform)

total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(config["random_seed"])
)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=2)

# Model Definition

In [16]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
    def forward(self, x):
        w = x.mean(dim=(2,3))
        w = F.relu(self.fc1(w))
        w = torch.sigmoid(self.fc2(w)).unsqueeze(-1).unsqueeze(-1)
        return x * w

class DeepfakeNet(nn.Module):
    def __init__(self, in_channels=4, num_classes=5):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.se1 = SEBlock(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.se2 = SEBlock(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.attention_head = nn.Conv2d(128, 1, 1)
        self.fc1 = nn.Linear(128 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.se1(self.bn1(self.conv1(x)))))
        x = self.pool(F.relu(self.se2(self.bn2(self.conv2(x)))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        attn_map = torch.sigmoid(self.attention_head(x))
        x_flat = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x_flat))
        logits = self.fc2(x)
        return logits, attn_map


In [23]:
def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    all_labels, all_outputs = [], []
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(images)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        all_labels.extend(labels.cpu().numpy())
        all_outputs.extend(logits.detach().cpu().numpy())
    acc = correct / total
    avg_loss = running_loss / total
    return avg_loss, acc, np.array(all_labels), np.array(all_outputs)

In [24]:
def evaluate(model, loader, loss_fn, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_labels, all_outputs = [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            logits, _ = model(images)
            loss = loss_fn(logits, labels)
            running_loss += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_outputs.extend(logits.cpu().numpy())
    acc = correct / total
    avg_loss = running_loss / total
    return avg_loss, acc, np.array(all_labels), np.array(all_outputs)

In [25]:
def plot_and_log_curve(train_values, val_values, ylabel, fname):
    import matplotlib.pyplot as plt
    plt.figure()
    plt.plot(train_values, label='Train')
    plt.plot(val_values, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel(ylabel)
    plt.title(f'{ylabel} Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig(fname)
    mlflow.log_artifact(fname)
    plt.close()

In [26]:
def plot_and_log_confusion_matrix(y_true, y_pred, step, label="val"):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(4,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"{label.capitalize()} Confusion Matrix")
    fname = f"{label}_confusion_matrix_{step}.png"
    plt.savefig(fname)
    mlflow.log_artifact(fname)
    plt.close()

In [27]:
import os
import shutil
import numpy as np
import warnings
from sklearn.exceptions import UndefinedMetricWarning

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepfakeNet().to(device)
loss_fn = nn.CrossEntropyLoss()


if config["optimizer"].lower() == "adagrad":
    optimizer = optim.Adagrad(model.parameters(), lr=config["learning_rate"])
else:
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])

best_val_f1 = 0
best_model_path = "/tmp/best_custom.pth"

with mlflow.start_run():
    mlflow.log_params(config)

    train_loss_list, train_acc_list = [], []
    val_loss_list, val_acc_list = [], []

    for epoch in range(config["num_epochs"]):
        train_loss, train_acc, train_labels, train_outputs = train_one_epoch(
            model, train_loader, optimizer, loss_fn, device)
        val_loss, val_acc, val_labels, val_outputs = evaluate(
            model, val_loader, loss_fn, device)
        val_labels = np.array(val_labels).astype(int).flatten()

        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)

        val_preds = np.argmax(val_outputs, axis=1)
        val_preds = np.array(val_preds).astype(int).flatten()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UndefinedMetricWarning)
            if len(np.unique(val_labels)) < 2:
                val_f1 = 0.0
                val_tpr = 0.0
                val_fpr = 0.0
                val_auc = float('nan')
            else:
                val_f1 = f1_score(val_labels, val_preds, average="macro", zero_division=0)
                val_tpr = recall_score(val_labels, val_preds, average="macro", zero_division=0)
                val_fpr = 1 - precision_score(val_labels, val_preds, average="macro", zero_division=0)
                try:
                    val_auc = roc_auc_score(val_labels, val_outputs)
                except:
                    val_auc = float('nan')

        mlflow.log_metrics({
            "train_loss": train_loss, "train_acc": train_acc,
            "val_loss": val_loss, "val_acc": val_acc,
            "val_f1": val_f1, "val_tpr": val_tpr, "val_auc": val_auc,
        }, step=epoch)

        print(f"Epoch {epoch+1}/{config['num_epochs']}: "
              f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}, "
              f"val_f1={val_f1:.4f}, val_auc={val_auc:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), best_model_path)
            mlflow.log_artifact(best_model_path)

    plot_and_log_curve(train_loss_list, val_loss_list, "Loss", "loss_curve.png")
    plot_and_log_curve(train_acc_list, val_acc_list, "Accuracy", "accuracy_curve.png")

    # Save the full model for reproducibility (architecture + weights)
    final_model_path = "/tmp/final_model"
    mlflow.pytorch.save_model(model, final_model_path)
    mlflow.log_artifacts(final_model_path, artifact_path="final_model")
    shutil.rmtree(final_model_path)

    print("Training finished. Best validation F1:", best_val_f1)
    plot_and_log_confusion_matrix(val_labels, val_preds, step="final", label="val")


Epoch 1/10: train_loss=0.8741, train_acc=0.6580, val_loss=0.6428, val_acc=0.7293, val_f1=0.7249, val_auc=nan
Epoch 2/10: train_loss=0.4237, train_acc=0.8229, val_loss=0.6143, val_acc=0.7373, val_f1=0.7189, val_auc=nan
Epoch 3/10: train_loss=0.3174, train_acc=0.8786, val_loss=0.3809, val_acc=0.8493, val_f1=0.8507, val_auc=nan
Epoch 4/10: train_loss=0.2100, train_acc=0.9346, val_loss=0.3596, val_acc=0.8507, val_f1=0.8547, val_auc=nan
Epoch 5/10: train_loss=0.1556, train_acc=0.9574, val_loss=0.3408, val_acc=0.8573, val_f1=0.8622, val_auc=nan
Epoch 6/10: train_loss=0.1068, train_acc=0.9766, val_loss=0.3371, val_acc=0.8547, val_f1=0.8584, val_auc=nan
Epoch 7/10: train_loss=0.0768, train_acc=0.9851, val_loss=0.3926, val_acc=0.8600, val_f1=0.8603, val_auc=nan
Epoch 8/10: train_loss=0.0611, train_acc=0.9914, val_loss=0.3571, val_acc=0.8653, val_f1=0.8665, val_auc=nan
Epoch 9/10: train_loss=0.0615, train_acc=0.9860, val_loss=0.3477, val_acc=0.8800, val_f1=0.8802, val_auc=nan
Epoch 10/10: train_



Training finished. Best validation F1: 0.8802312383288605


In [29]:
model.load_state_dict(torch.load(best_model_path))
model.eval()

test_labels, test_outputs = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits, attn_map = model(images)  # <-- Unpack!
        test_outputs.extend(logits.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

test_labels = np.array(test_labels).astype(int)
test_outputs = np.array(test_outputs)
test_preds = np.argmax(test_outputs, axis=1)


In [30]:
from sklearn.metrics import f1_score, recall_score, precision_score, roc_auc_score

# test_labels: shape (N,) integers; test_preds: shape (N,) integers; test_outputs: shape (N, num_classes) logits or probs

test_f1 = f1_score(test_labels, test_preds, average="macro", zero_division=0)
test_recall = recall_score(test_labels, test_preds, average="macro", zero_division=0)
test_precision = precision_score(test_labels, test_preds, average="macro", zero_division=0)

# For multiclass AUC, one-hot encode labels
import numpy as np
num_classes = np.max(test_labels) + 1
test_labels_onehot = np.eye(num_classes)[test_labels]

try:
    test_auc = roc_auc_score(test_labels_onehot, np.array(test_outputs), multi_class="ovr")
except Exception as e:
    test_auc = float('nan')

print(f"Test F1: {test_f1:.4f}, Test Recall: {test_recall:.4f}, Test Precision: {test_precision:.4f}, Test AUC: {test_auc:.4f}")


Test F1: 0.8937, Test Recall: 0.8956, Test Precision: 0.8948, Test AUC: 0.9607


In [31]:
mlflow.log_metrics({
    "test_f1": test_f1,
    "test_recall": test_recall,
    "test_precision": test_precision,
    "test_auc": test_auc
})


plot_and_log_confusion_matrix(test_labels, test_preds, step="final", label="test")

In [32]:
# Get the classes and their values from the model
classes = dataset.idx_to_class
class_values = list(classes.values())
classes, class_values

({0: 'real',
  1: 'stable_diffusion_xl',
  2: 'stylegan1',
  3: 'stylegan2',
  4: 'thispersondoesnotexist'},
 ['real',
  'stable_diffusion_xl',
  'stylegan1',
  'stylegan2',
  'thispersondoesnotexist'])