In [6]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

cn_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
ad_path = "/home/snu/Downloads/Histogram_AD_FA_20bin_updated.npy"


Histogram_feature_CN_FA_array = np.load(cn_path, allow_pickle=True)
Histogram_feature_AD_FA_array = np.load(ad_path, allow_pickle=True)

X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),  # CN → 0
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)    # AD → 1
])

num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

# Shuffle for reproducibility
rng = np.random.RandomState(42)
perm = rng.permutation(num_nodes)
X = X[perm]
y = y[perm]

def create_adj(F_data, alpha=1):
    norms = np.linalg.norm(F_data, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    F_norm = F_data / norms
    W = np.dot(F_norm, F_norm.T)
    W = np.where(W >= alpha, 1.0, 0.0).astype(np.float32)
    if W.max() > 0:
        W = W / W.max()
    return W

alpha = 0.8
W = create_adj(X, alpha=alpha)
np.fill_diagonal(W, 1.0)

def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    rows, cols = np.nonzero(adj > 0)
    edge_index = torch.from_numpy(np.vstack([rows, cols]).astype(np.int64))
    return Data(x=node_feats, edge_index=edge_index)

data = load_data(W, X).to(device)
print(data)

class GCN_AE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device, activ="RELU", dropout=0.25):
        super(GCN_AE, self).__init__()
        self.device = device

        # --- Encoder ---
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)

        # --- Decoder ---
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

        # --- Activation ---
        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu,
            "ELU": nnFn.elu
        }
        self.act = activations.get(activ, nnFn.relu)

    def forward(self, x, edge_index):
        # Encoder
        h = self.conv1(x, edge_index)
        h = self.bn1(h)
        h = self.act(h)
        h = self.dropout(h)

        # Classification
        logits = self.fc(h)

        # Reconstruction
        recon = self.decoder(h)

        return logits, recon, h


feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 1e-4
weight_decay = 1e-4
batch_print_freq = 100
lambda_recon = 0.01

sss = StratifiedShuffleSplit(n_splits=20, test_size=0.3, random_state=42)
accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []
y_all = y.copy()

for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y_all), start=1):
    print(f"\n=== Fold {fold} ===")
    cn_idx = np.where(y_all == 0)[0]
    ad_idx = np.where(y_all == 1)[0]
    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.3, random_state=fold)
    cn_train_idx, _ = next(sss_class.split(X[cn_idx], y_all[cn_idx]))
    ad_train_idx, _ = next(sss_class.split(X[ad_idx], y_all[ad_idx]))
    cn_train = cn_idx[cn_train_idx]
    ad_train = ad_idx[ad_train_idx]
    train_idx_final = np.concatenate([cn_train, ad_train])
    np.random.shuffle(train_idx_final)

    print(f"Train CN: {len(cn_train)}, Train AD: {len(ad_train)}")
    train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
    y_train_tensor = torch.from_numpy(y_all[train_idx_final]).long().to(device)

    model = GCN_AE(feats_dim, hidden_dim, num_classes, device, activ="RELU", dropout=0.25).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce_loss_fn = nn.CrossEntropyLoss()
    recon_loss_fn = nn.MSELoss()
    data = load_data(W, X).to(device)

    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        logits, recon_x, _ = model(data.x, data.edge_index)

        loss_sup = ce_loss_fn(logits[train_idx_t], y_train_tensor)
        loss_recon = recon_loss_fn(recon_x, data.x)
        total_loss = loss_sup + lambda_recon * loss_recon

        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print(f"NaN/Inf detected at fold {fold} epoch {epoch}. Aborting epoch.")
            break

        total_loss.backward()
        optimizer.step()

        if epoch % batch_print_freq == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval, _, _ = model(data.x, data.edge_index)
                preds_train = logits_eval[train_idx_t].argmax(dim=1)
                acc_train = accuracy_score(y_train_tensor.cpu(), preds_train.cpu())
            print(f"Fold {fold} Epoch {epoch}: TotalLoss={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | Recon={loss_recon.item():.6f} | TrainAcc={acc_train:.4f}")

    model.eval()
    with torch.no_grad():
        out_logits, _, _ = model(data.x, data.edge_index)
        preds = out_logits.argmax(dim=1).cpu().numpy()
        probs = nnFn.softmax(out_logits, dim=1)[:, 1].cpu().numpy()

    y_test = y_all[test_idx_global]
    y_pred_test = preds[test_idx_global]
    y_prob_test = probs[test_idx_global]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    try:
        auc = roc_auc_score(y_test, y_prob_test)
    except ValueError:
        auc = float('nan')
    ce = log_loss(y_test, y_prob_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    aucs.append(auc)
    ce_losses.append(ce)

    print(f"Fold {fold} → Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")


print("\n=== Average Results Across 20 Folds ===")
print(f"Accuracy:  {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"F1-score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"CE Loss:   {np.mean(ce_losses):.4f} ± {np.std(ce_losses):.4f}")
print(f"AUC:       {np.nanmean(aucs):.4f} ± {np.nanstd(aucs):.4f}")

Device: cuda
Features: (223, 180), Labels: (223,)
Data(x=[223, 180], edge_index=[2, 41689])

=== Fold 1 ===
Train CN: 93, Train AD: 63
Fold 1 Epoch 1: TotalLoss=0.726871 | Sup=0.726519 | Recon=0.035201 | TrainAcc=0.5962
Fold 1 Epoch 100: TotalLoss=0.454486 | Sup=0.454447 | Recon=0.003960 | TrainAcc=0.7821
Fold 1 Epoch 200: TotalLoss=0.440659 | Sup=0.440640 | Recon=0.001911 | TrainAcc=0.7949
Fold 1 Epoch 300: TotalLoss=0.431431 | Sup=0.431417 | Recon=0.001326 | TrainAcc=0.8013
Fold 1 Epoch 400: TotalLoss=0.416499 | Sup=0.416488 | Recon=0.001099 | TrainAcc=0.7949
Fold 1 Epoch 500: TotalLoss=0.413585 | Sup=0.413574 | Recon=0.001028 | TrainAcc=0.8141
Fold 1 Epoch 600: TotalLoss=0.396148 | Sup=0.396138 | Recon=0.000997 | TrainAcc=0.8205
Fold 1 Epoch 700: TotalLoss=0.384025 | Sup=0.384015 | Recon=0.001003 | TrainAcc=0.8141
Fold 1 Epoch 800: TotalLoss=0.376182 | Sup=0.376172 | Recon=0.001010 | TrainAcc=0.8013
Fold 1 Epoch 900: TotalLoss=0.362660 | Sup=0.362649 | Recon=0.001024 | TrainAcc=0.82

In [3]:
feats_dim = num_feats
hidden_dim = 512
num_classes = 2
num_epochs = 2000
lr = 1e-4
weight_decay = 1e-4
batch_print_freq = 100

# Lambda values to test
lambda_list = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]

