In [14]:
# Hybrid CNN + Transformer for CIFAKE (Notebook Style)

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
from tqdm import tqdm  # <-- added for progress bar
from sklearn.metrics import classification_report


In [15]:
# ---- Setup ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
EPOCHS = 5
IMG_SIZE = 128
NUM_CLASSES = 2

In [16]:
# ---- Transforms ----
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

In [6]:
# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("birdy654/cifake-real-and-ai-generated-synthetic-images")

# print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/cifake-real-and-ai-generated-synthetic-images


In [18]:
# ---- Dataset Load ----
full_train_dataset = datasets.ImageFolder('/kaggle/input/cifake-real-and-ai-generated-synthetic-images/train', transform=transform)
class_map = full_train_dataset.class_to_idx
print("Class Mapping:", class_map)
real_idx = class_map['REAL']
fake_idx = class_map['FAKE']

Class Mapping: {'FAKE': 0, 'REAL': 1}


In [19]:
def remap_targets(dataset):
    for i in range(len(dataset.targets)):
        if dataset.targets[i] == fake_idx:
            dataset.targets[i] = 1
        elif dataset.targets[i] == real_idx:
            dataset.targets[i] = 0

remap_targets(full_train_dataset)

In [20]:
# ---- Split into Train and Val ----
train_len = int(0.8 * len(full_train_dataset))
val_len = len(full_train_dataset) - train_len
train_dataset, val_dataset = random_split(full_train_dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [22]:
# ---- Test Set ----
test_dataset = datasets.ImageFolder('/kaggle/input/cifake-real-and-ai-generated-synthetic-images/test', transform=transform)
remap_targets(test_dataset)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [23]:
# ---- Hybrid Model ----
class HybridModel(nn.Module):
    def __init__(self):
        super(HybridModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.patch_embed = nn.Conv2d(256, 128, kernel_size=4, stride=4)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=128, nhead=4, dim_feedforward=256),
            num_layers=2
        )
        self.classifier = nn.Linear(128, NUM_CLASSES)

    def forward(self, x):
        x = self.cnn(x)
        x = self.patch_embed(x)         # [B, 128, 8, 8]
        x = x.flatten(2).permute(2, 0, 1)  # [64, B, 128]
        x = self.transformer(x)         # [64, B, 128]
        x = x.mean(dim=0)               # [B, 128]
        return self.classifier(x)


In [24]:
# ---- Train & Eval Functions ----
def train(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct = 0, 0
    for x, y in tqdm(loader, desc="Training"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        correct += (out.argmax(1) == y).sum().item()
    acc = correct / len(loader.dataset)
    return running_loss / len(loader), acc

In [25]:
def evaluate(model, loader, criterion):
    model.eval()
    loss, correct = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss += criterion(out, y).item()
            correct += (out.argmax(1) == y).sum().item()
    acc = correct / len(loader.dataset)
    return loss / len(loader), acc


In [26]:
# ---- Training ----
model = HybridModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")


Epoch 1/5


Training: 100%|██████████| 1250/1250 [07:36<00:00,  2.74it/s]


Train Acc: 0.8080 | Val Acc: 0.9019

Epoch 2/5


Training: 100%|██████████| 1250/1250 [04:50<00:00,  4.30it/s]


Train Acc: 0.9062 | Val Acc: 0.9190

Epoch 3/5


Training: 100%|██████████| 1250/1250 [04:28<00:00,  4.65it/s]


Train Acc: 0.9246 | Val Acc: 0.9278

Epoch 4/5


Training: 100%|██████████| 1250/1250 [04:22<00:00,  4.76it/s]


Train Acc: 0.9325 | Val Acc: 0.9406

Epoch 5/5


Training: 100%|██████████| 1250/1250 [04:22<00:00,  4.77it/s]


Train Acc: 0.9413 | Val Acc: 0.9459


In [27]:
# ---- Testing ----
model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for x, y in tqdm(test_loader, desc="Testing"):
        x = x.to(device)
        out = model(x)
        preds = out.argmax(1).cpu()
        y_true.extend(y.numpy())
        y_pred.extend(preds.numpy())

print("\nTest Classification Report:")
print(classification_report(y_true, y_pred, target_names=["REAL", "FAKE"]))

Testing: 100%|██████████| 313/313 [01:59<00:00,  2.62it/s]


Test Classification Report:
              precision    recall  f1-score   support

        REAL       0.95      0.93      0.94     10000
        FAKE       0.93      0.95      0.94     10000

    accuracy                           0.94     20000
   macro avg       0.94      0.94      0.94     20000
weighted avg       0.94      0.94      0.94     20000






In [28]:
# ---- Save ----
torch.save(model.state_dict(), "hybrid_cifake.pth")

In [None]:
# # 1. Make sure the architecture is defined (HybridModel class must be included above or imported)
# model = HybridModel().to(device)

# # 2. Load the saved weights
# model.load_state_dict(torch.load("hybrid_cifake.pth", map_location=device))

# # 3. Set to evaluation mode
# model.eval()