In [3]:
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
import random
from torch_geometric.nn import ARMAConv

fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

fa_feature_path = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy"
Histogram_feature_MCI_FA_array = np.load(fa_feature_path, allow_pickle=True)


X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_MCI_FA_array.shape[0], dtype=np.int64)
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ, stacks=1, layers=1):
        super(ARMAEncoder, self).__init__()
        self.device = device
        self.arma = ARMAConv(input_dim, hidden_dim, num_stacks=stacks, num_layers=layers)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.arma(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

class ARMA(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(ARMA, self).__init__()
        self.device = device
        self.num_clusters = num_clusters

        self.online_encoder = ARMAEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # use cut loss instead of modularity
        self.loss = self.cut_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)
        return S, logits

    def cut_loss(self, A, S):
        S = nnFn.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)

        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)

        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(self.num_clusters, device=self.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))

        return mincut_loss + ortho_loss


def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1, 0).astype(np.float32)
    W = (W / W.max()).astype(np.float32)
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight


from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

for run_seed in range(10):
    print("\n================ Run", run_seed, "===============")

    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    perm = np.random.permutation(X.shape[0])
    features = X[perm].astype(np.float32)
    labels = y[perm]

    alpha = 0.92
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 180
    K = 2

    W0 = create_adj(features, 0, alpha)
    node_feats, edge_index, _ = load_data(W0, features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = ARMA(feats_dim, 256, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()

        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        total_loss = unsup_loss
        total_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {total_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(labels, y_pred)
    acc_score_inverted = accuracy_score(labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(labels, y_pred)
    rec_score = recall_score(labels, y_pred)
    f1 = f1_score(labels, y_pred)
    log_loss_value = log_loss(labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })


ccs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(ccs), "std=", np.std(ccs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))

  from .autonotebook import tqdm as notebook_tqdm


Features: (300, 180), Labels: (300,)

Epoch 0 | Loss: -0.2617
Epoch 1000 | Loss: -0.8489
Epoch 2000 | Loss: -0.8479
Epoch 3000 | Loss: -0.8488
Epoch 4000 | Loss: -0.8476
Accuracy: 0.79 Precision: 0.8513513513513513 Recall: 0.7544910179640718 F1: 0.8

Epoch 0 | Loss: -0.2396
Epoch 1000 | Loss: -0.8493
Epoch 2000 | Loss: -0.8497
Epoch 3000 | Loss: -0.8509
Epoch 4000 | Loss: -0.8495
Accuracy: 0.77 Precision: 0.831081081081081 Recall: 0.7365269461077845 F1: 0.780952380952381

Epoch 0 | Loss: -0.2336
Epoch 1000 | Loss: -0.8229
Epoch 2000 | Loss: -0.8208
Epoch 3000 | Loss: -0.8231
Epoch 4000 | Loss: -0.8233
Accuracy: 0.51 Precision: 0.5675675675675675 Recall: 0.5029940119760479 F1: 0.5333333333333333

Epoch 0 | Loss: -0.2537
Epoch 1000 | Loss: -0.8451
Epoch 2000 | Loss: -0.8451
Epoch 3000 | Loss: -0.8457
Epoch 4000 | Loss: -0.8456
Accuracy: 0.7133333333333334 Precision: 0.7682119205298014 Recall: 0.6946107784431138 F1: 0.7295597484276729

Epoch 0 | Loss: -0.2460
Epoch 1000 | Loss: -0.8468
Ep

NameError: name 'accs' is not defined

In [5]:
print("\n===== Final Results across 10 runs =====")
print("Accuracy: mean=", np.mean(ccs), "std=", np.std(ccs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))


===== Final Results across 10 runs =====
Accuracy: mean= 0.7243333333333333 std= 0.08563553260443028
Precision: mean= 0.7828695389464758 std= 0.08705702265727608
Recall: mean= 0.6988023952095809 std= 0.07621745444548744
F1: mean= 0.738437529617372 std= 0.08123067198934417