# === Cross-validation setup ===
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.3, random_state=42)
y_all = y.copy()

# To store all results for each lambda
lambda_results = []

# === Outer loop for lambda sweep ===
for lambda_recon in lambda_list:
    print(f"\n\n====================== λ_recon = {lambda_recon} ======================")

    accuracies, precisions, recalls, f1_scores, aucs, ce_losses = [], [], [], [], [], []

    for fold, (train_val_idx, test_idx_global) in enumerate(sss.split(X, y_all), start=1):
        print(f"\n=== Fold {fold} ===")

        control_idx = np.where(y_all == 0)[0]
        patient_idx = np.where(y_all == 1)[0]
        sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.3, random_state=fold)
        control_train_idx, _ = next(sss_class.split(X[control_idx], y_all[control_idx]))
        patient_train_idx, _ = next(sss_class.split(X[patient_idx], y_all[patient_idx]))
        control_train = control_idx[control_train_idx]
        patient_train = patient_idx[patient_train_idx]
        train_idx_final = np.concatenate([control_train, patient_train])
        np.random.shuffle(train_idx_final)

        print(f"Train Control: {len(control_train)}, Train Patient: {len(patient_train)}")

        train_idx_t = torch.from_numpy(train_idx_final).long().to(device)
        y_train_tensor = torch.from_numpy(y_all[train_idx_final]).long().to(device)

        model = GCN_AE(feats_dim, hidden_dim, num_classes, device, activ="RELU", dropout=0.25).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        ce_loss_fn = nn.CrossEntropyLoss()
        recon_loss_fn = nn.MSELoss()
        data = load_data(W, X).to(device)

        for epoch in range(1, num_epochs + 1):
            model.train()
            optimizer.zero_grad()
            logits, recon_x, _ = model(data.x, data.edge_index)

            loss_sup = ce_loss_fn(logits[train_idx_t], y_train_tensor)
            loss_recon = recon_loss_fn(recon_x, data.x)
            total_loss = loss_sup + lambda_recon * loss_recon

            if torch.isnan(total_loss) or torch.isinf(total_loss):
                print(f"NaN/Inf detected at fold {fold} epoch {epoch}. Aborting epoch.")
                break

            total_loss.backward()
            optimizer.step()

            if epoch % batch_print_freq == 0 or epoch == 1:
                model.eval()
                with torch.no_grad():
                    logits_eval, _, _ = model(data.x, data.edge_index)
                    preds_train = logits_eval[train_idx_t].argmax(dim=1)
                    acc_train = accuracy_score(y_train_tensor.cpu(), preds_train.cpu())
                print(f"Fold {fold} Epoch {epoch}: TotalLoss={total_loss.item():.6f} | Sup={loss_sup.item():.6f} | Recon={loss_recon.item():.6f} | TrainAcc={acc_train:.4f}")

        # === Evaluation ===
        model.eval()
        with torch.no_grad():
            out_logits, _, _ = model(data.x, data.edge_index)
            preds = out_logits.argmax(dim=1).cpu().numpy()
            probs = nnFn.softmax(out_logits, dim=1)[:, 1].cpu().numpy()

        y_test = y_all[test_idx_global]
        y_pred_test = preds[test_idx_global]
        y_prob_test = probs[test_idx_global]

        acc = accuracy_score(y_test, y_pred_test)
        prec = precision_score(y_test, y_pred_test, zero_division=0)
        rec = recall_score(y_test, y_pred_test, zero_division=0)
        f1 = f1_score(y_test, y_pred_test, zero_division=0)
        try:
            auc = roc_auc_score(y_test, y_prob_test)
        except ValueError:
            auc = float('nan')
        ce = log_loss(y_test, y_prob_test)

        accuracies.append(acc)
        precisions.append(prec)
        recalls.append(rec)
        f1_scores.append(f1)
        aucs.append(auc)
        ce_losses.append(ce)

        print(f"Fold {fold} → Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | AUC={auc:.4f} | CE Loss={ce:.4f}")

    # === Store average results for this λ ===
    avg_results = {
        "lambda": lambda_recon,
        "Accuracy": np.mean(accuracies),
        "Precision": np.mean(precisions),
        "Recall": np.mean(recalls),
        "F1-score": np.mean(f1_scores),
        "CE Loss": np.mean(ce_losses),
        "AUC": np.nanmean(aucs)
    }
    lambda_results.append(avg_results)

    print("\n=== Average Results for λ =", lambda_recon, "===")
    print(f"Accuracy:  {avg_results['Accuracy']:.4f}")
    print(f"Precision: {avg_results['Precision']:.4f}")
    print(f"Recall:    {avg_results['Recall']:.4f}")
    print(f"F1-score:  {avg_results['F1-score']:.4f}")
    print(f"CE Loss:   {avg_results['CE Loss']:.4f}")
    print(f"AUC:       {avg_results['AUC']:.4f}")


