In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
import numpy as np
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score, log_loss
)

In [13]:
fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

fa_feature_path = "/home/snu/Downloads/Histogram_AD_FA_20bin_updated.npy"
Histogram_feature_AD_FA_array = np.load(fa_feature_path, allow_pickle=True)

X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

Features: (223, 180), Labels: (223,)


In [14]:
class ARMA_SemiSupervised(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device, activ="RELU"):
        super(ARMA_SemiSupervised, self).__init__()
        self.device = device
        self.conv1 = ARMAConv(input_dim, hidden_dim, num_stacks=1, num_layers=1)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(hidden_dim, output_dim)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu,
            "ELU": nnFn.elu
        }
        self.act = activations.get(activ, nnFn.elu)

    def forward(self, data):
      x, edge_index = data.x, data.edge_index
      x_in = x
      x = self.conv1(x, edge_index)
      x = self.bn1(x)

      if x_in.shape[1] == x.shape[1]:
          x = x + x_in  # residual (skip) connection is added

      x = self.act(x)
      x = self.dropout(x)
      logits = self.fc(x)
      return logits

    def modularity_loss(self, A, logits):
        S = nnFn.softmax(logits, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.outer(d, d) / (2 * m)

        modularity_term = (-1 / (2 * m)) * torch.trace(S.T @ B @ S)
        # Collapse regularization
        I_S = torch.eye(S.shape[1], device=self.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        collapse_reg = (torch.sqrt(k) / n) * torch.norm(torch.sum(S, dim=0), p='fro') - 1
        entropy_reg = -torch.mean(torch.sum(S * torch.log(S + 1e-9), dim=1))
        return modularity_term + 0.1*collapse_reg + 0.01*entropy_reg

In [15]:
def create_adj(F, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1, 0).astype(np.float32)
    W = W / W.max()
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    return Data(x=node_feats, edge_index=edge_index)

In [23]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
alpha = 0.8
feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 0.0001
weight_decay = 1e-4
batch_print_freq = 100
lambda_mod = 0.01 #0.01  # weight for modularity loss
# lambda_sup = 5

In [26]:
W = create_adj(X, alpha)
data = load_data(W, X).to(device)
A_tensor = torch.from_numpy(W).float().to(device)
print(data)

Data(x=[223, 180], edge_index=[2, 41689])


In [27]:
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=42)

accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
    print(f"\n=== Fold {fold} ===")


    cn_idx = np.where(y == 0)[0]
    ad_idx = np.where(y == 1)[0]

    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=fold)
    cn_train_idx, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
    ad_train_idx, _ = next(sss_class.split(X[ad_idx], y[ad_idx]))

    cn_train = cn_idx[cn_train_idx]
    ad_train = ad_idx[ad_train_idx]
    train_idx_final = np.concatenate([cn_train, ad_train])
    np.random.shuffle(train_idx_final)

    print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}")

    train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
    y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)


    model = ARMA_SemiSupervised(feats_dim, hidden_dim, num_classes, device, "SELU").to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce_loss = nn.CrossEntropyLoss()


    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()

        logits = model(data)
        loss_sup = ce_loss(logits[train_idx_t], y_train_tensor)
        loss_unsup = model.modularity_loss(A_tensor, logits)
        total_loss = loss_sup + lambda_mod * loss_unsup

        total_loss.backward()
        optimizer.step()

        if epoch % batch_print_freq == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                preds_train = logits[train_idx_t].argmax(dim=1)
                acc_train = accuracy_score(y_train_tensor.cpu(), preds_train.cpu())
            print(f"Fold {fold} Epoch {epoch}: "
                  f"TotalLoss={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | "
                  f"Unsup={loss_unsup.item():.6f} | TrainAcc={acc_train:.4f}")


    model.eval()
    with torch.no_grad():
        out = model(data)
        preds = out.argmax(dim=1).cpu().numpy()
        probs = torch.softmax(out, dim=1)[:, 1].cpu().numpy()  # Probability for class 1

    y_test = y[test_idx_global]
    y_pred_test = preds[test_idx_global]
    y_prob_test = probs[test_idx_global]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    auc = roc_auc_score(y_test, y_prob_test) # use
    ce = log_loss(y_test, y_prob_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    aucs.append(auc)
    ce_losses.append(ce)

    print(f"Fold {fold} → "
          f"Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | "
          f"F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")


print("\n=== Average Results Across 20 Folds ===")
print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"AUC:       {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
print(f"CE Loss:   {np.mean(ce_losses):.4f} ± {np.std(ce_losses):.4f}")


=== Fold 1 ===
Train CN: 13, Train AD: 9
Fold 1 Epoch 1: TotalLoss=0.921408 | Sup=0.922784 | Unsup=-0.137583 | TrainAcc=0.4091
Fold 1 Epoch 100: TotalLoss=0.043735 | Sup=0.045241 | Unsup=-0.150593 | TrainAcc=1.0000
Fold 1 Epoch 200: TotalLoss=0.010042 | Sup=0.011588 | Unsup=-0.154581 | TrainAcc=1.0000
Fold 1 Epoch 300: TotalLoss=0.003279 | Sup=0.004837 | Unsup=-0.155804 | TrainAcc=1.0000
Fold 1 Epoch 400: TotalLoss=0.000775 | Sup=0.002336 | Unsup=-0.156145 | TrainAcc=1.0000
Fold 1 Epoch 500: TotalLoss=0.000049 | Sup=0.001617 | Unsup=-0.156801 | TrainAcc=1.0000
Fold 1 Epoch 600: TotalLoss=-0.000567 | Sup=0.001019 | Unsup=-0.158577 | TrainAcc=1.0000
Fold 1 Epoch 700: TotalLoss=-0.001117 | Sup=0.000470 | Unsup=-0.158687 | TrainAcc=1.0000
Fold 1 Epoch 800: TotalLoss=-0.001116 | Sup=0.000478 | Unsup=-0.159453 | TrainAcc=1.0000
Fold 1 Epoch 900: TotalLoss=-0.001260 | Sup=0.000351 | Unsup=-0.161071 | TrainAcc=1.0000
Fold 1 Epoch 1000: TotalLoss=-0.001384 | Sup=0.000238 | Unsup=-0.162181 | Tr

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss

# ==========================================
# CONFIG
# ==========================================
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hidden_dim = 512
num_epochs = 5000
lr = 1e-4
weight_decay = 1e-4
batch_print_freq = 500  # print every 500 epochs

# Use same λ list as before
lambda_mod_list = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]


