In [2]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
import random
from torch_geometric.nn import GATConv

In [3]:
fa_patients_path = "/home/snu/Downloads/NIFD_Patients_FA_Histogram_Feature.npy"
Patients_FA_array = np.load(fa_patients_path, allow_pickle=True)

fa_controls_path = "/home/snu/Downloads/NIFD_Control_FA_Histogram_Feature.npy"
Controls_FA_array = np.load(fa_controls_path, allow_pickle=True)

X = np.vstack([Controls_FA_array, Patients_FA_array])
y = np.hstack([
    np.zeros(Controls_FA_array.shape[0], dtype=np.int64),  # 0 = Control
    np.ones(Patients_FA_array.shape[0], dtype=np.int64)    # 1 = Patient
])

np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]

In [4]:
class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

In [5]:
class GATEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ, heads=4):
        super(GATEncoder, self).__init__()
        self.device = device
        self.gat1 = GATConv(input_dim, hidden_dim // heads, heads=heads)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gat1(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [6]:
class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(GAT, self).__init__()
        self.device = device
        self.num_clusters = num_clusters
        self.cut = 0   # always modularity

        self.online_encoder = GATEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # only modularity loss
        self.loss = self.modularity_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)

        return S, logits

    def modularity_loss(self, A, S):
        C = nnFn.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)

        I_S = torch.eye(self.num_clusters, device=self.device)
        k = torch.norm(I_S)
        n = S.shape[0]

        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1

        return modularity_term + collapse_reg_term

In [7]:
def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    if cut == 0:
        W = np.where(W >= alpha, 1, 0).astype(np.float32)
        W = (W / W.max()).astype(np.float32)
    else:
        W = W - (W.max() / alpha)
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight

In [8]:
features = X.astype(np.float32)
print(features.shape, features.dtype)

cut = 0
alpha = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feats_dim = 180
K = 2

W0 = create_adj(features, 0, alpha)
node_feats, edge_index, _ = load_data(W0, features)
data0 = Data(x=node_feats, edge_index=edge_index).to(device)
A1 = torch.from_numpy(W0).float().to(device)
print(data0)

(146, 180) float32
Data(x=[146, 180], edge_index=[2, 21256])


In [9]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

model = GAT(feats_dim, 512, K, device, "ELU").to(device)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

num_epochs = 5000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    S, logits = model(data0)
    unsup_loss = model.loss(A1, logits)

    total_loss = unsup_loss
    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {total_loss:.4f}")

Epoch 0 | Loss: -0.2838
Epoch 100 | Loss: -0.2839
Epoch 200 | Loss: -0.2845
Epoch 300 | Loss: -0.2846
Epoch 400 | Loss: -0.2844
Epoch 500 | Loss: -0.2847
Epoch 600 | Loss: -0.2847
Epoch 700 | Loss: -0.2845
Epoch 800 | Loss: -0.2847
Epoch 900 | Loss: -0.2845
Epoch 1000 | Loss: -0.2847
Epoch 1100 | Loss: -0.2847
Epoch 1200 | Loss: -0.2847
Epoch 1300 | Loss: -0.2848
Epoch 1400 | Loss: -0.2846
Epoch 1500 | Loss: -0.2847
Epoch 1600 | Loss: -0.2847
Epoch 1700 | Loss: -0.2848
Epoch 1800 | Loss: -0.2848
Epoch 1900 | Loss: -0.2847
Epoch 2000 | Loss: -0.2846
Epoch 2100 | Loss: -0.2847
Epoch 2200 | Loss: -0.2848
Epoch 2300 | Loss: -0.2848
Epoch 2400 | Loss: -0.2847
Epoch 2500 | Loss: -0.2847
Epoch 2600 | Loss: -0.2848
Epoch 2700 | Loss: -0.2847
Epoch 2800 | Loss: -0.2848
Epoch 2900 | Loss: -0.2848
Epoch 3000 | Loss: -0.2847
Epoch 3100 | Loss: -0.2848
Epoch 3200 | Loss: -0.2848
Epoch 3300 | Loss: -0.2846
Epoch 3400 | Loss: -0.2846
Epoch 3500 | Loss: -0.2848
Epoch 3600 | Loss: -0.2848
Epoch 3700 | 

In [10]:
model.eval()
with torch.no_grad():
    S, logits = model(data0)
    y_pred = torch.argmax(logits, dim=1).cpu().numpy()
    y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()


acc_score = accuracy_score(y, y_pred)
acc_score_inverted = accuracy_score(y, 1 - y_pred)

if acc_score_inverted > acc_score:
    acc_score = acc_score_inverted
    y_pred = 1 - y_pred

prec_score = precision_score(y, y_pred)
rec_score = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)
log_loss_value = log_loss(y, y_pred_proba)

print("Accuracy:", acc_score)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.636986301369863
Precision: 0.8082191780821918
Recall: 0.6020408163265306
F1: 0.6900584795321637
Log Loss: 1.7942409911130333


In [12]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    # Set seeds for reproducibility
    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    # Shuffle features and labels
    perm = np.random.permutation(X.shape[0])
    features = X[perm].astype(np.float32)
    labels = y[perm]

    cut = 0
    alpha = 0.5
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 180
    K = 2

    W0 = create_adj(features, cut, alpha)
    node_feats, edge_index, _ = load_data(W0, features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = GAT(feats_dim, 512, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()


        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        unsup_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {unsup_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(labels, y_pred)
    acc_score_inverted = accuracy_score(labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(labels, y_pred)
    rec_score = recall_score(labels, y_pred)
    f1 = f1_score(labels, y_pred)
    log_loss_value = log_loss(labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })

accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))


Epoch 0 | Loss: -0.2811
Epoch 1000 | Loss: -0.2847
Epoch 2000 | Loss: -0.2848
Epoch 3000 | Loss: -0.2848
Epoch 4000 | Loss: -0.2847
Accuracy: 0.7191780821917808 Precision: 0.8352941176470589 Recall: 0.7244897959183674 F1: 0.7759562841530054

Epoch 0 | Loss: -0.2841
Epoch 1000 | Loss: -0.2848
Epoch 2000 | Loss: -0.2846
Epoch 3000 | Loss: -0.2847
Epoch 4000 | Loss: -0.2846
Accuracy: 0.6986301369863014 Precision: 0.8292682926829268 Recall: 0.6938775510204082 F1: 0.7555555555555555

Epoch 0 | Loss: -0.2841
Epoch 1000 | Loss: -0.2848
Epoch 2000 | Loss: -0.2840
Epoch 3000 | Loss: -0.2848
Epoch 4000 | Loss: -0.2848
Accuracy: 0.684931506849315 Precision: 0.825 Recall: 0.673469387755102 F1: 0.7415730337078652

Epoch 0 | Loss: -0.2839
Epoch 1000 | Loss: -0.2846
Epoch 2000 | Loss: -0.2846
Epoch 3000 | Loss: -0.2846
Epoch 4000 | Loss: -0.2846
Accuracy: 0.7123287671232876 Precision: 0.7916666666666666 Recall: 0.7755102040816326 F1: 0.7835051546391752

Epoch 0 | Loss: -0.2823
Epoch 1000 | Loss: -0.