In [None]:
!pip install torch torchvision torch-geometric scikit-learn tqdm seaborn matplotlib

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import pandas as pd
from itertools import cycle
from google.colab import drive

drive.mount('/content/drive')

input_folder = "/content/drive/MyDrive/defense/data/test"
output_path  = "/content/drive/MyDrive/defense/output"
os.makedirs(output_path, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

full_dataset = datasets.ImageFolder(root=input_folder)
class_names = full_dataset.classes
print("Classes:", class_names)

idx = np.arange(len(full_dataset.targets))
idx_train, idx_test = train_test_split(idx, test_size=0.3, stratify=full_dataset.targets, random_state=42)
idx_val, idx_test  = train_test_split(idx_test, test_size=0.5, stratify=np.array(full_dataset.targets)[idx_test], random_state=42)

train_dataset = Subset(datasets.ImageFolder(root=input_folder, transform=train_transform), idx_train)
val_dataset   = Subset(datasets.ImageFolder(root=input_folder, transform=test_transform), idx_val)
test_dataset  = Subset(datasets.ImageFolder(root=input_folder, transform=test_transform), idx_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

feature_extractor = models.resnet18(pretrained=True)
feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-1])
feature_extractor.eval().to(device)

def extract_features(loader):
    feats, labs = [], []
    with torch.no_grad():
        for images, targets in tqdm(loader, desc="Extracting Features"):
            images = images.to(device)
            outputs = feature_extractor(images)
            outputs = outputs.squeeze(-1).squeeze(-1)
            feats.append(outputs.cpu())
            labs.append(targets.cpu())
    return torch.cat(feats).numpy(), torch.cat(labs).numpy()

features_train, labels_train = extract_features(train_loader)
features_val,   labels_val   = extract_features(val_loader)
features_test,  labels_test  = extract_features(test_loader)

features = np.concatenate([features_train, features_val, features_test], axis=0)
labels   = np.concatenate([labels_train, labels_val, labels_test], axis=0)

idx_train = np.arange(len(labels_train))
idx_val   = np.arange(len(labels_train), len(labels_train)+len(labels_val))
idx_test  = np.arange(len(labels_train)+len(labels_val), len(labels))

adj = kneighbors_graph(features, n_neighbors=5, metric="cosine", mode="connectivity", include_self=False)
edge_index = torch.tensor(np.array(adj.nonzero()), dtype=torch.long)

data = Data(
    x=torch.tensor(features, dtype=torch.float),
    edge_index=edge_index,
    y=torch.tensor(labels, dtype=torch.long)
).to(device)

class LightHybridGCNGAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=2):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads, dropout=0.5)
        self.fc1  = nn.Linear(hidden_dim * heads, hidden_dim)
        self.fc2  = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.5)
        self.bn = nn.BatchNorm1d(hidden_dim * heads)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x_gcn = F.relu(self.gcn1(x, edge_index))
        x_gat = F.elu(self.gat1(x_gcn, edge_index))
        x = self.bn(x_gat)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train_model(model, name, max_epochs=10000):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=1e-4)
    criterion = nn.NLLLoss()
    best_val_acc, best_test_acc, best_epoch = 0, 0, 0
    history = {"train_loss": [], "val_loss": [], "test_loss": [],
               "train_acc": [], "val_acc": [], "test_acc": []}
    out = None

    for epoch in range(1, max_epochs+1):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[idx_train], data.y[idx_train])
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            out = model(data)
            pred = out.argmax(dim=1)
            losses = [
                criterion(out[idx_train], data.y[idx_train]).item(),
                criterion(out[idx_val],   data.y[idx_val]).item(),
                criterion(out[idx_test],  data.y[idx_test]).item()
            ]
            accs = [
                (pred[idx_train] == data.y[idx_train]).sum().item()/len(idx_train),
                (pred[idx_val]   == data.y[idx_val]).sum().item()/len(idx_val),
                (pred[idx_test]  == data.y[idx_test]).sum().item()/len(idx_test)
            ]

        history["train_loss"].append(losses[0])
        history["val_loss"].append(losses[1])
        history["test_loss"].append(losses[2])
        history["train_acc"].append(accs[0])
        history["val_acc"].append(accs[1])
        history["test_acc"].append(accs[2])

        if accs[1] > best_val_acc:
            best_val_acc, best_test_acc, best_epoch = accs[1], accs[2], epoch
            torch.save(model.state_dict(), os.path.join(output_path, f"{name}_best.pth"))

        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Train {accs[0]:.3f}/{losses[0]:.3f} | "
                  f"Val {accs[1]:.3f}/{losses[1]:.3f} | Test {accs[2]:.3f}/{losses[2]:.3f}")

    model.load_state_dict(torch.load(os.path.join(output_path, f"{name}_best.pth")))
    return best_val_acc, best_test_acc, best_epoch, history, out

hybrid_model = LightHybridGCNGAT(data.num_features, 32, len(class_names), heads=2)
best_val, best_test, epoch, history, out = train_model(hybrid_model, "LightHybrid")

print("Best Epoch:", epoch)
print("Validation Accuracy:", best_val)
print("Test Accuracy:", best_test)

epochs = range(1, len(history["train_loss"]) + 1)

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(epochs, history["train_loss"], label="Train Loss")
plt.plot(epochs, history["val_loss"], label="Val Loss")
plt.plot(epochs, history["test_loss"], label="Test Loss")
plt.xlabel("Epochs"); plt.ylabel("Loss"); plt.title("Loss Curve"); plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs, history["train_acc"], label="Train Acc")
plt.plot(epochs, history["val_acc"], label="Val Acc")
plt.plot(epochs, history["test_acc"], label="Test Acc")
plt.xlabel("Epochs"); plt.ylabel("Accuracy"); plt.title("Accuracy Curve"); plt.legend()
plt.show()

pred = out.argmax(dim=1)
cm = confusion_matrix(data.y[idx_test].cpu(), pred[idx_test].cpu())
plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names, cmap="mako")
plt.title("LightHybrid GCN+GAT Confusion Matrix")
plt.xlabel("Predicted"); plt.ylabel("True")
plt.show()

y_test = data.y[idx_test].cpu().numpy()
probs  = F.softmax(out[idx_test], dim=1).cpu().numpy()

if probs.shape[1] == 2:
    fpr, tpr, _ = roc_curve(y_test, probs[:, 1])
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8,6))
    plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Positive Class)")
    plt.legend()
    plt.show()
else:
    y_bin = label_binarize(y_test, classes=np.arange(probs.shape[1]))
    fpr, tpr, _ = roc_curve(y_bin.ravel(), probs.ravel())
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8,6))
    plt.plot(fpr, tpr, label=f"Micro-average ROC (AUC = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Micro-average ROC (Multi-class)")
    plt.legend()
    plt.show()
