In [None]:
!pip install -q torch torchvision torch-geometric scikit-learn tqdm seaborn matplotlib

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import pandas as pd
from google.colab import drive

drive.mount('/content/drive')

input_folder = "/content/drive/MyDrive/defense/data/test"
output_path  = "/content/drive/MyDrive/defense/output_finehybrid"
os.makedirs(output_path, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.75, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.08),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

base_dataset = datasets.ImageFolder(root=input_folder)
class_names = base_dataset.classes
print("Classes:", class_names)

idx_all = np.arange(len(base_dataset.targets))
idx_train, idx_tmp = train_test_split(idx_all, test_size=0.30, stratify=base_dataset.targets, random_state=42)
idx_val, idx_test = train_test_split(idx_tmp, test_size=0.50, stratify=np.array(base_dataset.targets)[idx_tmp], random_state=42)

train_dataset_full = datasets.ImageFolder(root=input_folder, transform=train_transform)
val_dataset_full   = datasets.ImageFolder(root=input_folder, transform=eval_transform)
test_dataset_full  = datasets.ImageFolder(root=input_folder, transform=eval_transform)

train_dataset = Subset(train_dataset_full, idx_train)
val_dataset   = Subset(val_dataset_full,   idx_val)
test_dataset  = Subset(test_dataset_full,  idx_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

feature_extractor = models.resnet18(pretrained=True)
feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-1])
feature_extractor.eval().to(device)

@torch.no_grad()
def extract_features(loader):
    feats, labs = [], []
    for images, targets in tqdm(loader, desc="Extracting Features"):
        images = images.to(device, non_blocking=True)
        out = feature_extractor(images)
        out = out.squeeze(-1).squeeze(-1)
        feats.append(out.cpu())
        labs.append(targets.cpu())
    return torch.cat(feats).numpy(), torch.cat(labs).numpy()

features_train, labels_train = extract_features(train_loader)
features_val,   labels_val   = extract_features(val_loader)
features_test,  labels_test  = extract_features(test_loader)

X = np.concatenate([features_train, features_val, features_test], axis=0)
y = np.concatenate([labels_train,  labels_val,   labels_test],  axis=0)

idx_train = np.arange(len(labels_train))
idx_val   = np.arange(len(labels_train), len(labels_train) + len(labels_val))
idx_test  = np.arange(len(labels_train) + len(labels_val), len(y))

print(f"Shapes: X={X.shape}, y={y.shape}, train={len(idx_train)}, val={len(idx_val)}, test={len(idx_test)}")

adj = kneighbors_graph(X, n_neighbors=7, metric="cosine", mode="connectivity", include_self=False)
edge_index = torch.tensor(np.array(adj.nonzero()), dtype=torch.long)

data = Data(
    x=torch.tensor(X, dtype=torch.float),
    edge_index=edge_index,
    y=torch.tensor(y, dtype=torch.long)
).to(device)

class FineTunedHybridGCNGAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=4, p_drop=0.6):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim // 2)
        self.gcn3 = GCNConv(hidden_dim // 2, hidden_dim)
        self.bn_gcn1 = nn.BatchNorm1d(hidden_dim)
        self.bn_gcn2 = nn.BatchNorm1d(hidden_dim // 2)
        self.bn_gcn3 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(p_drop)
        self.gat1 = GATConv(input_dim, hidden_dim, heads=heads, dropout=p_drop)
        self.bn_gat1 = nn.BatchNorm1d(hidden_dim * heads)
        self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=2, dropout=p_drop)
        self.bn_gat2 = nn.BatchNorm1d(hidden_dim * 2)
        fused_dim = hidden_dim + (hidden_dim * 2)
        self.fc1 = nn.Linear(fused_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x_gcn = F.relu(self.gcn1(x, edge_index))
        x_gcn = self.bn_gcn1(x_gcn)
        x_gcn = self.dropout(x_gcn)
        x_gcn = F.relu(self.gcn2(x_gcn, edge_index))
        x_gcn = self.bn_gcn2(x_gcn)
        x_gcn = self.dropout(x_gcn)
        x_gcn = F.relu(self.gcn3(x_gcn, edge_index))
        x_gcn = self.bn_gcn3(x_gcn)
        x_gcn = self.dropout(x_gcn)
        x_gat = F.elu(self.gat1(x, edge_index))
        x_gat = self.bn_gat1(x_gat)
        x_gat = self.dropout(x_gat)
        x_gat = F.elu(self.gat2(x_gat, edge_index))
        x_gat = self.bn_gat2(x_gat)
        x_gat = self.dropout(x_gat)
        x_fused = torch.cat([x_gcn, x_gat], dim=1)
        x = F.relu(self.fc1(x_fused))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train_model(model, name, patience=25, max_epochs=300, lr=0.002, weight_decay=1e-4):
    model = model
