In [2]:
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
import random
from torch_geometric.nn import ARMAConv

fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

fa_feature_path = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy"
Histogram_feature_MCI_FA_array = np.load(fa_feature_path, allow_pickle=True)


X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_MCI_FA_array.shape[0], dtype=np.int64)
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ, stacks=1, layers=1):
        super(ARMAEncoder, self).__init__()
        self.device = device
        self.arma = ARMAConv(input_dim, hidden_dim, num_stacks=stacks, num_layers=layers)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.arma(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

class ARMA(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(ARMA, self).__init__()
        self.device = device
        self.num_clusters = num_clusters

        self.online_encoder = ARMAEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # only modularity loss
        self.loss = self.modularity_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)
        return S, logits

    def modularity_loss(self, A, S):
        C = nnFn.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)

        I_S = torch.eye(self.num_clusters, device=self.device)
        k = torch.norm(I_S)
        n = S.shape[0]

        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1

        return modularity_term + collapse_reg_term


def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1, 0).astype(np.float32)
    W = (W / W.max()).astype(np.float32)
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight


from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    perm = np.random.permutation(X.shape[0])
    features = X[perm].astype(np.float32)
    labels = y[perm]

    alpha = 0.92
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 180
    K = 2

    W0 = create_adj(features, 0, alpha)
    node_feats, edge_index, _ = load_data(W0, features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = ARMA(feats_dim, 256, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()

        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        total_loss = unsup_loss
        total_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {total_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(labels, y_pred)
    acc_score_inverted = accuracy_score(labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(labels, y_pred)
    rec_score = recall_score(labels, y_pred)
    f1 = f1_score(labels, y_pred)
    log_loss_value = log_loss(labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })


accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))


Features: (300, 180), Labels: (300,)

Epoch 0 | Loss: -0.2884
Epoch 1000 | Loss: -0.4550
Epoch 2000 | Loss: -0.4552
Epoch 3000 | Loss: -0.4549
Epoch 4000 | Loss: -0.4547
Accuracy: 0.7466666666666667 Precision: 0.8013245033112583 Recall: 0.7245508982035929 F1: 0.7610062893081762

Epoch 0 | Loss: -0.2785
Epoch 1000 | Loss: -0.4550
Epoch 2000 | Loss: -0.4550
Epoch 3000 | Loss: -0.4551
Epoch 4000 | Loss: -0.4552
Accuracy: 0.75 Precision: 0.8066666666666666 Recall: 0.7245508982035929 F1: 0.7634069400630915

Epoch 0 | Loss: -0.2520
Epoch 1000 | Loss: -0.4544
Epoch 2000 | Loss: -0.4548
Epoch 3000 | Loss: -0.4544
Epoch 4000 | Loss: -0.4543
Accuracy: 0.7433333333333333 Precision: 0.8 Recall: 0.718562874251497 F1: 0.7570977917981072

Epoch 0 | Loss: -0.2831
Epoch 1000 | Loss: -0.4547
Epoch 2000 | Loss: -0.4544
Epoch 3000 | Loss: -0.4548
Epoch 4000 | Loss: -0.4548
Accuracy: 0.7533333333333333 Precision: 0.8079470198675497 Recall: 0.7305389221556886 F1: 0.7672955974842768

Epoch 0 | Loss: -0.2794


In [3]:
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
import random
from torch_geometric.nn import ARMAConv

fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

fa_feature_path = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy"
Histogram_feature_MCI_FA_array = np.load(fa_feature_path, allow_pickle=True)


X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_MCI_FA_array.shape[0], dtype=np.int64)
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ, stacks=1, layers=1):
        super(ARMAEncoder, self).__init__()
        self.device = device
        self.arma = ARMAConv(input_dim, hidden_dim, num_stacks=stacks, num_layers=layers)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.arma(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

class ARMA(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(ARMA, self).__init__()
        self.device = device
        self.num_clusters = num_clusters

        self.online_encoder = ARMAEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # only modularity loss
        self.loss = self.modularity_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)
        return S, logits

    def modularity_loss(self, A, S):
        C = nnFn.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)

        I_S = torch.eye(self.num_clusters, device=self.device)
        k = torch.norm(I_S)
        n = S.shape[0]

        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1

        return modularity_term + collapse_reg_term


def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1, 0).astype(np.float32)
    W = (W / W.max()).astype(np.float32)
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight


from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    perm = np.random.permutation(X.shape[0])
    features = X[perm].astype(np.float32)
    labels = y[perm]

    alpha = 0.92
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 180
    K = 2

    W0 = create_adj(features, 0, alpha)
    node_feats, edge_index, _ = load_data(W0, features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = ARMA(feats_dim, 256, K, device, "SELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()

        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        total_loss = unsup_loss
        total_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {total_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(labels, y_pred)
    acc_score_inverted = accuracy_score(labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(labels, y_pred)
    rec_score = recall_score(labels, y_pred)
    f1 = f1_score(labels, y_pred)
    log_loss_value = log_loss(labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })


accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))


Features: (300, 180), Labels: (300,)

Epoch 0 | Loss: -0.2884
Epoch 1000 | Loss: -0.4550
Epoch 2000 | Loss: -0.4552
Epoch 3000 | Loss: -0.4549
Epoch 4000 | Loss: -0.4547
Accuracy: 0.7466666666666667 Precision: 0.8013245033112583 Recall: 0.7245508982035929 F1: 0.7610062893081762

Epoch 0 | Loss: -0.2785
Epoch 1000 | Loss: -0.4550
Epoch 2000 | Loss: -0.4550
Epoch 3000 | Loss: -0.4551
Epoch 4000 | Loss: -0.4552
Accuracy: 0.75 Precision: 0.8066666666666666 Recall: 0.7245508982035929 F1: 0.7634069400630915

Epoch 0 | Loss: -0.2520
Epoch 1000 | Loss: -0.4544
Epoch 2000 | Loss: -0.4548
Epoch 3000 | Loss: -0.4544
Epoch 4000 | Loss: -0.4543
Accuracy: 0.7433333333333333 Precision: 0.8 Recall: 0.718562874251497 F1: 0.7570977917981072

Epoch 0 | Loss: -0.2831
Epoch 1000 | Loss: -0.4547
Epoch 2000 | Loss: -0.4544
Epoch 3000 | Loss: -0.4548
Epoch 4000 | Loss: -0.4548
Accuracy: 0.7533333333333333 Precision: 0.8079470198675497 Recall: 0.7305389221556886 F1: 0.7672955974842768

Epoch 0 | Loss: -0.2794
