In [1]:
import os
import sys
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss, roc_auc_score, roc_curve
from sklearn.ensemble import RandomForestClassifier  # (only if you want to compare)
import warnings
warnings.filterwarnings("ignore")

In [2]:
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.set_num_threads(4)

In [3]:
import numpy as np
import torch
import random

# Set seeds for reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.set_num_threads(4)

fa_cn = "Histogram_CN_FA_20bin_updated.npy"
fa_ad = "Histogram_AD_FA_20bin_updated.npy" 

X_cn = np.load(fa_cn, allow_pickle=True)
X_ad = np.load(fa_ad, allow_pickle=True)

X = np.vstack([X_cn, X_ad]).astype(np.float32)
y = np.hstack([
    np.zeros(X_cn.shape[0], dtype=np.int64),  # CN = 0
    np.ones(X_ad.shape[0], dtype=np.int64)    # AD = 1
])

perm = np.random.RandomState(SEED).permutation(len(X))
X = X[perm]
y = y[perm]

num_nodes, num_feats = X.shape
print(f"Dataset: nodes={num_nodes}, feats={num_feats}")
print(f"CN samples: {len(X_cn)}, AD samples: {len(X_ad)}")
print(f"Class balance → CN: {np.sum(y==0)}, AD: {np.sum(y==1)}")

Dataset: nodes=223, feats=180
CN samples: 133, AD samples: 90
Class balance → CN: 133, AD: 90


In [4]:
def create_adj(F, alpha=1):
    norms = np.linalg.norm(F, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    F_norm = F / norms
    W = np.dot(F_norm, F_norm.T)
    W = (W >= alpha).astype(np.float32)
    return W

In [5]:
W0 = create_adj(X, alpha=0.8)
print(f"W0: {W0.shape}")

W0: (223, 223)


In [6]:
def load_graph_torch(adj, node_feats):
    node_feats_t = torch.from_numpy(node_feats).float()
    edge_idx = np.array(np.nonzero(adj))
    edge_index = torch.from_numpy(edge_idx).long()
    return node_feats_t, edge_index
node_feats_all, edge_index_all = load_graph_torch(W0, X)
print(f"Number of edges: {edge_index_all.size(1)}")

Number of edges: 41689


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import ARMAConv

class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ="ELU", num_stacks=2, num_layers=2):
        super(ARMAEncoder, self).__init__()

        # Activation functions dictionary
        activations = {
            "SELU": F.selu,
            "SiLU": F.silu,
            "GELU": F.gelu,
            "ELU": F.elu,
            "RELU": F.relu
        }
        self.act = activations.get(activ, F.elu)

        # ARMAConv layer (instead of GCNConv)
        self.arma = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            shared_weights=True,
            dropout=0.3
        )

        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.arma(x, edge_index)
        x = self.act(x)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)

        return logits


In [8]:
class AvgReadout(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

In [9]:
class Discriminator(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)
    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 0).expand_as(h_pl)
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
        logits = torch.cat((sc_1, sc_2), 0)
        return logits

In [10]:
class DGI(nn.Module):
    def __init__(self, n_in, n_h, dropout=0.25):
        super().__init__()
        self.arma = ARMAEncoder(n_in, n_h, device='cuda' if torch.cuda.is_available() else 'cpu', activ=nn.ELU())
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, edge_index):
        # Create Data objects for the ARMAEncoder
        data1 = Data(x=seq1, edge_index=edge_index)
        data2 = Data(x=seq2, edge_index=edge_index)

        h_1 = self.arma(data1)
        c = self.read(h_1)
        c = self.sigm(c)
        h_2 = self.arma(data2)
        logits = self.disc(c, h_1, h_2)
        return logits, h_1

