In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss
from torch_geometric.data import Data
from torch_geometric.nn import GATConv

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

cn_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
mci_path = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy"

Histogram_feature_CN_FA_array = np.load(cn_path, allow_pickle=True)
Histogram_feature_MCI_FA_array = np.load(mci_path, allow_pickle=True)

X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_MCI_FA_array.shape[0], dtype=np.int64)
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

rng = np.random.RandomState(42)
perm = rng.permutation(num_nodes)
X = X[perm]
y = y[perm]

def create_adj(F_data, alpha=1):
    norms = np.linalg.norm(F_data, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    F_norm = F_data / norms
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1.0, 0.0).astype(np.float32)
    if W.max() > 0:
        W = W / W.max()
    return W

alpha = 0.8
W = create_adj(X, alpha=alpha)
np.fill_diagonal(W, 1.0)

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    rows, cols = np.nonzero(adj > 0)
    edge_index = torch.from_numpy(np.vstack([rows, cols]).astype(np.int64))
    return Data(x=node_feats, edge_index=edge_index)

data = load_data(W, X).to(device)
print(data)

class GAT_AE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device,
                 heads=4, activ="RELU", dropout=0.25):
        super(GAT_AE, self).__init__()
        self.device = device
        self.heads = heads

        # --- Encoder ---
        # Multi-head attention (concatenates outputs by default)
        self.conv1 = GATConv(input_dim, hidden_dim // heads, heads=heads, dropout=dropout)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)

        # --- Decoder ---
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

        # --- Activation ---
        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu,
            "ELU": nnFn.elu
        }
        self.act = activations.get(activ, nnFn.relu)

    def forward(self, x, edge_index):
        # --- Encoder ---
        h = self.conv1(x, edge_index)
        h = self.bn1(h)
        h = self.act(h)
        h = self.dropout(h)

        # --- Classification / Embedding ---
        logits = self.fc(h)

        # --- Reconstruction ---
        recon = self.decoder(h)

        return logits, recon, h


feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 1e-4
weight_decay = 1e-4
batch_print_freq = 100
lambda_recon = 8

sss = StratifiedShuffleSplit(n_splits=20, test_size=0.3, random_state=42)
accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []
y_all = y.copy()

for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y_all), start=1):
    print(f"\n=== Fold {fold} ===")
    cn_idx = np.where(y_all == 0)[0]
    mci_idx = np.where(y_all == 1)[0]
    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.3, random_state=fold)
    cn_train_idx, _ = next(sss_class.split(X[cn_idx], y_all[cn_idx]))
    mci_train_idx, _ = next(sss_class.split(X[mci_idx], y_all[mci_idx]))
    cn_train = cn_idx[cn_train_idx]
    mci_train = mci_idx[mci_train_idx]
    train_idx_final = np.concatenate([cn_train, mci_train])
    np.random.shuffle(train_idx_final)

    print(f"Train CN: {len(cn_train)}, Train MCI: {len(mci_train)}")
    train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
    y_train_tensor = torch.from_numpy(y_all[train_idx_final]).long().to(device)

    model = GAT_AE(feats_dim, hidden_dim, num_classes, device, heads=4).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce_loss_fn = nn.CrossEntropyLoss()
    recon_loss_fn = nn.MSELoss()
    data = load_data(W, X).to(device)

    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        logits, recon_x, _ = model(data.x, data.edge_index)

        loss_sup = ce_loss_fn(logits[train_idx_t], y_train_tensor)
        loss_recon = recon_loss_fn(recon_x, data.x)
        total_loss = loss_sup + lambda_recon * loss_recon

        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print(f"NaN/Inf detected at fold {fold} epoch {epoch}. Aborting epoch.")
            break

        total_loss.backward()
        optimizer.step()

        if epoch % batch_print_freq == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval, _, _ = model(data.x, data.edge_index)
                preds_train = logits_eval[train_idx_t].argmax(dim=1)
                acc_train = accuracy_score(y_train_tensor.cpu(), preds_train.cpu())
            print(f"Fold {fold} Epoch {epoch}: TotalLoss={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | Recon={loss_recon.item():.6f} | TrainAcc={acc_train:.4f}")

    model.eval()
    with torch.no_grad():
        out_logits, _, _ = model(data.x, data.edge_index)
        preds = out_logits.argmax(dim=1).cpu().numpy()
        probs = nnFn.softmax(out_logits, dim=1)[:, 1].cpu().numpy()

    y_test = y_all[test_idx_global]
    y_pred_test = preds[test_idx_global]
    y_prob_test = probs[test_idx_global]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    try:
        auc = roc_auc_score(y_test, y_prob_test)
    except ValueError:
        auc = float('nan')
    ce = log_loss(y_test, y_prob_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    aucs.append(auc)
    ce_losses.append(ce)

    print(f"Fold {fold} → Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")

print("\n=== Average Results Across 20 Folds ===")
print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"AUC:       {np.nanmean(aucs):.4f} ± {np.nanstd(aucs):.4f}")
print(f"CE Loss:   {np.mean(ce_losses):.4f} ± {np.std(ce_losses):.4f}")


  from .autonotebook import tqdm as notebook_tqdm


Device: cuda
Features: (300, 180), Labels: (300,)
Data(x=[300, 180], edge_index=[2, 79964])

=== Fold 1 ===
Train CN: 93, Train MCI: 116
Fold 1 Epoch 1: TotalLoss=0.908737 | Sup=0.679824 | Recon=0.028614 | TrainAcc=0.4450
Fold 1 Epoch 100: TotalLoss=0.584955 | Sup=0.561494 | Recon=0.002933 | TrainAcc=0.7703
Fold 1 Epoch 200: TotalLoss=0.528403 | Sup=0.515885 | Recon=0.001565 | TrainAcc=0.7751
Fold 1 Epoch 300: TotalLoss=0.485471 | Sup=0.478387 | Recon=0.000886 | TrainAcc=0.7847
Fold 1 Epoch 400: TotalLoss=0.435465 | Sup=0.429836 | Recon=0.000704 | TrainAcc=0.8182
Fold 1 Epoch 500: TotalLoss=0.364107 | Sup=0.359085 | Recon=0.000628 | TrainAcc=0.8278
Fold 1 Epoch 600: TotalLoss=0.323401 | Sup=0.318575 | Recon=0.000603 | TrainAcc=0.8182
Fold 1 Epoch 700: TotalLoss=0.258201 | Sup=0.253949 | Recon=0.000532 | TrainAcc=0.8947
Fold 1 Epoch 800: TotalLoss=0.190965 | Sup=0.187271 | Recon=0.000462 | TrainAcc=0.8900
Fold 1 Epoch 900: TotalLoss=0.198458 | Sup=0.194883 | Recon=0.000447 | TrainAcc=0.