In [None]:
!pip install torch torchvision torch-geometric scikit-learn tqdm seaborn matplotlib

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
from itertools import cycle
from google.colab import drive

drive.mount('/content/drive')

input_folder = "/content/drive/MyDrive/defense/data/test"
output_path = "/content/drive/MyDrive/defense/output"
os.makedirs(output_path, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(root=input_folder, transform=transform)
class_names = dataset.classes

feature_extractor = models.resnet18(pretrained=True)
feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-1])
feature_extractor.eval().to(device)

features, labels = [], []
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

with torch.no_grad():
    for images, targets in tqdm(dataloader, desc='Extracting Features'):
        images = images.to(device)
        outputs = feature_extractor(images)
        outputs = outputs.squeeze(-1).squeeze(-1)
        features.append(outputs.cpu())
        labels.append(targets.cpu())

features = torch.cat(features).numpy()
labels = torch.cat(labels).numpy()

idx = np.arange(len(labels))
idx_train, idx_test = train_test_split(idx, test_size=0.3, stratify=labels, random_state=42)
idx_val, idx_test = train_test_split(idx_test, test_size=0.5, stratify=labels[idx_test], random_state=42)

adj = kneighbors_graph(features, n_neighbors=7, metric='cosine', mode='connectivity', include_self=False)
edge_index = torch.tensor(np.array(adj.nonzero()), dtype=torch.long)

data = Data(
    x=torch.tensor(features, dtype=torch.float),
    edge_index=edge_index,
    y=torch.tensor(labels, dtype=torch.long)
).to(device)

class GCNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim // 2)
        self.conv3 = GCNConv(hidden_dim // 2, output_dim)
        self.dropout = nn.Dropout(0.6)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim // 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return F.log_softmax(x, dim=1)

def train_model(model, name, patience=20, max_epochs=300):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=1e-4)
    criterion = nn.NLLLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

    history = {"train_loss": [], "val_loss": [], "test_loss": [],
               "train_acc": [], "val_acc": [], "test_acc": []}
    best_val_acc, best_test_acc, best_epoch = 0, 0, 0
    patience_counter = 0

    for epoch in range(1, max_epochs+1):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[idx_train], data.y[idx_train])
        loss.backward()
        optimizer.step()
        scheduler.step()

        model.eval()
        with torch.no_grad():
            out = model(data)
            pred = out.argmax(dim=1)

            losses = [
                criterion(out[idx_train], data.y[idx_train]).item(),
                criterion(out[idx_val], data.y[idx_val]).item(),
                criterion(out[idx_test], data.y[idx_test]).item()
            ]
            accs = [
                (pred[idx_train] == data.y[idx_train]).sum().item() / len(idx_train),
                (pred[idx_val] == data.y[idx_val]).sum().item() / len(idx_val),
                (pred[idx_test] == data.y[idx_test]).sum().item() / len(idx_test)
            ]

        history["train_loss"].append(losses[0])
        history["val_loss"].append(losses[1])
        history["test_loss"].append(losses[2])
        history["train_acc"].append(accs[0])
        history["val_acc"].append(accs[1])
        history["test_acc"].append(accs[2])

        if accs[1] > best_val_acc:
            best_val_acc, best_test_acc, best_epoch = accs[1], accs[2], epoch
            torch.save(model.state_dict(), os.path.join(output_path, f"{name}_best.pth"))
            patience_counter = 0
        else:
            patience_counter += 1

        if epoch % 20 == 0:
            print(f"Epoch {epoch} | Train: {accs[0]:.4f} | Val: {accs[1]:.4f} | Test: {accs[2]:.4f}")

        if patience_counter > patience:
            print("Early stopping triggered")
            break

    model.load_state_dict(torch.load(os.path.join(output_path, f"{name}_best.pth")))
    model.eval()
    with torch.no_grad():
        out = model(data)
        pred = out.argmax(dim=1)

    cm = confusion_matrix(data.y[idx_test].cpu(), pred[idx_test].cpu())
    report = classification_report(data.y[idx_test].cpu(), pred[idx_test].cpu(), target_names=class_names, output_dict=True)
    return best_val_acc, best_test_acc, best_epoch, history, cm, pd.DataFrame(report).transpose(), out

gcn_model = GCNModel(
    input_dim=data.num_features,
    hidden_dim=128,
    output_dim=len(class_names)
)

best_val_acc, best_test_acc, best_epoch, history, cm, report, out = train_model(gcn_model, "GCN_FT")

print("Best Epoch:", best_epoch)
print("Validation Accuracy:", best_val_acc)
print("Test Accuracy:", best_test_acc)
print(report)

plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names, cmap="Blues")
plt.title("Fine-tuned GCN Confusion Matrix")
plt.show()

epochs = range(1, len(history["train_loss"]) + 1)

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(epochs, history["train_loss"], label="Train Loss")
plt.plot(epochs, history["val_loss"], label="Val Loss")
plt.plot(epochs, history["test_loss"], label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs, history["train_acc"], label="Train Acc")
plt.plot(epochs, history["val_acc"], label="Val Acc")
plt.plot(epochs, history["test_acc"], label="Test Acc")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Accuracy Curve")
plt.legend()
plt.show()

y_test = data.y[idx_test].cpu().numpy()
y_score = F.softmax(out[idx_test], dim=1).cpu().numpy()
fpr, tpr, _ = roc_curve(y_test, y_score[:, 1])
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8,6))
plt.plot(fpr, tpr, color="blue", label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve for Positive Class")
plt.legend()
plt.show()