In [11]:
class DGI_with_classifier(DGI):
    def __init__(self, n_in, n_h, n_classes=2, cut=0, dropout=0.25):
        super().__init__(n_in, n_h, dropout=dropout)
        self.classifier = nn.Linear(n_h, n_classes)
        self.cut = cut

    def get_embeddings(self, node_feats, edge_index):
        _, embeddings = self.forward(node_feats, node_feats, edge_index)
        return embeddings

    def cut_loss(self, A, S):
        S = F.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)
        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)
        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(S.shape[1], device=A.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))
        return mincut_loss + ortho_loss

    def modularity_loss(self, A, S):
        C = F.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)
        I_S = torch.eye(C.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1
        return modularity_term + collapse_reg_term

    def Reg_loss(self, A, embeddings):
        logits = self.classifier(embeddings)
        if self.cut == 1:
            return self.cut_loss(A, logits)
        else:
            return self.modularity_loss(A, logits)

In [12]:
from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=SEED)
accuracies, precisions, recalls, f1_scores, losses = [], [], [], [], []
all_y_true = []
all_y_proba = []
all_fpr, all_tpr, all_auc = [], [], []

In [13]:
def get_device():
    try:
        if torch.cuda.is_available():
            dev = torch.device("cuda")
            # test a tiny tensor operation to check CUDA health
            torch.tensor([1.0], device=dev) + 1.0
            return dev
    except Exception as e:
        print("CUDA not usable, falling back to CPU:", e)
        try:
            torch.cuda.empty_cache()
        except:
            pass
    return torch.device("cpu")

device = get_device()
print("Using device:", device)

Using device: cuda


In [14]:
A_tensor = torch.from_numpy(W0).to(device)
hidden_dim = 512
cut = 0
num_epochs = 5000
lr = 1e-4
weight_decay = 1e-4
# sup_weight = 1.0
# dgi_weight = 0.3
reg_weight = 0.01

