In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
import numpy as np
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score, log_loss
)

In [None]:
# === Load Patients ===
fa_patients_path = "/home/snu/Downloads/NIFD_Patients_FA_Histogram_Feature.npy"
Patients_FA_array = np.load(fa_patients_path, allow_pickle=True)

# === Load Controls ===
fa_controls_path = "/home/snu/Downloads/NIFD_Control_FA_Histogram_Feature.npy"
Controls_FA_array = np.load(fa_controls_path, allow_pickle=True)

print("Patients Shape:", Patients_FA_array.shape)
print("Controls Shape:", Controls_FA_array.shape)

# === Combine features and labels ===
X = np.vstack([Controls_FA_array, Patients_FA_array])
y = np.hstack([
    np.zeros(Controls_FA_array.shape[0], dtype=np.int64),  # 0 = Control
    np.ones(Patients_FA_array.shape[0], dtype=np.int64)    # 1 = Patient
])

# Shuffle
np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]

Patients Shape: (98, 180)
Controls Shape: (48, 180)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch_geometric.nn import GATConv


class GAT_SemiSupervised(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device, heads=2, activ="SELU", dropout=0.25):
        super(GAT_SemiSupervised, self).__init__()
        self.device = device

        # GAT layer
        self.conv1 = GATConv(input_dim, hidden_dim, heads=heads, concat=True, dropout=dropout)
        self.bn1 = nn.BatchNorm1d(hidden_dim * heads)  # heads multiply feature dimension if concat=True
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * heads, output_dim)

        # Activation mapping
        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu,
            "ELU": nnFn.elu
        }
        self.act = activations.get(activ, nnFn.elu)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # GAT layer
        x = self.conv1(x, edge_index)
        # x = self.bn1(x)
        x = self.act(x)
        x = self.dropout(x)

        logits = self.fc(x)
        return logits

    def modularity_loss(self, A, logits):
        S = nnFn.softmax(logits, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.outer(d, d) / (2 * m + 1e-12)

        modularity_term = (-1 / (2 * m + 1e-12)) * torch.trace(S.T @ B @ S)
        # Collapse regularization
        I_S = torch.eye(S.shape[1], device=self.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        collapse_reg = (torch.sqrt(k) / n) * torch.norm(torch.sum(S, dim=0), p='fro') - 1
        entropy_reg = -torch.mean(torch.sum(S * torch.log(S + 1e-9), dim=1))
        return modularity_term + 0.1 * collapse_reg + 0.01 * entropy_reg


In [None]:
def create_adj(F, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1, 0).astype(np.float32)
    W = W / W.max()
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    return Data(x=node_feats, edge_index=edge_index)

In [None]:
num_nodes, num_feats = X.shape
print(f"Number of features: {num_feats}")

Number of features: 180


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
alpha = 0.5
feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 0.0001
weight_decay = 1e-4
batch_print_freq = 100
lambda_mod = 0.5 #0.01  # weight for modularity loss
# lambda_sup = 5

In [None]:
W = create_adj(X, alpha)
data = load_data(W, X).to(device)
A_tensor = torch.from_numpy(W).float().to(device)
print(data)

Data(x=[146, 180], edge_index=[2, 21256])


In [None]:
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=42)

accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
    print(f"\n=== Fold {fold} ===")


    cn_idx = np.where(y == 0)[0]
    mci_idx = np.where(y == 1)[0]

    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=fold)
    cn_train_idx, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
    mci_train_idx, _ = next(sss_class.split(X[mci_idx], y[mci_idx]))

    cn_train = cn_idx[cn_train_idx]
    mci_train = mci_idx[mci_train_idx]
    train_idx_final = np.concatenate([cn_train, mci_train])
    np.random.shuffle(train_idx_final)

    print(f"Train CN: {len(cn_train)}, Train MCI: {len(mci_train)}")

    train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
    y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)


    model = GAT_SemiSupervised(feats_dim, hidden_dim, num_classes, device, activ="RELU").to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce_loss = nn.CrossEntropyLoss()


    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()

        logits = model(data)
        loss_sup = ce_loss(logits[train_idx_t], y_train_tensor)
        loss_unsup = model.modularity_loss(A_tensor, logits)
        total_loss = loss_sup + lambda_mod * loss_unsup

        total_loss.backward()
        optimizer.step()

        if epoch % batch_print_freq == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                preds_train = logits[train_idx_t].argmax(dim=1)
                acc_train = accuracy_score(y_train_tensor.cpu(), preds_train.cpu())
            print(f"Fold {fold} Epoch {epoch}: "
                  f"TotalLoss={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | "
                  f"Unsup={loss_unsup.item():.6f} | TrainAcc={acc_train:.4f}")


    model.eval()
    with torch.no_grad():
        out = model(data)
        preds = out.argmax(dim=1).cpu().numpy()
        probs = torch.softmax(out, dim=1)[:, 1].cpu().numpy()  # Probability for class 1

    y_test = y[test_idx_global]
    y_pred_test = preds[test_idx_global]
    y_prob_test = probs[test_idx_global]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    auc = roc_auc_score(y_test, y_prob_test)
    ce = log_loss(y_test, y_prob_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    aucs.append(auc)
    ce_losses.append(ce)

    print(f"Fold {fold} → "
          f"Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | "
          f"F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")


print("\n=== Average Results Across 20 Folds ===")
print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"AUC:       {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
print(f"CE Loss:   {np.mean(ce_losses):.4f} ± {np.std(ce_losses):.4f}")


=== Fold 1 ===
Train CN: 4, Train MCI: 9
Fold 1 Epoch 1: TotalLoss=0.555270 | Sup=0.622866 | Unsup=-0.135193 | TrainAcc=0.7692
Fold 1 Epoch 100: TotalLoss=0.472497 | Sup=0.544118 | Unsup=-0.143242 | TrainAcc=0.8462
Fold 1 Epoch 200: TotalLoss=0.118846 | Sup=0.192776 | Unsup=-0.147862 | TrainAcc=1.0000
Fold 1 Epoch 300: TotalLoss=-0.042662 | Sup=0.032032 | Unsup=-0.149387 | TrainAcc=1.0000
Fold 1 Epoch 400: TotalLoss=-0.044250 | Sup=0.032411 | Unsup=-0.153322 | TrainAcc=1.0000
Fold 1 Epoch 500: TotalLoss=-0.062590 | Sup=0.012712 | Unsup=-0.150603 | TrainAcc=1.0000
Fold 1 Epoch 600: TotalLoss=-0.072996 | Sup=0.005894 | Unsup=-0.157782 | TrainAcc=1.0000
Fold 1 Epoch 700: TotalLoss=-0.075378 | Sup=0.004330 | Unsup=-0.159416 | TrainAcc=1.0000
Fold 1 Epoch 800: TotalLoss=-0.077785 | Sup=0.002359 | Unsup=-0.160287 | TrainAcc=1.0000
Fold 1 Epoch 900: TotalLoss=-0.005168 | Sup=0.078883 | Unsup=-0.168102 | TrainAcc=0.9231
Fold 1 Epoch 1000: TotalLoss=-0.060565 | Sup=0.021735 | Unsup=-0.164601 |