results_summary = []

# ==========================================
# Run for each λ_mod
# ==========================================
for lambda_mod in lambda_mod_list:
    print(f"\n==============================")
    print(f" Running with λ_mod = {lambda_mod}")
    print(f"==============================")

    accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

    sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=SEED)

    for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
        print(f"\n=== Fold {fold} ===")

        cn_idx = np.where(y == 0)[0]
        ad_idx = np.where(y == 1)[0]

        sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=fold)
        cn_train_idx, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
        ad_train_idx, _ = next(sss_class.split(X[ad_idx], y[ad_idx]))

        cn_train = cn_idx[cn_train_idx]
        ad_train = ad_idx[ad_train_idx]
        train_idx_final = np.concatenate([cn_train, ad_train])
        np.random.shuffle(train_idx_final)

        print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}")

        train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
        y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)

        # Initialize model
        model = ARMA_SemiSupervised(feats_dim, hidden_dim, num_classes, device, "SELU").to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        ce_loss = nn.CrossEntropyLoss()

        # Training loop
        for epoch in range(1, num_epochs + 1):
            model.train()
            optimizer.zero_grad()

            logits = model(data)
            loss_sup = ce_loss(logits[train_idx_t], y_train_tensor)
            loss_unsup = model.modularity_loss(A_tensor, logits)
            total_loss = loss_sup + lambda_mod * loss_unsup

            total_loss.backward()
            optimizer.step()

            if epoch % batch_print_freq == 0 or epoch == 1:
                model.eval()
                with torch.no_grad():
                    logits_eval = model(data)
                    preds_train = logits_eval[train_idx_t].argmax(dim=1).cpu().numpy()
                    acc_train = accuracy_score(y_train_tensor.cpu().numpy(), preds_train)
                print(f"Fold {fold} Epoch {epoch}: "
                      f"TotalLoss={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | "
                      f"Unsup={loss_unsup.item():.6f} | TrainAcc={acc_train:.4f}")

        # Evaluation
        model.eval()
        with torch.no_grad():
            out = model(data)
            preds = out.argmax(dim=1).cpu().numpy()
            probs = torch.softmax(out, dim=1)[:, 1].cpu().numpy()  # probability for "AD"

        y_test = y[test_idx_global]
        y_pred_test = preds[test_idx_global]
        y_prob_test = probs[test_idx_global]

        acc = accuracy_score(y_test, y_pred_test)
        prec = precision_score(y_test, y_pred_test, zero_division=0)
        rec = recall_score(y_test, y_pred_test, zero_division=0)
        f1 = f1_score(y_test, y_pred_test, zero_division=0)
        auc = roc_auc_score(y_test, y_prob_test)
        ce = log_loss(y_test, y_prob_test)

        accuracies.append(acc)
        precisions.append(prec)
        recalls.append(rec)
        f1_scores.append(f1)
        aucs.append(auc)
        ce_losses.append(ce)

        print(f"Fold {fold} → "
              f"Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | "
              f"F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")

    # Average results per λ_mod
    mean_acc, std_acc = np.mean(accuracies), np.std(accuracies)
    mean_prec, std_prec = np.mean(precisions), np.std(precisions)
    mean_rec, std_rec = np.mean(recalls), np.std(recalls)
    mean_f1, std_f1 = np.mean(f1_scores), np.std(f1_scores)
    mean_auc, std_auc = np.mean(aucs), np.std(aucs)
    mean_ce, std_ce = np.mean(ce_losses), np.std(ce_losses)

    results_summary.append({
        "λ_mod": lambda_mod,
        "Accuracy": f"{mean_acc:.4f} ± {std_acc:.4f}",
        "Precision": f"{mean_prec:.4f} ± {std_prec:.4f}",
        "Recall": f"{mean_rec:.4f} ± {std_rec:.4f}",
        "F1": f"{mean_f1:.4f} ± {std_f1:.4f}",
        "AUC": f"{mean_auc:.4f} ± {std_auc:.4f}",
        "CE Loss": f"{mean_ce:.4f} ± {std_ce:.4f}",
    })

    print(f"\n=== λ_mod = {lambda_mod} → Average Results ===")
    print(f"Accuracy:  {mean_acc:.4f} ± {std_acc:.4f}")
    print(f"Precision: {mean_prec:.4f} ± {std_prec:.4f}")
    print(f"Recall:    {mean_rec:.4f} ± {std_rec:.4f}")
    print(f"F1-score:  {mean_f1:.4f} ± {std_f1:.4f}")
    print(f"AUC:       {mean_auc:.4f} ± {std_auc:.4f}")
    print(f"CE Loss:   {mean_ce:.4f} ± {std_ce:.4f}")

# ==========================================
# Final summary table for all λ_mod values
# ==========================================
print("\n\n========== FINAL SUMMARY TABLE (CN vs AD) ==========")
results_df = pd.DataFrame(results_summary)
print(results_df.to_string(index=False))


 Running with λ_mod = 0.001

=== Fold 1 ===
Train CN: 66, Train AD: 83
Fold 1 Epoch 1: TotalLoss=0.749776 | Sup=0.749910 | Unsup=-0.134613 | TrainAcc=0.4430
Fold 1 Epoch 500: TotalLoss=0.007621 | Sup=0.007770 | Unsup=-0.149480 | TrainAcc=1.0000
Fold 1 Epoch 1000: TotalLoss=0.002174 | Sup=0.002324 | Unsup=-0.149805 | TrainAcc=1.0000
Fold 1 Epoch 1500: TotalLoss=0.000532 | Sup=0.000681 | Unsup=-0.149690 | TrainAcc=1.0000
Fold 1 Epoch 2000: TotalLoss=0.000071 | Sup=0.000220 | Unsup=-0.149214 | TrainAcc=1.0000
Fold 1 Epoch 2500: TotalLoss=-0.000042 | Sup=0.000107 | Unsup=-0.149431 | TrainAcc=1.0000
Fold 1 Epoch 3000: TotalLoss=-0.000082 | Sup=0.000068 | Unsup=-0.149488 | TrainAcc=1.0000
Fold 1 Epoch 3500: TotalLoss=-0.000090 | Sup=0.000059 | Unsup=-0.149237 | TrainAcc=1.0000
Fold 1 Epoch 4000: TotalLoss=-0.000091 | Sup=0.000058 | Unsup=-0.149338 | TrainAcc=1.0000
Fold 1 Epoch 4500: TotalLoss=-0.000092 | Sup=0.000058 | Unsup=-0.149312 | TrainAcc=1.0000
Fold 1 Epoch 5000: TotalLoss=-0.00008

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss

# ==========================================
# CONFIG
# ==========================================
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hidden_dim = 512
num_epochs = 5000
lr = 1e-4
weight_decay = 1e-4
batch_print_freq = 500  # print every 500 epochs

# Use same λ list as before
lambda_mod_list = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]


results_summary = []

# ==========================================
# Run for each λ_mod
# ==========================================
for lambda_mod in lambda_mod_list:
    print(f"\n==============================")
    print(f" Running with λ_mod = {lambda_mod}")
    print(f"==============================")

    accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

    sss = StratifiedShuffleSplit(n_splits=20, test_size=0.1, random_state=SEED)

    for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
        print(f"\n=== Fold {fold} ===")

        cn_idx = np.where(y == 0)[0]
        ad_idx = np.where(y == 1)[0]

        sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.1, random_state=fold)
        cn_train_idx, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
        ad_train_idx, _ = next(sss_class.split(X[ad_idx], y[ad_idx]))

        cn_train = cn_idx[cn_train_idx]
        ad_train = ad_idx[ad_train_idx]
        train_idx_final = np.concatenate([cn_train, ad_train])
        np.random.shuffle(train_idx_final)

        print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}")

        train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
        y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)

        # Initialize model
        model = ARMA_SemiSupervised(feats_dim, hidden_dim, num_classes, device, "SELU").to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        ce_loss = nn.CrossEntropyLoss()

        # Training loop
        for epoch in range(1, num_epochs + 1):
            model.train()
            optimizer.zero_grad()

            logits = model(data)
            loss_sup = ce_loss(logits[train_idx_t], y_train_tensor)
            loss_unsup = model.modularity_loss(A_tensor, logits)
            total_loss = loss_sup + lambda_mod * loss_unsup

            total_loss.backward()
            optimizer.step()

            if epoch % batch_print_freq == 0 or epoch == 1:
                model.eval()
                with torch.no_grad():
                    logits_eval = model(data)
                    preds_train = logits_eval[train_idx_t].argmax(dim=1).cpu().numpy()
                    acc_train = accuracy_score(y_train_tensor.cpu().numpy(), preds_train)
                print(f"Fold {fold} Epoch {epoch}: "
                      f"TotalLoss={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | "
                      f"Unsup={loss_unsup.item():.6f} | TrainAcc={acc_train:.4f}")

        # Evaluation
        model.eval()
        with torch.no_grad():
            out = model(data)
            preds = out.argmax(dim=1).cpu().numpy()
            probs = torch.softmax(out, dim=1)[:, 1].cpu().numpy()  # probability for "AD"

        y_test = y[test_idx_global]
        y_pred_test = preds[test_idx_global]
        y_prob_test = probs[test_idx_global]

        acc = accuracy_score(y_test, y_pred_test)
        prec = precision_score(y_test, y_pred_test, zero_division=0)
        rec = recall_score(y_test, y_pred_test, zero_division=0)
        f1 = f1_score(y_test, y_pred_test, zero_division=0)
        auc = roc_auc_score(y_test, y_prob_test)
        ce = log_loss(y_test, y_prob_test)

        accuracies.append(acc)
        precisions.append(prec)
        recalls.append(rec)
        f1_scores.append(f1)
        aucs.append(auc)
        ce_losses.append(ce)

        print(f"Fold {fold} → "
              f"Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | "
              f"F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")

    # Average results per λ_mod
    mean_acc, std_acc = np.mean(accuracies), np.std(accuracies)
    mean_prec, std_prec = np.mean(precisions), np.std(precisions)
    mean_rec, std_rec = np.mean(recalls), np.std(recalls)
    mean_f1, std_f1 = np.mean(f1_scores), np.std(f1_scores)
    mean_auc, std_auc = np.mean(aucs), np.std(aucs)
    mean_ce, std_ce = np.mean(ce_losses), np.std(ce_losses)

    results_summary.append({
        "λ_mod": lambda_mod,
        "Accuracy": f"{mean_acc:.4f} ± {std_acc:.4f}",
        "Precision": f"{mean_prec:.4f} ± {std_prec:.4f}",
        "Recall": f"{mean_rec:.4f} ± {std_rec:.4f}",
        "F1": f"{mean_f1:.4f} ± {std_f1:.4f}",
        "AUC": f"{mean_auc:.4f} ± {std_auc:.4f}",
        "CE Loss": f"{mean_ce:.4f} ± {std_ce:.4f}",
    })

    print(f"\n=== λ_mod = {lambda_mod} → Average Results ===")
    print(f"Accuracy:  {mean_acc:.4f} ± {std_acc:.4f}")
    print(f"Precision: {mean_prec:.4f} ± {std_prec:.4f}")
    print(f"Recall:    {mean_rec:.4f} ± {std_rec:.4f}")
    print(f"F1-score:  {mean_f1:.4f} ± {std_f1:.4f}")
    print(f"AUC:       {mean_auc:.4f} ± {std_auc:.4f}")
    print(f"CE Loss:   {mean_ce:.4f} ± {std_ce:.4f}")

# ==========================================
# Final summary table for all λ_mod values
# ==========================================
print("\n\n========== FINAL SUMMARY TABLE (CN vs AD) ==========")
results_df = pd.DataFrame(results_summary)
print(results_df.to_string(index=False))


 Running with λ_mod = 0.001