In [15]:
for fold, (train_idx, test_idx) in enumerate(sss.split(X, y)):
    print(f"\n=== Fold {fold+1} ===")

    cn_idx = np.where(y == 0)[0]
    ad_idx = np.where(y == 1)[0]

    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=fold)
    cn_train_idx, cn_test_idx = next(sss_class.split(X[cn_idx], y[cn_idx]))
    ad_train_idx, ad_test_idx = next(sss_class.split(X[ad_idx], y[ad_idx]))

    cn_train = cn_idx[cn_train_idx]
    ad_train = ad_idx[ad_train_idx]
    cn_test = cn_idx[cn_test_idx]
    ad_test = ad_idx[ad_test_idx]

    balanced_train_idx = np.concatenate([cn_train, ad_train])
    test_idx = np.concatenate([cn_test, ad_test])
    np.random.shuffle(balanced_train_idx)
    np.random.shuffle(test_idx)

    print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}")
    print(f"Test CN: {len(cn_test)}, Test AD: {len(ad_test)}")

    node_feats = node_feats_all.to(device)
    edge_index = edge_index_all.to(device)
    y_tensor = torch.from_numpy(y).long().to(device)
    train_idx_t = torch.from_numpy(balanced_train_idx).long().to(device)
    test_idx_t = torch.from_numpy(test_idx).long().to(device)

    N = node_feats.size(0)
    lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])

    model = DGI_with_classifier(num_feats, hidden_dim, n_classes=2, cut=cut).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    bce_loss = nn.BCEWithLogitsLoss()
    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()

        perm = torch.randperm(N, device=device)
        corrupt = node_feats[perm]

        logits_dgi, embeddings = model(node_feats, corrupt, edge_index)
        dgi_loss = bce_loss(logits_dgi.squeeze(), lbl)

        logits_cls = model.classifier(embeddings)
        train_logits = logits_cls[train_idx_t]
        train_labels = y_tensor[train_idx_t]
        supervised_loss = ce_loss(train_logits, train_labels)

        reg_loss_val = model.Reg_loss(A_tensor, embeddings)
        total_loss = supervised_loss + dgi_loss + reg_weight * reg_loss_val

        total_loss.backward()
        optimizer.step()

        if epoch % 500 == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval = model.classifier(model.get_embeddings(node_feats, edge_index))
                preds_train = torch.argmax(logits_eval[train_idx_t], dim=1).cpu().numpy()
                acc = accuracy_score(train_labels.cpu().numpy(), preds_train)
            print(f"Epoch {epoch}: Sup={supervised_loss.item():.4f} | DGI={dgi_loss.item():.4f} | Reg={reg_loss_val.item():.4f} | Total={total_loss.item():.4f} | TrainAcc={acc:.4f}")

    # ===== Evaluation =====
    model.eval()
    with torch.no_grad():
        emb_final = model.get_embeddings(node_feats, edge_index)
        logits_final = model.classifier(emb_final)
        probs = F.softmax(logits_final, dim=1).cpu().numpy()
        y_pred = np.argmax(probs, axis=1)

    y_test = y[test_idx]
    y_pred_test = y_pred[test_idx]
    y_proba_test = probs[test_idx, 1]  # probability of AD

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    loss_val = log_loss(y_test, y_proba_test)
    auc_score = roc_auc_score(y_test, y_proba_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    losses.append(loss_val)
    all_auc.append(auc_score)

    print(f"Fold {fold+1} → Acc={acc:.4f} Prec={prec:.4f} Rec={rec:.4f} F1={f1:.4f} AUC={auc_score:.4f}")

# ===== Final Summary =====
print("\n=== Average Results (CN vs AD) ===")
print(f"Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall: {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1: {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"LogLoss: {np.mean(losses):.4f} ± {np.std(losses):.4f}")
print(f"AUC: {np.mean(all_auc):.4f} ± {np.std(all_auc):.4f}")



=== Fold 1 ===
Train CN: 66, Train AD: 45
Test CN: 67, Test AD: 45
Epoch 1: Sup=0.7012 | DGI=0.7131 | Reg=-0.2841 | Total=1.4115 | TrainAcc=0.5946
Epoch 500: Sup=0.2033 | DGI=0.6949 | Reg=-0.2889 | Total=0.8953 | TrainAcc=0.8739
Epoch 1000: Sup=0.1827 | DGI=0.6935 | Reg=-0.2920 | Total=0.8732 | TrainAcc=0.8829
Epoch 1500: Sup=0.0943 | DGI=0.6963 | Reg=-0.2933 | Total=0.7877 | TrainAcc=0.9369
Epoch 2000: Sup=0.1316 | DGI=0.6932 | Reg=-0.2917 | Total=0.8218 | TrainAcc=0.9459
Epoch 2500: Sup=0.1625 | DGI=0.6959 | Reg=-0.2878 | Total=0.8556 | TrainAcc=1.0000
Epoch 3000: Sup=0.0531 | DGI=0.6984 | Reg=-0.2873 | Total=0.7487 | TrainAcc=0.9820
Epoch 3500: Sup=0.1045 | DGI=0.6933 | Reg=-0.2872 | Total=0.7950 | TrainAcc=1.0000
Epoch 4000: Sup=0.0888 | DGI=0.6936 | Reg=-0.2857 | Total=0.7796 | TrainAcc=1.0000
Epoch 4500: Sup=0.0437 | DGI=0.6937 | Reg=-0.2884 | Total=0.7344 | TrainAcc=1.0000
Epoch 5000: Sup=0.0558 | DGI=0.6939 | Reg=-0.2916 | Total=0.7468 | TrainAcc=1.0000
Fold 1 → Acc=0.7679 Pre

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score
)

A_tensor = torch.from_numpy(W0).to(device)
hidden_dim = 512
cut = 0
num_epochs = 5000
lr = 1e-4
weight_decay = 1e-4
reg_weights = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=SEED)

results_summary = []

