In [None]:
!pip install -q torch torchvision torch-geometric scikit-learn tqdm seaborn matplotlib

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import add_self_loops
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import label_binarize, StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import pandas as pd
import math
from collections import Counter
from torch.optim.swa_utils import AveragedModel, SWALR
from google.colab import drive

drive.mount('/content/drive')

input_folder = "/content/drive/MyDrive/defense/data/test"
output_path  = "/content/drive/MyDrive/defense/output_fusion_sota"
os.makedirs(output_path, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.75, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.08),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.15), ratio=(0.3, 3.3), value='random'),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

base_dataset = datasets.ImageFolder(root=input_folder)
class_names = base_dataset.classes
print("Classes:", class_names)

idx_all = np.arange(len(base_dataset.targets))
idx_train, idx_tmp = train_test_split(idx_all, test_size=0.30, stratify=base_dataset.targets, random_state=42)
idx_val, idx_test = train_test_split(idx_tmp, test_size=0.50, stratify=np.array(base_dataset.targets)[idx_tmp], random_state=42)

train_dataset = Subset(datasets.ImageFolder(root=input_folder, transform=train_transform), idx_train)
val_dataset   = Subset(datasets.ImageFolder(root=input_folder, transform=eval_transform),   idx_val)
test_dataset  = Subset(datasets.ImageFolder(root=input_folder, transform=eval_transform),  idx_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

feature_extractor = models.resnet18(pretrained=True)
feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-1])
feature_extractor.eval().to(device)

@torch.no_grad()
def extract_features(loader):
    feats, labs = [], []
    for images, targets in tqdm(loader, desc="Extracting Features"):
        images = images.to(device, non_blocking=True)
        out = feature_extractor(images)
        out = out.squeeze(-1).squeeze(-1)
        feats.append(out.cpu())
        labs.append(targets.cpu())
    return torch.cat(feats).numpy(), torch.cat(labs).numpy()

features_train, labels_train = extract_features(train_loader)
features_val,   labels_val   = extract_features(val_loader)
features_test,  labels_test  = extract_features(test_loader)

scaler = StandardScaler().fit(features_train)
def preprocess(X):
    Xs = scaler.transform(X)
    norm = np.linalg.norm(Xs, axis=1, keepdims=True) + 1e-12
    return Xs / norm

features_train = preprocess(features_train)
features_val   = preprocess(features_val)
features_test  = preprocess(features_test)

X = np.concatenate([features_train, features_val, features_test], axis=0)
y = np.concatenate([labels_train,  labels_val,   labels_test],  axis=0)

idx_train = np.arange(len(labels_train))
idx_val   = np.arange(len(labels_train), len(labels_train) + len(labels_val))
idx_test  = np.arange(len(labels_train) + len(labels_val), len(y))

print(f"X={X.shape}, y={y.shape} | train={len(idx_train)}, val={len(idx_val)}, test={len(idx_test)}")

def mutual_knn_graph(X, k=7):
    A = kneighbors_graph(X, n_neighbors=k, metric="cosine", mode="connectivity", include_self=False)
    A_mutual = A.minimum(A.T)
    edge_idx = np.array(A_mutual.nonzero())
    return torch.tensor(edge_idx, dtype=torch.long)

edge_index = mutual_knn_graph(X, k=7)
edge_index, _ = add_self_loops(edge_index)

data = Data(
    x=torch.tensor(X, dtype=torch.float),
    edge_index=edge_index,
    y=torch.tensor(y, dtype=torch.long)
).to(device)

num_features = data.num_features
num_classes  = len(class_names)

class_count = Counter(labels_train.tolist())
weights = torch.ones(num_classes, dtype=torch.flo_
