# GAN-Manipulated Medical Image Detection

This notebook implements the approach described in **A. S. & S. Narayan (2024, AMATHE)**:
- Preprocessing with **Local Binary Patterns (LBP)**
- Feature extraction via a **U-Net encoder**
- Classification using an **SVM**

Dataset here is synthetic (toy medical-like images with real/fake). You can later swap with real datasets like LIDC-IDRI.

In [None]:
!pip install torch torchvision scikit-learn scikit-image optuna matplotlib joblib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from skimage.feature import local_binary_pattern
import matplotlib.pyplot as plt

# ---------------- U-Net ----------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNet2D(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        self.d1 = DoubleConv(in_ch, 64)
        self.d2 = DoubleConv(64, 128)
        self.d3 = DoubleConv(128, 256)
        self.d4 = DoubleConv(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.emb_layer = nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        c1 = self.d1(x)
        c2 = self.d2(self.pool(c1))
        c3 = self.d3(self.pool(c2))
        c4 = self.d4(self.pool(c3))
        emb = self.emb_layer(c4).view(x.size(0), -1)
        return emb

# ---------------- Synthetic Dataset ----------------
class SyntheticMedicalDataset(Dataset):
    def __init__(self, n=200, size=64, fake_ratio=0.5):
        self.data, self.labels = [], []
        for i in range(n):
            img = np.zeros((size, size), dtype=np.float32)
            if np.random.rand() > fake_ratio:
                rr, cc = np.ogrid[:size, :size]
                mask = (rr - size//2)**2 + (cc - size//2)**2 < (size//4)**2
                img[mask] = 1.0
                label = 0
            else:
                img = np.random.rand(size, size).astype(np.float32)
                label = 1
            lbp = local_binary_pattern(img, P=8, R=1, method="uniform")
            lbp = lbp / lbp.max()
            self.data.append(lbp[None, ...])
            self.labels.append(label)
        self.data = torch.tensor(self.data, dtype=torch.float32)
        self.labels = torch.tensor(self.labels, dtype=torch.long)
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx], self.labels[idx]

# ---------------- Training ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = SyntheticMedicalDataset(n=300)
train_idx, test_idx = train_test_split(range(len(dataset)), test_size=0.3, stratify=dataset.labels)
train_loader = DataLoader(torch.utils.data.Subset(dataset, train_idx), batch_size=16, shuffle=True)
test_loader = DataLoader(torch.utils.data.Subset(dataset, test_idx), batch_size=16)

model = UNet2D().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    model.train()
    for x, _ in train_loader:
        x = x.to(device)
        emb = model(x)
        loss = torch.mean(emb**2)  # proxy unsupervised loss
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(f"Epoch {epoch+1}: loss={loss.item():.4f}")

def get_embeddings(loader):
    model.eval()
    embs, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            embs.append(model(x).cpu())
            labels.append(y)
    return torch.cat(embs), torch.cat(labels)

X_train, y_train = get_embeddings(train_loader)
X_test, y_test = get_embeddings(test_loader)

clf = SVC(kernel="rbf", C=1, gamma="scale")
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred))

In [None]:
# Visualize a few examples
import random
fig, axes = plt.subplots(2, 5, figsize=(10,4))
for ax in axes.ravel():
    idx = random.randint(0, len(dataset)-1)
    img, label = dataset[idx]
    ax.imshow(img.squeeze(), cmap="gray")
    ax.set_title("Real" if label.item()==0 else "Fake")
    ax.axis("off")
plt.tight_layout()
plt.show()