=== Fold 1 ===
Train CN: 119, Train AD: 150
Fold 1 Epoch 1: TotalLoss=0.723792 | Sup=0.723927 | Unsup=-0.134613 | TrainAcc=0.4424
Fold 1 Epoch 500: TotalLoss=0.013705 | Sup=0.013853 | Unsup=-0.148030 | TrainAcc=1.0000
Fold 1 Epoch 1000: TotalLoss=0.002749 | Sup=0.002897 | Unsup=-0.148342 | TrainAcc=1.0000
Fold 1 Epoch 1500: TotalLoss=0.000606 | Sup=0.000755 | Unsup=-0.148670 | TrainAcc=1.0000
Fold 1 Epoch 2000: TotalLoss=0.000297 | Sup=0.000446 | Unsup=-0.148666 | TrainAcc=1.0000
Fold 1 Epoch 2500: TotalLoss=0.000112 | Sup=0.000260 | Unsup=-0.148759 | TrainAcc=1.0000
Fold 1 Epoch 3000: TotalLoss=-0.000000 | Sup=0.000148 | Unsup=-0.148736 | TrainAcc=1.0000
Fold 1 Epoch 3500: TotalLoss=0.000006 | Sup=0.000155 | Unsup=-0.148787 | TrainAcc=1.0000
Fold 1 Epoch 4000: TotalLoss=-0.000040 | Sup=0.000109 | Unsup=-0.148759 | TrainAcc=1.0000
Fold 1 Epoch 4500: TotalLoss=0.000004 | Sup=0.000153 | Unsup=-0.148790 | TrainAcc=1.0000
Fold 1 Epoch 5000: TotalLoss=-0.000043

ARMA + CE + MOD + RECONS

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

fa_feature_path_cn = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
fa_feature_path_ad = "/home/snu/Downloads/Histogram_AD_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path_cn, allow_pickle=True)
Histogram_feature_AD_FA_array = np.load(fa_feature_path_ad, allow_pickle=True)
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)
])
num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

class ARMA_MAE_Mod(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_stacks=1, num_layers=1, dropout=0.2, activ="RELU"):
        super().__init__()
        self.arma1 = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            dropout=dropout
        )
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim, output_dim)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, input_dim)
        )

        activations = {
            "LeakyReLU": nnFn.leaky_relu,
            "SELU": nnFn.selu,
            "ReLU": nnFn.relu,
            "GELU": nnFn.gelu
        }
        self.act = activations.get(activ, nnFn.selu)

    def forward(self, data):
      x, edge_index = data.x, data.edge_index
      h = self.arma1(x, edge_index)
      h = self.act(h)
      h = self.bn1(h)
      h = self.dropout(h)
      logits = self.classifier(h)
      recon = self.decoder(h)
      return logits, recon, h


    def modularity_loss(self, A, logits):
        S = nnFn.softmax(logits, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        if m.item() == 0:
            return torch.tensor(0.0, device=self.device)
        B = A - torch.outer(d, d) / (2 * m)
        modularity_term = (-1.0 / (2.0 * m)) * torch.trace(S.T @ B @ S)
        I_S = torch.eye(S.shape[1], device=A.device)

        k = torch.norm(I_S)
        n = S.shape[0]
        collapse_reg = (torch.sqrt(k) / n) * torch.norm(torch.sum(S, dim=0), p='fro') - 1.0
        entropy_reg = -torch.mean(torch.sum(S * torch.log(S + 1e-9), dim=1))
        return modularity_term + 0.1 * collapse_reg + 0.01 * entropy_reg

def create_adj(F, alpha=1.0):
    row_norms = np.linalg.norm(F, axis=1, keepdims=True)
    row_norms[row_norms == 0] = 1.0
    F_norm = F / row_norms
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1.0, 0.0).astype(np.float32)
    if W.max() > 0:
        W = W / W.max()
    else:
        W = W
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    nz = np.nonzero((adj > 0).astype(np.int32))
    edge_index = torch.from_numpy(np.array(nz)).long()
    return Data(x=node_feats, edge_index=edge_index)

alpha = 0.8
feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 0.0001
weight_decay = 1e-4
batch_print_freq = 100
lambda_mod = 0.05 #0.05
lambda_recon = 1.0
dropout = 0.2

W = create_adj(X, alpha)
data = load_data(W, X).to(device)
A_tensor = torch.from_numpy(W).float().to(device)
print(data)

sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=SEED)
accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
    print(f"\n=== Fold {fold} ===")
    cn_idx = np.where(y == 0)[0]
    ad_idx = np.where(y == 1)[0]
    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=fold)
    cn_train_idx_small, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
    ad_train_idx_small, _ = next(sss_class.split(X[ad_idx], y[ad_idx]))
    cn_train = cn_idx[cn_train_idx_small]
    ad_train = ad_idx[ad_train_idx_small]
    train_idx_final = np.concatenate([cn_train, ad_train])
    np.random.shuffle(train_idx_final)
    print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}, Total train: {len(train_idx_final)}")
    train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
    y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)
    model = ARMA_MAE_Mod(
    input_dim=feats_dim,
    hidden_dim=hidden_dim,
    output_dim=num_classes,
    num_stacks=1,
    num_layers=1,
    dropout=dropout
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce_loss_fn = nn.CrossEntropyLoss()
    recon_loss_fn = nn.MSELoss()
    data_on_device = data
    data_x = data_on_device.x.to(device)
    data_edge_index = data_on_device.edge_index.to(device)
    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        logits, recon, _ = model(data_on_device)

        loss_sup = ce_loss_fn(logits[train_idx_t], y_train_tensor)
        loss_recon = recon_loss_fn(recon, data_x)
        loss_mod = model.modularity_loss(A_tensor, logits)
        total_loss = loss_sup + lambda_recon * loss_recon + lambda_mod * loss_mod
        total_loss.backward()
        optimizer.step()
        if epoch % batch_print_freq == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval, _, _ = model(data_on_device)
                preds_train = logits_eval[train_idx_t].argmax(dim=1).cpu().numpy()
                acc_train = accuracy_score(y_train_tensor.cpu().numpy(), preds_train)
            print(f"Fold {fold} Epoch {epoch}: Total={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | Recon={loss_recon.item():.6f} | Mod={loss_mod.item():.6f} | TrainAcc={acc_train:.4f}")
    model.eval()
    with torch.no_grad():
        logits_final, _, _ = model(data_on_device)
        preds_all = logits_final.argmax(dim=1).cpu().numpy()
        probs_all = torch.softmax(logits_final, dim=1)[:, 1].cpu().numpy()
    y_test = y[test_idx_global]
    y_pred_test = preds_all[test_idx_global]
    y_prob_test = probs_all[test_idx_global]
    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    try:
        auc = roc_auc_score(y_test, y_prob_test)
    except Exception:
        auc = float("nan")
    try:
        ce = log_loss(y_test, y_prob_test)
    except Exception:
        ce = float("nan")
    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    aucs.append(auc)
    ce_losses.append(ce)
    print(f"Fold {fold} → Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")

print("\n=== Average Results Across Folds ===")
print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"AUC:       {np.nanmean(aucs):.4f} ± {np.nanstd(aucs):.4f}")
print(f"CE Loss:   {np.nanmean(ce_losses):.4f} ± {np.nanstd(ce_losses):.4f}")