for reg_weight in reg_weights:
    print(f"\n================ REG_WEIGHT = {reg_weight} ================")
    accuracies, precisions, recalls, f1_scores, losses, all_auc = [], [], [], [], [], []

    for fold, (train_idx, test_idx) in enumerate(sss.split(X, y)):
        print(f"\n=== Fold {fold+1} ===")

        cn_idx = np.where(y == 0)[0]
        ad_idx = np.where(y == 1)[0]

        sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=fold)
        cn_train_idx, cn_test_idx = next(sss_class.split(X[cn_idx], y[cn_idx]))
        ad_train_idx, ad_test_idx = next(sss_class.split(X[ad_idx], y[ad_idx]))

        cn_train = cn_idx[cn_train_idx]
        ad_train = ad_idx[ad_train_idx]
        cn_test = cn_idx[cn_test_idx]
        ad_test = ad_idx[ad_test_idx]

        balanced_train_idx = np.concatenate([cn_train, ad_train])
        test_idx = np.concatenate([cn_test, ad_test])
        np.random.shuffle(balanced_train_idx)
        np.random.shuffle(test_idx)

        print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}")
        print(f"Test CN: {len(cn_test)}, Test AD: {len(ad_test)}")

        node_feats = node_feats_all.to(device)
        edge_index = edge_index_all.to(device)
        y_tensor = torch.from_numpy(y).long().to(device)
        train_idx_t = torch.from_numpy(balanced_train_idx).long().to(device)
        test_idx_t = torch.from_numpy(test_idx).long().to(device)

        N = node_feats.size(0)
        lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])

        model = DGI_with_classifier(num_feats, hidden_dim, n_classes=2, cut=cut).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        bce_loss = nn.BCEWithLogitsLoss()
        ce_loss = nn.CrossEntropyLoss()

        for epoch in range(1, num_epochs + 1):
            model.train()
            optimizer.zero_grad()

            perm = torch.randperm(N, device=device)
            corrupt = node_feats[perm]

            logits_dgi, embeddings = model(node_feats, corrupt, edge_index)
            dgi_loss = bce_loss(logits_dgi.squeeze(), lbl)

            logits_cls = model.classifier(embeddings)
            train_logits = logits_cls[train_idx_t]
            train_labels = y_tensor[train_idx_t]
            supervised_loss = ce_loss(train_logits, train_labels)

            reg_loss_val = model.Reg_loss(A_tensor, embeddings)
            total_loss = supervised_loss + dgi_loss + reg_weight * reg_loss_val

            total_loss.backward()
            optimizer.step()

            if epoch % 500 == 0 or epoch == 1:
                model.eval()
                with torch.no_grad():
                    logits_eval = model.classifier(model.get_embeddings(node_feats, edge_index))
                    preds_train = torch.argmax(logits_eval[train_idx_t], dim=1).cpu().numpy()
                    acc = accuracy_score(train_labels.cpu().numpy(), preds_train)
                print(f"Epoch {epoch}: Sup={supervised_loss.item():.4f} | "
                      f"DGI={dgi_loss.item():.4f} | Reg={reg_loss_val.item():.4f} | "
                      f"Total={total_loss.item():.4f} | TrainAcc={acc:.4f}")

        # ===== Evaluation =====
        model.eval()
        with torch.no_grad():
            emb_final = model.get_embeddings(node_feats, edge_index)
            logits_final = model.classifier(emb_final)
            probs = F.softmax(logits_final, dim=1).cpu().numpy()
            y_pred = np.argmax(probs, axis=1)

        y_test = y[test_idx]
        y_pred_test = y_pred[test_idx]
        y_proba_test = probs[test_idx, 1]  # probability of AD

        acc = accuracy_score(y_test, y_pred_test)
        prec = precision_score(y_test, y_pred_test, zero_division=0)
        rec = recall_score(y_test, y_pred_test, zero_division=0)
        f1 = f1_score(y_test, y_pred_test, zero_division=0)
        loss_val = log_loss(y_test, y_proba_test)
        auc_score = roc_auc_score(y_test, y_proba_test)

        accuracies.append(acc)
        precisions.append(prec)
        recalls.append(rec)
        f1_scores.append(f1)
        losses.append(loss_val)
        all_auc.append(auc_score)

        print(f"Fold {fold+1} → Acc={acc:.4f} Prec={prec:.4f} Rec={rec:.4f} "
              f"F1={f1:.4f} AUC={auc_score:.4f}")

    mean_acc, std_acc = np.mean(accuracies), np.std(accuracies)
    mean_prec, std_prec = np.mean(precisions), np.std(precisions)
    mean_rec, std_rec = np.mean(recalls), np.std(recalls)
    mean_f1, std_f1 = np.mean(f1_scores), np.std(f1_scores)
    mean_loss, std_loss = np.mean(losses), np.std(losses)
    mean_auc, std_auc = np.mean(all_auc), np.std(all_auc)

    results_summary.append({
        "Reg_Weight": reg_weight,
        "Accuracy": f"{mean_acc:.4f} ± {std_acc:.4f}",
        "Precision": f"{mean_prec:.4f} ± {std_prec:.4f}",
        "Recall": f"{mean_rec:.4f} ± {std_rec:.4f}",
        "F1": f"{mean_f1:.4f} ± {std_f1:.4f}",
        "LogLoss": f"{mean_loss:.4f} ± {std_loss:.4f}",
        "AUC": f"{mean_auc:.4f} ± {std_auc:.4f}"
    })

    print(f"\n=== Average Results for reg_weight = {reg_weight} (CN vs AD) ===")
    print(f"Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    print(f"Precision: {mean_prec:.4f} ± {std_prec:.4f}")
    print(f"Recall: {mean_rec:.4f} ± {std_rec:.4f}")
    print(f"F1: {mean_f1:.4f} ± {std_f1:.4f}")
    print(f"LogLoss: {mean_loss:.4f} ± {std_loss:.4f}")
    print(f"AUC: {mean_auc:.4f} ± {std_auc:.4f}")

# ============================
# Final summary table
# ============================
print("\n\n========== FINAL SUMMARY TABLE (CN vs AD) ==========")
results_df = pd.DataFrame(results_summary)
print(results_df.to_string(index=False))




=== Fold 1 ===
Train CN: 66, Train AD: 45
Test CN: 67, Test AD: 45
Epoch 1: Sup=0.6975 | DGI=0.7137 | Reg=-0.2841 | Total=1.4109 | TrainAcc=0.6396
Epoch 500: Sup=0.2330 | DGI=0.6951 | Reg=-0.2875 | Total=0.9278 | TrainAcc=0.8468
Epoch 1000: Sup=0.1725 | DGI=0.6932 | Reg=-0.2918 | Total=0.8654 | TrainAcc=0.8649
Epoch 1500: Sup=0.1161 | DGI=0.6943 | Reg=-0.2876 | Total=0.8101 | TrainAcc=0.9099
Epoch 2000: Sup=0.1372 | DGI=0.6932 | Reg=-0.2915 | Total=0.8300 | TrainAcc=0.9369
Epoch 2500: Sup=0.1276 | DGI=0.6936 | Reg=-0.2837 | Total=0.8209 | TrainAcc=0.9730
Epoch 3000: Sup=0.0764 | DGI=0.6935 | Reg=-0.2921 | Total=0.7696 | TrainAcc=0.9640
Epoch 3500: Sup=0.0372 | DGI=0.6932 | Reg=-0.2921 | Total=0.7301 | TrainAcc=0.9820
Epoch 4000: Sup=0.0375 | DGI=0.6932 | Reg=-0.2912 | Total=0.7304 | TrainAcc=1.0000
Epoch 4500: Sup=0.0656 | DGI=0.6959 | Reg=-0.2916 | Total=0.7612 | TrainAcc=1.0000
Epoch 5000: Sup=0.0268 | DGI=0.6937 | Reg=-0.2890 | Total=0.7202 | TrainAcc=1.0000
Fold 1 → Acc=0.7321 Pr

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score
)