# === FINAL SUMMARY ACROSS ALL λ ===
print("\n\n====================== FINAL SUMMARY ACROSS ALL λ ======================")
print(f"{'Lambda':<8} | {'Accuracy':<9} | {'Precision':<9} | {'Recall':<8} | {'F1-score':<9} | {'AUC':<8} | {'CE Loss':<8}")
print("-" * 72)
for res in lambda_results:
    print(f"{res['lambda']:<8} | {res['Accuracy']:<9.4f} | {res['Precision']:<9.4f} | {res['Recall']:<8.4f} | {res['F1-score']:<9.4f} | {res['AUC']:<8.4f} | {res['CE Loss']:<8.4f}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
=== Fold 6 ===
Train Control: 93, Train Patient: 63
Fold 6 Epoch 1: TotalLoss=0.714168 | Sup=0.712622 | Recon=0.030922 | TrainAcc=0.5962
Fold 6 Epoch 100: TotalLoss=0.468489 | Sup=0.468372 | Recon=0.002344 | TrainAcc=0.7821
Fold 6 Epoch 200: TotalLoss=0.434445 | Sup=0.434392 | Recon=0.001047 | TrainAcc=0.7949
Fold 6 Epoch 300: TotalLoss=0.406412 | Sup=0.406376 | Recon=0.000716 | TrainAcc=0.8141
Fold 6 Epoch 400: TotalLoss=0.381861 | Sup=0.381833 | Recon=0.000570 | TrainAcc=0.8333
Fold 6 Epoch 500: TotalLoss=0.366689 | Sup=0.366663 | Recon=0.000525 | TrainAcc=0.8397
Fold 6 Epoch 600: TotalLoss=0.360134 | Sup=0.360109 | Recon=0.000493 | TrainAcc=0.8590
Fold 6 Epoch 700: TotalLoss=0.332275 | Sup=0.332250 | Recon=0.000497 | TrainAcc=0.8654
Fold 6 Epoch 800: TotalLoss=0.319618 | Sup=0.319594 | Recon=0.000490 | TrainAcc=0.8462
Fold 6 Epoch 900: TotalLoss=0.309439 | Sup=0.309414 | Recon=0.000493 | TrainAcc=0.8718
Fold 6 Epoch 10