Features: (223, 180), Labels: (223,)
Data(x=[223, 180], edge_index=[2, 41689])

=== Fold 1 ===
Train CN: 66, Train AD: 45, Total train: 111
Fold 1 Epoch 1: Total=0.941973 | Sup=0.808002 | Recon=0.140705 | Mod=-0.134679 | TrainAcc=0.5946
Fold 1 Epoch 100: Total=0.337803 | Sup=0.319216 | Recon=0.026262 | Mod=-0.153508 | TrainAcc=0.8829
Fold 1 Epoch 200: Total=0.245721 | Sup=0.242329 | Recon=0.011159 | Mod=-0.155345 | TrainAcc=0.9279
Fold 1 Epoch 300: Total=0.254026 | Sup=0.256009 | Recon=0.005816 | Mod=-0.155974 | TrainAcc=0.9459
Fold 1 Epoch 400: Total=0.240155 | Sup=0.244517 | Recon=0.003532 | Mod=-0.157890 | TrainAcc=0.9910
Fold 1 Epoch 500: Total=0.167152 | Sup=0.172695 | Recon=0.002345 | Mod=-0.157761 | TrainAcc=1.0000
Fold 1 Epoch 600: Total=0.137396 | Sup=0.143710 | Recon=0.001632 | Mod=-0.158928 | TrainAcc=0.9910
Fold 1 Epoch 700: Total=0.136654 | Sup=0.143380 | Recon=0.001202 | Mod=-0.158576 | TrainAcc=1.0000
Fold 1 Epoch 800: Total=0.121526 | Sup=0.128561 | Recon=0.000944 | Mod

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

fa_feature_path_cn = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
fa_feature_path_mci = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy"

Histogram_feature_CN_FA_array = np.load(fa_feature_path_cn, allow_pickle=True)
Histogram_feature_MCI_FA_array = np.load(fa_feature_path_mci, allow_pickle=True)

# Stack CN and MCI data together
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),  # CN = 0
    np.ones(Histogram_feature_MCI_FA_array.shape[0], dtype=np.int64)   # MCI = 1
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")
print("Loaded CN vs MCI dataset")

class ARMA_MAE_Mod(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_stacks=1, num_layers=1, dropout=0.2, activ="SELU"):
        super().__init__()
        self.arma1 = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            dropout=dropout
        )
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim, output_dim)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, input_dim)
        )

        activations = {
            "LeakyReLU": nnFn.leaky_relu,
            "SELU": nnFn.selu,
            "ReLU": nnFn.relu,
            "GELU": nnFn.gelu
        }
        self.act = activations.get(activ, nnFn.selu)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        h = self.arma1(x, edge_index)
        h = self.act(h)
        h = self.bn1(h)
        h = self.dropout(h)
        logits = self.classifier(h)
        recon = self.decoder(h)
        return logits, recon, h

    def modularity_loss(self, A, logits):
        S = nnFn.softmax(logits, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        if m.item() == 0:
            return torch.tensor(0.0, device=A.device)
        B = A - torch.outer(d, d) / (2 * m)
        modularity_term = (-1.0 / (2.0 * m)) * torch.trace(S.T @ B @ S)
        I_S = torch.eye(S.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        collapse_reg = (torch.sqrt(k) / n) * torch.norm(torch.sum(S, dim=0), p='fro') - 1.0
        entropy_reg = -torch.mean(torch.sum(S * torch.log(S + 1e-9), dim=1))
        return modularity_term + 0.1 * collapse_reg + 0.01 * entropy_reg

def create_adj(F, alpha=1.0):
    row_norms = np.linalg.norm(F, axis=1, keepdims=True)
    row_norms[row_norms == 0] = 1.0
    F_norm = F / row_norms
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1.0, 0.0).astype(np.float32)
    if W.max() > 0:
        W = W / W.max()
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    nz = np.nonzero((adj > 0).astype(np.int32))
    edge_index = torch.from_numpy(np.array(nz)).long()
    return Data(x=node_feats, edge_index=edge_index)

alpha = 0.92
feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 0.0001
weight_decay = 1e-4
batch_print_freq = 100
lambda_mod = 0.05
lambda_recon = 1.0
dropout = 0.2

W = create_adj(X, alpha)
data = load_data(W, X).to(device)
A_tensor = torch.from_numpy(W).float().to(device)
print(data)

sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=SEED)
accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
    print(f"\n=== Fold {fold} ===")
    cn_idx = np.where(y == 0)[0]
    mci_idx = np.where(y == 1)[0]
    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=fold)
    cn_train_idx_small, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
    mci_train_idx_small, _ = next(sss_class.split(X[mci_idx], y[mci_idx]))
    cn_train = cn_idx[cn_train_idx_small]
    mci_train = mci_idx[mci_train_idx_small]
    train_idx_final = np.concatenate([cn_train, mci_train])
    np.random.shuffle(train_idx_final)
    print(f"Train CN: {len(cn_train)}, Train MCI: {len(mci_train)}, Total train: {len(train_idx_final)}")

    train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
    y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)

    model = ARMA_MAE_Mod(
        input_dim=feats_dim,
        hidden_dim=hidden_dim,
        output_dim=num_classes,
        num_stacks=1,
        num_layers=1,
        dropout=dropout
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce_loss_fn = nn.CrossEntropyLoss()
    recon_loss_fn = nn.MSELoss()
    data_on_device = data
    data_x = data_on_device.x.to(device)


    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        logits, recon, _ = model(data_on_device)

        loss_sup = ce_loss_fn(logits[train_idx_t], y_train_tensor)
        loss_recon = recon_loss_fn(recon, data_x)
        loss_mod = model.modularity_loss(A_tensor, logits)
        total_loss = loss_sup + lambda_recon * loss_recon + lambda_mod * loss_mod
        total_loss.backward()
        optimizer.step()

        if epoch % batch_print_freq == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval, _, _ = model(data_on_device)
                preds_train = logits_eval[train_idx_t].argmax(dim=1).cpu().numpy()
                acc_train = accuracy_score(y_train_tensor.cpu().numpy(), preds_train)
            print(f"Fold {fold} Epoch {epoch}: Total={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | Recon={loss_recon.item():.6f} | Mod={loss_mod.item():.6f} | TrainAcc={acc_train:.4f}")

    model.eval()
    with torch.no_grad():
        logits_final, _, _ = model(data_on_device)
        preds_all = logits_final.argmax(dim=1).cpu().numpy()
        probs_all = torch.softmax(logits_final, dim=1)[:, 1].cpu().numpy()

    y_test = y[test_idx_global]
    y_pred_test = preds_all[test_idx_global]
    y_prob_test = probs_all[test_idx_global]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    try:
        auc = roc_auc_score(y_test, y_prob_test)
    except Exception:
        auc = float("nan")
    try:
        ce = log_loss(y_test, y_prob_test)
    except Exception:
        ce = float("nan")

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    aucs.append(auc)
    ce_losses.append(ce)

    print(f"Fold {fold} → Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")

print("\n=== Average Results Across Folds ===")
print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"AUC:       {np.nanmean(aucs):.4f} ± {np.nanstd(aucs):.4f}")
print(f"CE Loss:   {np.nanmean(ce_losses):.4f} ± {np.nanstd(ce_losses):.4f}")

