In [1]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from numpy import asarray
import tifffile as tiff
import torch.nn as nn
import torch.nn.functional as nnFn
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from sklearn.manifold import TSNE
import random
from torch_geometric.nn import GCNConv
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Load MCI features
fa_feature_path = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy"
Histogram_feature_MCI_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Combine features and labels
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_MCI_FA_array.shape[0], dtype=np.int64)
])

np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]
num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

Features: (300, 180), Labels: (300,)


In [3]:
class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

In [4]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ):
        super(GCNEncoder, self).__init__()
        self.device = device
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gcn1(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [5]:
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(GCN, self).__init__()
        self.device = device
        self.num_clusters = num_clusters
        self.cut = 0   # always modularity

        self.online_encoder = GCNEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_size=hidden_dim)

        # attach modularity loss
        self.loss = self.modularity_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)
        return S, logits

    def modularity_loss(self, A, S):
        C = nnFn.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)

        I_S = torch.eye(self.num_clusters, device=self.device)
        k = torch.norm(I_S)
        n = S.shape[0]

        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1

        return modularity_term + collapse_reg_term


In [6]:
def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    if cut == 0:
        W = np.where(W >= alpha, 1, 0).astype(np.float32)
        W = (W / W.max()).astype(np.float32)
    else:
        W = W - (W.max() / alpha)
    return W

In [7]:
def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight

In [8]:
features = X.astype(np.float32)
print(features.shape, features.dtype)

cut = 0
alpha = 0.92
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feats_dim = 180
K = 2

W0 = create_adj(features, 0, alpha)
node_feats, edge_index, _ = load_data(W0, features)
data0 = Data(x=node_feats, edge_index=edge_index).to(device)
A1 = torch.from_numpy(W0).float().to(device)
print(data0)

(300, 180) float32
Data(x=[300, 180], edge_index=[2, 13604])


In [9]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

model = GCN(feats_dim, 256, K, device, "ELU").to(device)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

num_epochs = 5000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    S, logits = model(data0)
    unsup_loss = model.loss(A1, logits)

    total_loss = unsup_loss
    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {total_loss:.4f}")


Epoch 0 | Loss: -0.2761
Epoch 100 | Loss: -0.4470
Epoch 200 | Loss: -0.4503
Epoch 300 | Loss: -0.4516
Epoch 400 | Loss: -0.4522
Epoch 500 | Loss: -0.4527
Epoch 600 | Loss: -0.4531
Epoch 700 | Loss: -0.4531
Epoch 800 | Loss: -0.4534
Epoch 900 | Loss: -0.4533
Epoch 1000 | Loss: -0.4532
Epoch 1100 | Loss: -0.4534
Epoch 1200 | Loss: -0.4534
Epoch 1300 | Loss: -0.4535
Epoch 1400 | Loss: -0.4536
Epoch 1500 | Loss: -0.4534
Epoch 1600 | Loss: -0.4535
Epoch 1700 | Loss: -0.4536
Epoch 1800 | Loss: -0.4535
Epoch 1900 | Loss: -0.4536
Epoch 2000 | Loss: -0.4534
Epoch 2100 | Loss: -0.4535
Epoch 2200 | Loss: -0.4534
Epoch 2300 | Loss: -0.4536
Epoch 2400 | Loss: -0.4534
Epoch 2500 | Loss: -0.4536
Epoch 2600 | Loss: -0.4535
Epoch 2700 | Loss: -0.4535
Epoch 2800 | Loss: -0.4534
Epoch 2900 | Loss: -0.4537
Epoch 3000 | Loss: -0.4536
Epoch 3100 | Loss: -0.4536
Epoch 3200 | Loss: -0.4535
Epoch 3300 | Loss: -0.4535
Epoch 3400 | Loss: -0.4535
Epoch 3500 | Loss: -0.4535
Epoch 3600 | Loss: -0.4535
Epoch 3700 | 

In [10]:
model.eval()
with torch.no_grad():
    S, logits = model(data0)
    y_pred = torch.argmax(logits, dim=1).cpu().numpy()
    y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()


acc_score = accuracy_score(y, y_pred)
acc_score_inverted = accuracy_score(y, 1 - y_pred)

if acc_score_inverted > acc_score:
    acc_score = acc_score_inverted
    y_pred = 1 - y_pred

prec_score = precision_score(y, y_pred)
rec_score = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)
log_loss_value = log_loss(y, y_pred_proba)

print("Accuracy:", acc_score)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.75
Precision: 0.8026315789473685
Recall: 0.7305389221556886
F1: 0.7648902821316614
Log Loss: 6.559525478557995


