In [1]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
import random
from torch_geometric.nn import GATConv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load CN features
cn_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(cn_feature_path, allow_pickle=True)

# Load AD features
ad_feature_path = "/home/snu/Downloads/Histogram_AD_FA_20bin_updated.npy"
Histogram_feature_AD_FA_array = np.load(ad_feature_path, allow_pickle=True)

# Combine features and labels
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

Features: (223, 180), Labels: (223,)


In [3]:
class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

In [4]:
class GATEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ, heads=4):
        super(GATEncoder, self).__init__()
        self.device = device
        self.gat1 = GATConv(input_dim, hidden_dim // heads, heads=heads)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gat1(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [6]:
class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(GAT, self).__init__()
        self.device = device
        self.num_clusters = num_clusters
        self.cut = 0

        self.online_encoder = GATEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # only modularity loss
        self.loss = self.cut_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)

        return S, logits

    def cut_loss(self, A, S):
        S = nnFn.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)

        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)

        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(self.num_clusters, device=self.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))

        return mincut_loss + ortho_loss

In [7]:
def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    if cut == 0:
        W = np.where(W >= alpha, 1, 0).astype(np.float32)
        W = (W / W.max()).astype(np.float32)
    else:
        W = W - (W.max() / alpha)
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight

In [8]:
features = X.astype(np.float32)
print(features.shape, features.dtype)

cut = 0
alpha = 0.8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feats_dim = 180
K = 2

W0 = create_adj(features, 0, alpha)
node_feats, edge_index, _ = load_data(W0, features)
data0 = Data(x=node_feats, edge_index=edge_index).to(device)
A1 = torch.from_numpy(W0).float().to(device)
print(data0)

(223, 180) float32
Data(x=[223, 180], edge_index=[2, 41689])


In [9]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

model = GAT(feats_dim, 256, K, device, "ELU").to(device)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

num_epochs = 5000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    S, logits = model(data0)
    unsup_loss = model.loss(A1, logits)

    total_loss = unsup_loss
    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {total_loss:.4f}")

Epoch 0 | Loss: -0.2465
Epoch 100 | Loss: -0.4341
Epoch 200 | Loss: -0.4759
Epoch 300 | Loss: -0.4821
Epoch 400 | Loss: -0.4843
Epoch 500 | Loss: -0.4973
Epoch 600 | Loss: -0.5012
Epoch 700 | Loss: -0.4965
Epoch 800 | Loss: -0.5071
Epoch 900 | Loss: -0.5000
Epoch 1000 | Loss: -0.5038
Epoch 1100 | Loss: -0.5080
Epoch 1200 | Loss: -0.4977
Epoch 1300 | Loss: -0.5077
Epoch 1400 | Loss: -0.5106
Epoch 1500 | Loss: -0.4925
Epoch 1600 | Loss: -0.5090
Epoch 1700 | Loss: -0.5019
Epoch 1800 | Loss: -0.5043
Epoch 1900 | Loss: -0.5007
Epoch 2000 | Loss: -0.5083
Epoch 2100 | Loss: -0.5062
Epoch 2200 | Loss: -0.5080
Epoch 2300 | Loss: -0.5091
Epoch 2400 | Loss: -0.5004
Epoch 2500 | Loss: -0.5056
Epoch 2600 | Loss: -0.5078
Epoch 2700 | Loss: -0.5044
Epoch 2800 | Loss: -0.5099
Epoch 2900 | Loss: -0.5031
Epoch 3000 | Loss: -0.4988
Epoch 3100 | Loss: -0.5077
Epoch 3200 | Loss: -0.5065
Epoch 3300 | Loss: -0.5058
Epoch 3400 | Loss: -0.5076
Epoch 3500 | Loss: -0.4999
Epoch 3600 | Loss: -0.5032
Epoch 3700 | 

In [10]:
model.eval()
with torch.no_grad():
    S, logits = model(data0)
    y_pred = torch.argmax(logits, dim=1).cpu().numpy()
    y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()


acc_score = accuracy_score(y, y_pred)
acc_score_inverted = accuracy_score(y, 1 - y_pred)

if acc_score_inverted > acc_score:
    acc_score = acc_score_inverted
    y_pred = 1 - y_pred

prec_score = precision_score(y, y_pred)
rec_score = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)
log_loss_value = log_loss(y, y_pred_proba)

print("Accuracy:", acc_score)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.7937219730941704
Precision: 0.7391304347826086
Recall: 0.7555555555555555
F1: 0.7472527472527473
Log Loss: 1.166678209503574


In [11]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    # Set seeds for reproducibility
    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    # Shuffle features and labels
    perm = np.random.permutation(X.shape[0])
    features = X[perm].astype(np.float32)
    labels = y[perm]

    cut = 0
    alpha = 0.8
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 180
    K = 2

    W0 = create_adj(features, cut, alpha)
    node_feats, edge_index, _ = load_data(W0, features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = GAT(feats_dim, 256, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()


        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        unsup_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {unsup_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(labels, y_pred)
    acc_score_inverted = accuracy_score(labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(labels, y_pred)
    rec_score = recall_score(labels, y_pred)
    f1 = f1_score(labels, y_pred)
    log_loss_value = log_loss(labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })

accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))


Epoch 0 | Loss: -0.2398
Epoch 1000 | Loss: -0.5147
Epoch 2000 | Loss: -0.5158
Epoch 3000 | Loss: -0.5159
Epoch 4000 | Loss: -0.5166
Accuracy: 0.7668161434977578 Precision: 0.69 Recall: 0.7666666666666667 F1: 0.7263157894736842

Epoch 0 | Loss: -0.2386
Epoch 1000 | Loss: -0.5079
Epoch 2000 | Loss: -0.5119
Epoch 3000 | Loss: -0.5059
Epoch 4000 | Loss: -0.5104
Accuracy: 0.7892376681614349 Precision: 0.7263157894736842 Recall: 0.7666666666666667 F1: 0.745945945945946

Epoch 0 | Loss: -0.2417
Epoch 1000 | Loss: -0.5159
Epoch 2000 | Loss: -0.5207
Epoch 3000 | Loss: -0.5206
Epoch 4000 | Loss: -0.5194
Accuracy: 0.7668161434977578 Precision: 0.69 Recall: 0.7666666666666667 F1: 0.7263157894736842

Epoch 0 | Loss: -0.2363
Epoch 1000 | Loss: -0.5179
Epoch 2000 | Loss: -0.5150
Epoch 3000 | Loss: -0.5207
Epoch 4000 | Loss: -0.5163
Accuracy: 0.7668161434977578 Precision: 0.69 Recall: 0.7666666666666667 F1: 0.7263157894736842

Epoch 0 | Loss: -0.2409
Epoch 1000 | Loss: -0.5122
Epoch 2000 | Loss: -0.5