A_tensor = torch.from_numpy(W0).to(device)
hidden_dim = 512
cut = 0
num_epochs = 5000
lr = 1e-4
weight_decay = 1e-4
reg_weights = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=SEED)

results_summary = []

for reg_weight in reg_weights:
    print(f"\n================ REG_WEIGHT = {reg_weight} ================")
    accuracies, precisions, recalls, f1_scores, losses, all_auc = [], [], [], [], [], []

    for fold, (train_idx, test_idx) in enumerate(sss.split(X, y)):
        print(f"\n=== Fold {fold+1} ===")

        cn_idx = np.where(y == 0)[0]
        ad_idx = np.where(y == 1)[0]

        sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.1, random_state=fold)
        cn_train_idx, cn_test_idx = next(sss_class.split(X[cn_idx], y[cn_idx]))
        ad_train_idx, ad_test_idx = next(sss_class.split(X[ad_idx], y[ad_idx]))

        cn_train = cn_idx[cn_train_idx]
        ad_train = ad_idx[ad_train_idx]
        cn_test = cn_idx[cn_test_idx]
        ad_test = ad_idx[ad_test_idx]

        balanced_train_idx = np.concatenate([cn_train, ad_train])
        test_idx = np.concatenate([cn_test, ad_test])
        np.random.shuffle(balanced_train_idx)
        np.random.shuffle(test_idx)

        print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}")
        print(f"Test CN: {len(cn_test)}, Test AD: {len(ad_test)}")

        node_feats = node_feats_all.to(device)
        edge_index = edge_index_all.to(device)
        y_tensor = torch.from_numpy(y).long().to(device)
        train_idx_t = torch.from_numpy(balanced_train_idx).long().to(device)
        test_idx_t = torch.from_numpy(test_idx).long().to(device)

        N = node_feats.size(0)
        lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])

        model = DGI_with_classifier(num_feats, hidden_dim, n_classes=2, cut=cut).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        bce_loss = nn.BCEWithLogitsLoss()
        ce_loss = nn.CrossEntropyLoss()

        for epoch in range(1, num_epochs + 1):
            model.train()
            optimizer.zero_grad()

            perm = torch.randperm(N, device=device)
            corrupt = node_feats[perm]

            logits_dgi, embeddings = model(node_feats, corrupt, edge_index)
            dgi_loss = bce_loss(logits_dgi.squeeze(), lbl)

            logits_cls = model.classifier(embeddings)
            train_logits = logits_cls[train_idx_t]
            train_labels = y_tensor[train_idx_t]
            supervised_loss = ce_loss(train_logits, train_labels)

            reg_loss_val = model.Reg_loss(A_tensor, embeddings)
            total_loss = supervised_loss + dgi_loss + reg_weight * reg_loss_val

            total_loss.backward()
            optimizer.step()

            if epoch % 500 == 0 or epoch == 1:
                model.eval()
                with torch.no_grad():
                    logits_eval = model.classifier(model.get_embeddings(node_feats, edge_index))
                    preds_train = torch.argmax(logits_eval[train_idx_t], dim=1).cpu().numpy()
                    acc = accuracy_score(train_labels.cpu().numpy(), preds_train)
                print(f"Epoch {epoch}: Sup={supervised_loss.item():.4f} | "
                      f"DGI={dgi_loss.item():.4f} | Reg={reg_loss_val.item():.4f} | "
                      f"Total={total_loss.item():.4f} | TrainAcc={acc:.4f}")

        # ===== Evaluation =====
        model.eval()
        with torch.no_grad():
            emb_final = model.get_embeddings(node_feats, edge_index)
            logits_final = model.classifier(emb_final)
            probs = F.softmax(logits_final, dim=1).cpu().numpy()
            y_pred = np.argmax(probs, axis=1)

        y_test = y[test_idx]
        y_pred_test = y_pred[test_idx]
        y_proba_test = probs[test_idx, 1]  # probability of AD

        acc = accuracy_score(y_test, y_pred_test)
        prec = precision_score(y_test, y_pred_test, zero_division=0)
        rec = recall_score(y_test, y_pred_test, zero_division=0)
        f1 = f1_score(y_test, y_pred_test, zero_division=0)
        loss_val = log_loss(y_test, y_proba_test)
        auc_score = roc_auc_score(y_test, y_proba_test)

        accuracies.append(acc)
        precisions.append(prec)
        recalls.append(rec)
        f1_scores.append(f1)
        losses.append(loss_val)
        all_auc.append(auc_score)

        print(f"Fold {fold+1} → Acc={acc:.4f} Prec={prec:.4f} Rec={rec:.4f} "
              f"F1={f1:.4f} AUC={auc_score:.4f}")

    mean_acc, std_acc = np.mean(accuracies), np.std(accuracies)
    mean_prec, std_prec = np.mean(precisions), np.std(precisions)
    mean_rec, std_rec = np.mean(recalls), np.std(recalls)
    mean_f1, std_f1 = np.mean(f1_scores), np.std(f1_scores)
    mean_loss, std_loss = np.mean(losses), np.std(losses)
    mean_auc, std_auc = np.mean(all_auc), np.std(all_auc)

    results_summary.append({
        "Reg_Weight": reg_weight,
        "Accuracy": f"{mean_acc:.4f} ± {std_acc:.4f}",
        "Precision": f"{mean_prec:.4f} ± {std_prec:.4f}",
        "Recall": f"{mean_rec:.4f} ± {std_rec:.4f}",
        "F1": f"{mean_f1:.4f} ± {std_f1:.4f}",
        "LogLoss": f"{mean_loss:.4f} ± {std_loss:.4f}",
        "AUC": f"{mean_auc:.4f} ± {std_auc:.4f}"
    })

    print(f"\n=== Average Results for reg_weight = {reg_weight} (CN vs AD) ===")
    print(f"Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    print(f"Precision: {mean_prec:.4f} ± {std_prec:.4f}")
    print(f"Recall: {mean_rec:.4f} ± {std_rec:.4f}")
    print(f"F1: {mean_f1:.4f} ± {std_f1:.4f}")
    print(f"LogLoss: {mean_loss:.4f} ± {std_loss:.4f}")
    print(f"AUC: {mean_auc:.4f} ± {std_auc:.4f}")

# ============================
# Final summary table
# ============================
print("\n\n========== FINAL SUMMARY TABLE (CN vs AD) ==========")
results_df = pd.DataFrame(results_summary)
print(results_df.to_string(index=False))




=== Fold 1 ===
Train CN: 119, Train AD: 81
Test CN: 14, Test AD: 9
Epoch 1: Sup=0.6928 | DGI=0.7114 | Reg=-0.2840 | Total=1.4039 | TrainAcc=0.5950
Epoch 500: Sup=0.3188 | DGI=0.6949 | Reg=-0.2871 | Total=1.0134 | TrainAcc=0.7800
Epoch 1000: Sup=0.2102 | DGI=0.6972 | Reg=-0.2854 | Total=0.9071 | TrainAcc=0.8050
Epoch 1500: Sup=0.2775 | DGI=0.6934 | Reg=-0.2850 | Total=0.9707 | TrainAcc=0.8400
Epoch 2000: Sup=0.2026 | DGI=0.6944 | Reg=-0.2851 | Total=0.8967 | TrainAcc=0.8150
Epoch 2500: Sup=0.2446 | DGI=0.6939 | Reg=-0.2873 | Total=0.9382 | TrainAcc=0.8400
Epoch 3000: Sup=0.1368 | DGI=0.6969 | Reg=-0.2869 | Total=0.8334 | TrainAcc=0.8650
Epoch 3500: Sup=0.1190 | DGI=0.6950 | Reg=-0.2856 | Total=0.8138 | TrainAcc=0.9000
Epoch 4000: Sup=0.0954 | DGI=0.7172 | Reg=-0.2871 | Total=0.8123 | TrainAcc=0.9550
Epoch 4500: Sup=0.1173 | DGI=0.6946 | Reg=-0.2833 | Total=0.8116 | TrainAcc=0.9600
Epoch 5000: Sup=0.0826 | DGI=0.6932 | Reg=-0.2861 | Total=0.7755 | TrainAcc=0.9550
Fold 1 → Acc=0.7391 Pr