Features: (300, 180), Labels: (300,)
Loaded CN vs MCI dataset
Data(x=[300, 180], edge_index=[2, 13604])

=== Fold 1 ===
Train CN: 66, Train MCI: 83, Total train: 149
Fold 1 Epoch 1: Total=1.050128 | Sup=0.908718 | Recon=0.148180 | Mod=-0.135393 | TrainAcc=0.4362
Fold 1 Epoch 100: Total=0.352585 | Sup=0.336424 | Recon=0.025956 | Mod=-0.195898 | TrainAcc=0.8725
Fold 1 Epoch 200: Total=0.285043 | Sup=0.284749 | Recon=0.010186 | Mod=-0.197831 | TrainAcc=0.9262
Fold 1 Epoch 300: Total=0.281662 | Sup=0.285737 | Recon=0.005868 | Mod=-0.198857 | TrainAcc=0.9329
Fold 1 Epoch 400: Total=0.213875 | Sup=0.221113 | Recon=0.003210 | Mod=-0.208972 | TrainAcc=0.9597
Fold 1 Epoch 500: Total=0.167072 | Sup=0.175272 | Recon=0.002197 | Mod=-0.207949 | TrainAcc=0.9866
Fold 1 Epoch 600: Total=0.196440 | Sup=0.205055 | Recon=0.001520 | Mod=-0.202712 | TrainAcc=0.9732
Fold 1 Epoch 700: Total=0.139971 | Sup=0.149474 | Recon=0.001109 | Mod=-0.212223 | TrainAcc=0.9866
Fold 1 Epoch 800: Total=0.098461 | Sup=0.108

Masking ARMA + CE + MOD + RECONS

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv


SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

fa_feature_path_cn = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
fa_feature_path_ad = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy" # Assuming this is the AD data

Histogram_feature_CN_FA_array = np.load(fa_feature_path_cn, allow_pickle=True)
Histogram_feature_AD_FA_array = np.load(fa_feature_path_ad, allow_pickle=True)
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)
])
num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