In [12]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    # Set seeds for reproducibility
    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    # Shuffle features and labels
    perm = np.random.permutation(X.shape[0])
    features = X[perm].astype(np.float32)
    labels = y[perm]

    cut = 0
    alpha = 0.92
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 180
    K = 2

    W0 = create_adj(features, cut, alpha)
    node_feats, edge_index, _ = load_data(W0, features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = GCN(feats_dim, 256, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()


        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        unsup_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {unsup_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(labels, y_pred)
    acc_score_inverted = accuracy_score(labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(labels, y_pred)
    rec_score = recall_score(labels, y_pred)
    f1 = f1_score(labels, y_pred)
    log_loss_value = log_loss(labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })

accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))



Epoch 0 | Loss: -0.2698
Epoch 1000 | Loss: -0.4534
Epoch 2000 | Loss: -0.4538
Epoch 3000 | Loss: -0.4536
Epoch 4000 | Loss: -0.4538
Accuracy: 0.7533333333333333 Precision: 0.8079470198675497 Recall: 0.7305389221556886 F1: 0.7672955974842768

Epoch 0 | Loss: -0.2701
Epoch 1000 | Loss: -0.4537
Epoch 2000 | Loss: -0.4539
Epoch 3000 | Loss: -0.4535
Epoch 4000 | Loss: -0.4539
Accuracy: 0.7466666666666667 Precision: 0.7973856209150327 Recall: 0.7305389221556886 F1: 0.7625

Epoch 0 | Loss: -0.2745
Epoch 1000 | Loss: -0.4529
Epoch 2000 | Loss: -0.4536
Epoch 3000 | Loss: -0.4536
Epoch 4000 | Loss: -0.4534
Accuracy: 0.75 Precision: 0.8026315789473685 Recall: 0.7305389221556886 F1: 0.7648902821316614

Epoch 0 | Loss: -0.2808
Epoch 1000 | Loss: -0.4537
Epoch 2000 | Loss: -0.4539
Epoch 3000 | Loss: -0.4536
Epoch 4000 | Loss: -0.4537
Accuracy: 0.75 Precision: 0.8026315789473685 Recall: 0.7305389221556886 F1: 0.7648902821316614

Epoch 0 | Loss: -0.2795
Epoch 1000 | Loss: -0.4534
Epoch 2000 | Loss: -

SELU

In [13]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    # Set seeds for reproducibility
    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    # Shuffle features and labels
    perm = np.random.permutation(X.shape[0])
    features = X[perm].astype(np.float32)
    labels = y[perm]

    cut = 0
    alpha = 0.92
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 180
    K = 2

    W0 = create_adj(features, cut, alpha)
    node_feats, edge_index, _ = load_data(W0, features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = GCN(feats_dim, 256, K, device, "SELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()


        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        unsup_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {unsup_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(labels, y_pred)
    acc_score_inverted = accuracy_score(labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(labels, y_pred)
    rec_score = recall_score(labels, y_pred)
    f1 = f1_score(labels, y_pred)
    log_loss_value = log_loss(labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })

accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))



Epoch 0 | Loss: -0.2698
Epoch 1000 | Loss: -0.4534
Epoch 2000 | Loss: -0.4538
Epoch 3000 | Loss: -0.4536
Epoch 4000 | Loss: -0.4538
Accuracy: 0.7533333333333333 Precision: 0.8079470198675497 Recall: 0.7305389221556886 F1: 0.7672955974842768

Epoch 0 | Loss: -0.2701
Epoch 1000 | Loss: -0.4537
Epoch 2000 | Loss: -0.4539
Epoch 3000 | Loss: -0.4535
Epoch 4000 | Loss: -0.4539
Accuracy: 0.7466666666666667 Precision: 0.7973856209150327 Recall: 0.7305389221556886 F1: 0.7625

Epoch 0 | Loss: -0.2745
Epoch 1000 | Loss: -0.4529
Epoch 2000 | Loss: -0.4536
Epoch 3000 | Loss: -0.4536
Epoch 4000 | Loss: -0.4534
Accuracy: 0.75 Precision: 0.8026315789473685 Recall: 0.7305389221556886 F1: 0.7648902821316614

Epoch 0 | Loss: -0.2808
Epoch 1000 | Loss: -0.4537
Epoch 2000 | Loss: -0.4539
Epoch 3000 | Loss: -0.4536
Epoch 4000 | Loss: -0.4537
Accuracy: 0.75 Precision: 0.8026315789473685 Recall: 0.7305389221556886 F1: 0.7648902821316614

Epoch 0 | Loss: -0.2795
Epoch 1000 | Loss: -0.4534
Epoch 2000 | Loss: -