# ==========================
#  Model Definition
# ==========================
class ARMA_MAE_Mod(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_stacks=1, num_layers=1, dropout=0.2, activ="SELU"):
        super().__init__()
        self.arma1 = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            dropout=dropout
        )
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim, output_dim)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, input_dim)
        )

        activations = {
            "LeakyReLU": nnFn.leaky_relu,
            "SELU": nnFn.selu,
            "ReLU": nnFn.relu,
            "GELU": nnFn.gelu
        }
        self.act = activations.get(activ, nnFn.selu)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        h = self.arma1(x, edge_index)
        h = self.act(h)
        h = self.bn1(h)
        h = self.dropout(h)
        logits = self.classifier(h)
        recon = self.decoder(h)
        return logits, recon, h

    def modularity_loss(self, A, logits):
        S = nnFn.softmax(logits, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        if m.item() == 0:
            return torch.tensor(0.0, device=A.device)
        B = A - torch.outer(d, d) / (2 * m)
        modularity_term = (-1.0 / (2.0 * m)) * torch.trace(S.T @ B @ S)
        I_S = torch.eye(S.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        collapse_reg = (torch.sqrt(k) / n) * torch.norm(torch.sum(S, dim=0), p='fro') - 1.0
        entropy_reg = -torch.mean(torch.sum(S * torch.log(S + 1e-9), dim=1))
        return modularity_term + 0.1 * collapse_reg + 0.01 * entropy_reg

# ==========================
#  Utility Functions
# ==========================
def create_adj(F, alpha=1.0):
    row_norms = np.linalg.norm(F, axis=1, keepdims=True)
    row_norms[row_norms == 0] = 1.0
    F_norm = F / row_norms
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1.0, 0.0).astype(np.float32)
    if W.max() > 0:
        W = W / W.max()
    return W

def mask_edges_symmetric(W, remove_ratio=0.1, seed=42):
    np.random.seed(seed)
    W_masked = W.copy()
    upper_indices = np.triu_indices_from(W, k=1)
    existing_edges = np.where(W[upper_indices] > 0)[0]
    num_edges = len(existing_edges)
    num_remove = int(remove_ratio * num_edges)
    remove_idx = np.random.choice(existing_edges, size=num_remove, replace=False)
    rows_to_remove = upper_indices[0][remove_idx]
    cols_to_remove = upper_indices[1][remove_idx]
    W_masked[rows_to_remove, cols_to_remove] = 0.0
    W_masked[cols_to_remove, rows_to_remove] = 0.0
    return W_masked

def mask_edges_asymmetric(W, remove_ratio=0.1, seed=42):
    np.random.seed(seed)
    W_masked = W.copy()
    non_diag = np.where(~np.eye(W.shape[0], dtype=bool))
    existing_edges = np.where(W[non_diag] > 0)[0]
    num_edges = len(existing_edges)
    num_remove = int(remove_ratio * num_edges)
    remove_idx = np.random.choice(existing_edges, size=num_remove, replace=False)
    rows_flat, cols_flat = non_diag
    rows_to_remove = rows_flat[remove_idx]
    cols_to_remove = cols_flat[remove_idx]
    W_masked[rows_to_remove, cols_to_remove] = 0.0
    return W_masked

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    nz = np.nonzero((adj > 0).astype(np.int32))
    edge_index = torch.from_numpy(np.array(nz)).long()
    return Data(x=node_feats, edge_index=edge_index)

# ==========================
#  Experiment Configuration
# ==========================
alpha = 0.92
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 0.0001
weight_decay = 1e-4
batch_print_freq = 100
lambda_mod = 0.05
lambda_recon = 1.0
dropout = 0.2

# ==========================
#  Run Two Experiments
# ==========================
def run_experiment(W_masked, data_label=""):
    data = load_data(W_masked, X).to(device)
    A_tensor = torch.from_numpy(W_masked).float().to(device)

    sss = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=SEED)
    accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

    for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
        print(f"\n=== {data_label} | Fold {fold} ===")

        cn_idx = np.where(y == 0)[0]
        ad_idx = np.where(y == 1)[0]
        sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=fold)
        cn_train_idx_small, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
        ad_train_idx_small, _ = next(sss_class.split(X[ad_idx], y[ad_idx]))
        cn_train = cn_idx[cn_train_idx_small]
        ad_train = ad_idx[ad_train_idx_small]
        train_idx_final = np.concatenate([cn_train, ad_train])
        np.random.shuffle(train_idx_final)

        train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
        y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)

        model = ARMA_MAE_Mod(
            input_dim=num_feats,
            hidden_dim=hidden_dim,
            output_dim=num_classes,
            num_stacks=1,
            num_layers=1,
            dropout=dropout
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        ce_loss_fn = nn.CrossEntropyLoss()
        recon_loss_fn = nn.MSELoss()

        data_on_device = data
        data_x = data_on_device.x.to(device)

        for epoch in range(1, num_epochs + 1):
            model.train()
            optimizer.zero_grad()
            logits, recon, _ = model(data_on_device)
            loss_sup = ce_loss_fn(logits[train_idx_t], y_train_tensor)
            loss_recon = recon_loss_fn(recon, data_x)
            loss_mod = model.modularity_loss(A_tensor, logits)
            total_loss = loss_sup + lambda_recon * loss_recon + lambda_mod * loss_mod
            total_loss.backward()
            optimizer.step()

            if epoch % batch_print_freq == 0 or epoch == 1:
                model.eval()
                with torch.no_grad():
                    logits_eval, _, _ = model(data_on_device)
                    preds_train = logits_eval[train_idx_t].argmax(dim=1).cpu().numpy()
                    acc_train = accuracy_score(y_train_tensor.cpu().numpy(), preds_train)
                print(f"Epoch {epoch}: Total={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | Recon={loss_recon.item():.6f} | Mod={loss_mod.item():.6f} | TrainAcc={acc_train:.4f}")

        # Evaluation
        model.eval()
        with torch.no_grad():
            logits_final, _, _ = model(data_on_device)
            preds_all = logits_final.argmax(dim=1).cpu().numpy()
            probs_all = torch.softmax(logits_final, dim=1)[:, 1].cpu().numpy()

        y_test = y[test_idx_global]
        y_pred_test = preds_all[test_idx_global]
        y_prob_test = probs_all[test_idx_global]
        acc = accuracy_score(y_test, y_pred_test)
        prec = precision_score(y_test, y_pred_test, zero_division=0)
        rec = recall_score(y_test, y_pred_test, zero_division=0)
        f1 = f1_score(y_test, y_pred_test, zero_division=0)
        auc = roc_auc_score(y_test, y_prob_test) if len(np.unique(y_test)) > 1 else float("nan")
        ce = log_loss(y_test, y_prob_test) if len(np.unique(y_test)) > 1 else float("nan")

        accuracies.append(acc)
        precisions.append(prec)
        recalls.append(rec)
        f1_scores.append(f1)
        aucs.append(auc)
        ce_losses.append(ce)
        print(f"Fold {fold} → Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | AUC={auc:.4f} | CE={ce:.4f}")

    print(f"\n=== {data_label} | Average Results ===")
    print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
    print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
    print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
    print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
    print(f"AUC:       {np.nanmean(aucs):.4f} ± {np.nanstd(aucs):.4f}")
    print(f"CE Loss:   {np.nanmean(ce_losses):.4f} ± {np.nanstd(ce_losses):.4f}")

# ==========================
#  Run Symmetric and Asymmetric
# ==========================
print("\n================ SYMMETRIC EDGE REMOVAL (20%) ================")
W = create_adj(X, alpha)
W_sym_masked = mask_edges_symmetric(W, remove_ratio=0.5, seed=SEED)
run_experiment(W_sym_masked, data_label="Symmetric Masking (CN vs AD)")

print("\n================ ASYMMETRIC EDGE REMOVAL (20%) ================")
W = create_adj(X, alpha)
W_asym_masked = mask_edges_asymmetric(W, remove_ratio=0.5, seed=SEED)
run_experiment(W_asym_masked, data_label="Asymmetric Masking (CN vs AD)")

Features: (300, 180), Labels: (300,)


=== Symmetric Masking | Fold 1 ===
Epoch 1: Total=0.901895 | Sup=0.760507 | Recon=0.148204 | Mod=-0.136326 | TrainAcc=0.4483
Epoch 100: Total=0.252618 | Sup=0.235272 | Recon=0.026891 | Mod=-0.190918 | TrainAcc=0.9310
Epoch 200: Total=0.115698 | Sup=0.115503 | Recon=0.010346 | Mod=-0.203012 | TrainAcc=1.0000
Epoch 300: Total=0.111476 | Sup=0.116088 | Recon=0.005890 | Mod=-0.210042 | TrainAcc=1.0000
Epoch 400: Total=0.044856 | Sup=0.052184 | Recon=0.003160 | Mod=-0.209764 | TrainAcc=1.0000
Epoch 500: Total=0.032036 | Sup=0.040899 | Recon=0.002106 | Mod=-0.219407 | TrainAcc=1.0000
Epoch 600: Total=0.046887 | Sup=0.056544 | Recon=0.001487 | Mod=-0.222871 | TrainAcc=1.0000
Epoch 700: Total=0.002858 | Sup=0.013096 | Recon=0.001103 | Mod=-0.226817 | TrainAcc=1.0000
Epoch 800: Total=-0.003963 | Sup=0.006744 | Recon=0.000876 | Mod=-0.231663 | TrainAcc=1.0000
Epoch 900: Total=0.001118 | Sup=0.011650 | Recon=0.000708 | Mod=-0.224803 | TrainAcc=1.0000
Epoch 1

ARMA + CE + RECON

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

fa_feature_path_cn = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
fa_feature_path_ad = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy" # Assuming this is the AD data
Histogram_feature_CN_FA_array = np.load(fa_feature_path_cn, allow_pickle=True)
Histogram_feature_AD_FA_array = np.load(fa_feature_path_ad, allow_pickle=True)
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)
])
num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

class ARMA_MAE_Mod(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_stacks=1, num_layers=1, dropout=0.2, activ="SELU"):
        super().__init__()
        self.arma1 = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            dropout=dropout
        )
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim, output_dim)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, input_dim)
        )
        activations = {
            "LeakyReLU": nnFn.leaky_relu,
            "SELU": nnFn.selu,
            "ReLU": nnFn.relu,
            "GELU": nnFn.gelu
        }
        self.act = activations.get(activ, nnFn.selu)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        h = self.arma1(x, edge_index)
        h = self.act(h)
        h = self.bn1(h)
        h = self.dropout(h)
        logits = self.classifier(h)
        recon = self.decoder(h)
        return logits, recon, h

def create_adj(F, alpha=1.0):
    row_norms = np.linalg.norm(F, axis=1, keepdims=True)
    row_norms[row_norms == 0] = 1.0
    F_norm = F / row_norms
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1.0, 0.0).astype(np.float32)
    if W.max() > 0:
        W = W / W.max()
    else:
        W = W
    return W

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    nz = np.nonzero((adj > 0).astype(np.int32))
    edge_index = torch.from_numpy(np.array(nz)).long()
    return Data(x=node_feats, edge_index=edge_index)

alpha = 0.92
feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 0.0001
weight_decay = 1e-4
batch_print_freq = 100
lambda_recon = 1.0
dropout = 0.2

W = create_adj(X, alpha)
data = load_data(W, X).to(device)
print(data)

sss = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=SEED)
accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y), start=1):
    print(f"\n=== Fold {fold} ===")
    cn_idx = np.where(y == 0)[0]
    ad_idx = np.where(y == 1)[0]
    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=fold)
    cn_train_idx_small, _ = next(sss_class.split(X[cn_idx], y[cn_idx]))
    ad_train_idx_small, _ = next(sss_class.split(X[ad_idx], y[ad_idx]))
    cn_train = cn_idx[cn_train_idx_small]
    ad_train = ad_idx[ad_train_idx_small]
    train_idx_final = np.concatenate([cn_train, ad_train])
    np.random.shuffle(train_idx_final)
    print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}, Total train: {len(train_idx_final)}")
    train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
    y_train_tensor = torch.from_numpy(y[train_idx_final]).long().to(device)
    model = ARMA_MAE_Mod(
        input_dim=feats_dim,
        hidden_dim=hidden_dim,
        output_dim=num_classes,
        num_stacks=1,
        num_layers=1,
        dropout=dropout
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce_loss_fn = nn.CrossEntropyLoss()
    recon_loss_fn = nn.MSELoss()
    data_on_device = data
    data_x = data_on_device.x.to(device)
    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        logits, recon, _ = model(data_on_device)
        loss_sup = ce_loss_fn(logits[train_idx_t], y_train_tensor)
        loss_recon = recon_loss_fn(recon, data_x)
        total_loss = loss_sup + lambda_recon * loss_recon
        total_loss.backward()
        optimizer.step()
        if epoch % batch_print_freq == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval, _, _ = model(data_on_device)
                preds_train = logits_eval[train_idx_t].argmax(dim=1).cpu().numpy()
                acc_train = accuracy_score(y_train_tensor.cpu().numpy(), preds_train)
            print(f"Fold {fold} Epoch {epoch}: Total={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | Recon={loss_recon.item():.6f} | TrainAcc={acc_train:.4f}")
    model.eval()
    with torch.no_grad():
        logits_final, _, _ = model(data_on_device)
        preds_all = logits_final.argmax(dim=1).cpu().numpy()
        probs_all = torch.softmax(logits_final, dim=1)[:, 1].cpu().numpy()
    y_test = y[test_idx_global]
    y_pred_test = preds_all[test_idx_global]
    y_prob_test = probs_all[test_idx_global]
    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    try:
        auc = roc_auc_score(y_test, y_prob_test)
    except Exception:
        auc = float("nan")
    try:
        ce = log_loss(y_test, y_prob_test)
    except Exception:
        ce = float("nan")
    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    aucs.append(auc)
    ce_losses.append(ce)
    print(f"Fold {fold} → Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")

print("\n=== Average Results Across Folds ===")
print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"AUC:       {np.nanmean(aucs):.4f} ± {np.nanstd(aucs):.4f}")
print(f"CE Loss:   {np.nanmean(ce_losses):.4f} ± {np.nanstd(ce_losses):.4f}")

Features: (300, 180), Labels: (300,)
Data(x=[300, 180], edge_index=[2, 13604])

=== Fold 1 ===
Train CN: 13, Train MCI: 16, Total train: 29
Fold 1 Epoch 1: Total=0.898942 | Sup=0.750762 | Recon=0.148180 | TrainAcc=0.4483
Fold 1 Epoch 100: Total=0.280364 | Sup=0.253722 | Recon=0.026642 | TrainAcc=0.9310
Fold 1 Epoch 200: Total=0.141014 | Sup=0.130706 | Recon=0.010308 | TrainAcc=1.0000
Fold 1 Epoch 300: Total=0.167314 | Sup=0.161501 | Recon=0.005813 | TrainAcc=1.0000
Fold 1 Epoch 400: Total=0.076450 | Sup=0.073255 | Recon=0.003195 | TrainAcc=1.0000
Fold 1 Epoch 500: Total=0.055084 | Sup=0.052965 | Recon=0.002119 | TrainAcc=1.0000
Fold 1 Epoch 600: Total=0.084500 | Sup=0.082955 | Recon=0.001545 | TrainAcc=1.0000
Fold 1 Epoch 700: Total=0.020026 | Sup=0.018924 | Recon=0.001102 | TrainAcc=1.0000
Fold 1 Epoch 800: Total=0.011563 | Sup=0.010683 | Recon=0.000880 | TrainAcc=1.0000
Fold 1 Epoch 900: Total=0.019120 | Sup=0.018423 | Recon=0.000697 | TrainAcc=1.0000
Fold 1 Epoch 1000: Total=0.01434