In [6]:
import os
import sys
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, log_loss, roc_auc_score
)
import warnings
warnings.filterwarnings("ignore")

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.set_num_threads(4)

# === Load Patients ===
fa_patients_path = "/home/snu/Downloads/NIFD_Patients_FA_Histogram_Feature.npy"
Patients_FA_array = np.load(fa_patients_path, allow_pickle=True)

# === Load Controls ===
fa_controls_path = "/home/snu/Downloads/NIFD_Control_FA_Histogram_Feature.npy"
Controls_FA_array = np.load(fa_controls_path, allow_pickle=True)

print("Patients Shape:", Patients_FA_array.shape)
print("Controls Shape:", Controls_FA_array.shape)

# === Combine features and labels ===
X = np.vstack([Controls_FA_array, Patients_FA_array])
y = np.hstack([
    np.zeros(Controls_FA_array.shape[0], dtype=np.int64),  # 0 = Control
    np.ones(Patients_FA_array.shape[0], dtype=np.int64)    # 1 = Patient
])

# Shuffle
np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]

# ------------------------
# Adjacency building
# ------------------------
def create_adj(F, alpha=1.0):
    norms = np.linalg.norm(F, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    F_norm = F / norms
    W = np.dot(F_norm, F_norm.T)
    W = (W >= alpha).astype(np.float32)
    return W

W0 = create_adj(X, alpha=0.5)
print(f"W0: {W0.shape}")

def load_graph_torch(adj, node_feats):
    node_feats_t = torch.from_numpy(node_feats).float()
    edge_idx = np.array(np.nonzero(adj))
    edge_index = torch.from_numpy(edge_idx).long()
    return node_feats_t, edge_index

node_feats_all, edge_index_all = load_graph_torch(W0, X)
print(f"Number of edges: {edge_index_all.size(1)}")

# ------------------------
# Model components
# ------------------------
class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_stacks=1, num_layers=1, activ="RELU"):
        super(ARMAEncoder, self).__init__()
        activations = {
            "SELU": F.selu,
            "SiLU": F.silu,
            "GELU": F.gelu,
            "ELU": F.elu,
            "RELU": F.relu
        }
        self.act = activations.get(activ, F.elu)
        self.arma = ARMAConv(
            input_dim, hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            shared_weights=True,
            dropout=0.3
        )
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.arma(x, edge_index)
        x = self.act(x)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

class AvgReadout(nn.Module):
    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 0).expand_as(h_pl)
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
        return torch.cat((sc_1, sc_2), 0)

class DGI(nn.Module):
    def __init__(self, n_in, n_h, num_stacks=2, num_layers=1):
        super().__init__()
        self.gcn1 = ARMAEncoder(n_in, n_h, num_stacks=num_stacks, num_layers=num_layers)
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, edge_index):
        data1 = Data(x=seq1, edge_index=edge_index)
        data2 = Data(x=seq2, edge_index=edge_index)
        h_1 = self.gcn1(data1)
        c = self.read(h_1)
        c = self.sigm(c)
        h_2 = self.gcn1(data2)
        logits = self.disc(c, h_1, h_2)
        return logits, h_1

class DGI_with_classifier(DGI):
    def __init__(self, n_in, n_h, n_classes=2, cut=0, device='cpu', num_stacks=1, num_layers=1):
        super().__init__(n_in, n_h, num_stacks=num_stacks, num_layers=num_layers)
        self.classifier = nn.Linear(n_h, n_classes)
        self.cut = cut
        self.device = device
        self.n_clusters = n_classes

    def get_embeddings(self, node_feats, edge_index):
        _, embeddings = self.forward(node_feats, node_feats, edge_index)
        return embeddings

    def cut_loss(self, A, S):
        C = F.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, C).t(), C)
        num = torch.trace(A_pool)
        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, C).t(), C)
        den = torch.trace(D_pooled) + 1e-9
        mincut_loss = -(num / den)
        St_S = torch.matmul(C.t(), C)
        I_S = torch.eye(C.shape[1], device=A.device)
        ortho_loss = torch.norm(St_S / (torch.norm(St_S) + 1e-9) - I_S / (torch.norm(I_S) + 1e-9))
        return mincut_loss + ortho_loss

    def modularity_loss(self, A, S):
        C = F.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A) + 1e-9
        B = A - torch.outer(d, d) / (2 * m)
        I_S = torch.eye(C.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1
        return modularity_term + collapse_reg_term

    def Reg_loss(self, A, embeddings):
        logits = self.classifier(embeddings)
        if self.cut == 1:
            return self.cut_loss(A, logits)
        else:
            return self.modularity_loss(A, logits)

# ------------------------
# Prepare cross-validation
# ------------------------
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=SEED)
accuracies, precisions, recalls, f1_scores, losses, all_auc = [], [], [], [], [], []

def get_device():
    try:
        if torch.cuda.is_available():
            dev = torch.device("cuda")
            torch.tensor([1.0], device=dev) + 1.0
            return dev
    except Exception as e:
        print("CUDA not usable, falling back to CPU:", e)
        try:
            torch.cuda.empty_cache()
        except:
            pass
    return torch.device("cpu")

device = get_device()
print("Using device:", device)

A_tensor = torch.from_numpy(W0).float().to(device)
hidden_dim = 512
cut = 1
num_epochs = 2000
lr = 1e-4
weight_decay = 1e-4
reg_weight = 0.001
num_stacks = 1
num_layers = 1

node_feats = node_feats_all.to(device)
edge_index = edge_index_all.to(device)

N = node_feats.size(0)
lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)]).float()

# Define num_feats
num_feats = X.shape[1]

for fold, (train_idx, test_idx) in enumerate(sss.split(X, y), start=1):
    print(f"\n=== Fold {fold} (Controls vs Patients) ===")

    cn_idx = np.where(y == 0)[0]
    ad_idx = np.where(y == 1)[0]

    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=fold)
    cn_train_idx, cn_test_idx = next(sss_class.split(X[cn_idx], y[cn_idx]))
    ad_train_idx, ad_test_idx = next(sss_class.split(X[ad_idx], y[ad_idx]))

    cn_train = cn_idx[cn_train_idx]
    ad_train = ad_idx[ad_train_idx]
    cn_test = cn_idx[cn_test_idx]
    ad_test = ad_idx[ad_test_idx]

    balanced_train_idx = np.concatenate([cn_train, ad_train])
    test_idx_final = np.concatenate([cn_test, ad_test])
    np.random.shuffle(balanced_train_idx)
    np.random.shuffle(test_idx_final)

    print(f"Train Controls: {len(cn_train)}, Train Patients: {len(ad_train)}")
    print(f"Test Controls: {len(cn_test)}, Test Patients: {len(ad_test)}")

    y_tensor = torch.from_numpy(y).long().to(device)
    train_idx_t = torch.from_numpy(balanced_train_idx).long().to(device)
    test_idx_t = torch.from_numpy(test_idx_final).long().to(device)

    model = DGI_with_classifier(num_feats, hidden_dim, n_classes=2, cut=cut,
                                device=device, num_stacks=num_stacks, num_layers=num_layers).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    bce_loss = nn.BCEWithLogitsLoss()
    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        perm = torch.randperm(N, device=device)
        corrupt = node_feats[perm]
        logits_dgi, embeddings = model(node_feats, corrupt, edge_index)
        dgi_loss = bce_loss(logits_dgi.squeeze(), lbl)
        logits_cls = model.classifier(embeddings)
        train_logits = logits_cls[train_idx_t]
        train_labels = y_tensor[train_idx_t]
        supervised_loss = ce_loss(train_logits, train_labels)
        reg_loss_val = model.Reg_loss(A_tensor, embeddings)
        total_loss = supervised_loss + dgi_loss + reg_weight * reg_loss_val
        total_loss.backward()
        optimizer.step()

        if epoch % 500 == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval = model.classifier(model.get_embeddings(node_feats, edge_index))
                preds_train = torch.argmax(logits_eval[train_idx_t], dim=1).cpu().numpy()
                acc = accuracy_score(train_labels.cpu().numpy(), preds_train)
            print(f"Epoch {epoch}: Sup={supervised_loss.item():.4f} | DGI={dgi_loss.item():.4f} | "
                  f"Reg={reg_loss_val.item():.4f} | Total={total_loss.item():.4f} | TrainAcc={acc:.4f}")

    model.eval()
    with torch.no_grad():
        emb_final = model.get_embeddings(node_feats, edge_index)
        logits_final = model.classifier(emb_final)
        probs = F.softmax(logits_final, dim=1).cpu().numpy()
        y_pred = np.argmax(probs, axis=1)

    y_test = y[test_idx_final]
    y_pred_test = y_pred[test_idx_final]
    y_proba_test = probs[test_idx_final, 1]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    loss_val = log_loss(y_test, y_proba_test)
    auc_score = roc_auc_score(y_test, y_proba_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    losses.append(loss_val)
    all_auc.append(auc_score)

    print("\n=== Average Results (Controls vs Patients) ===")
    print(f"Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
    print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
    print(f"Recall: {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
    print(f"F1: {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
    print(f"LogLoss: {np.mean(losses):.4f} ± {np.std(losses):.4f}")
    print(f"AUC: {np.mean(all_auc):.4f} ± {np.std(all_auc):.4f}")

Patients Shape: (98, 180)
Controls Shape: (48, 180)
W0: (146, 146)
Number of edges: 21256
Using device: cuda

=== Fold 1 (Controls vs Patients) ===
Train Controls: 4, Train Patients: 9
Test Controls: 44, Test Patients: 89
Epoch 1: Sup=0.6504 | DGI=0.7147 | Reg=-0.2333 | Total=1.3648 | TrainAcc=0.6923
Epoch 500: Sup=0.0116 | DGI=0.7026 | Reg=-0.2757 | Total=0.7140 | TrainAcc=1.0000
Epoch 1000: Sup=0.0650 | DGI=0.6933 | Reg=-0.2654 | Total=0.7581 | TrainAcc=1.0000
Epoch 1500: Sup=0.0003 | DGI=0.6932 | Reg=-0.2664 | Total=0.6932 | TrainAcc=1.0000
Epoch 2000: Sup=0.0024 | DGI=0.6932 | Reg=-0.2422 | Total=0.6954 | TrainAcc=1.0000

=== Average Results (Controls vs Patients) ===
Accuracy: 0.6617 ± 0.0000
Precision: 0.6746 ± 0.0000
Recall: 0.9551 ± 0.0000
F1: 0.7907 ± 0.0000
LogLoss: 2.0032 ± 0.0000
AUC: 0.5790 ± 0.0000

=== Fold 2 (Controls vs Patients) ===
Train Controls: 4, Train Patients: 9
Test Controls: 44, Test Patients: 89
Epoch 1: Sup=0.6542 | DGI=0.7093 | Reg=-0.2337 | Total=1.3632 |

In [None]:
import os
import sys
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, log_loss, roc_auc_score
)
import warnings
warnings.filterwarnings("ignore")

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.set_num_threads(4)

# === Load Patients ===
fa_patients_path = "/home/snu/Downloads/NIFD_Patients_FA_Histogram_Feature.npy"
Patients_FA_array = np.load(fa_patients_path, allow_pickle=True)

# === Load Controls ===
fa_controls_path = "/home/snu/Downloads/NIFD_Control_FA_Histogram_Feature.npy"
Controls_FA_array = np.load(fa_controls_path, allow_pickle=True)

print("Patients Shape:", Patients_FA_array.shape)
print("Controls Shape:", Controls_FA_array.shape)

# === Combine features and labels ===
X = np.vstack([Controls_FA_array, Patients_FA_array])
y = np.hstack([
    np.zeros(Controls_FA_array.shape[0], dtype=np.int64),  # 0 = Control
    np.ones(Patients_FA_array.shape[0], dtype=np.int64)    # 1 = Patient
])

# Shuffle
np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]

# ------------------------
# Adjacency building
# ------------------------
def create_adj(F, alpha=1.0):
    norms = np.linalg.norm(F, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    F_norm = F / norms
    W = np.dot(F_norm, F_norm.T)
    W = (W >= alpha).astype(np.float32)
    return W

W0 = create_adj(X, alpha=0.8)
print(f"W0: {W0.shape}")

def load_graph_torch(adj, node_feats):
    node_feats_t = torch.from_numpy(node_feats).float()
    edge_idx = np.array(np.nonzero(adj))
    edge_index = torch.from_numpy(edge_idx).long()
    return node_feats_t, edge_index

node_feats_all, edge_index_all = load_graph_torch(W0, X)
print(f"Number of edges: {edge_index_all.size(1)}")

# ------------------------
# Model components
# ------------------------
class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_stacks=1, num_layers=1, activ="RELU"):
        super(ARMAEncoder, self).__init__()
        activations = {
            "SELU": F.selu,
            "SiLU": F.silu,
            "GELU": F.gelu,
            "ELU": F.elu,
            "RELU": F.relu
        }
        self.act = activations.get(activ, F.elu)
        self.arma = ARMAConv(
            input_dim, hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            shared_weights=True,
            dropout=0.3
        )
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.arma(x, edge_index)
        x = self.act(x)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

class AvgReadout(nn.Module):
    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 0).expand_as(h_pl)
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
        return torch.cat((sc_1, sc_2), 0)

class DGI(nn.Module):
    def __init__(self, n_in, n_h, num_stacks=2, num_layers=1):
        super().__init__()
        self.gcn1 = ARMAEncoder(n_in, n_h, num_stacks=num_stacks, num_layers=num_layers)
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, edge_index):
        data1 = Data(x=seq1, edge_index=edge_index)
        data2 = Data(x=seq2, edge_index=edge_index)
        h_1 = self.gcn1(data1)
        c = self.read(h_1)
        c = self.sigm(c)
        h_2 = self.gcn1(data2)
        logits = self.disc(c, h_1, h_2)
        return logits, h_1

class DGI_with_classifier(DGI):
    def __init__(self, n_in, n_h, n_classes=2, cut=0, device='cpu', num_stacks=1, num_layers=1):
        super().__init__(n_in, n_h, num_stacks=num_stacks, num_layers=num_layers)
        self.classifier = nn.Linear(n_h, n_classes)
        self.cut = cut
        self.device = device
        self.n_clusters = n_classes

    def get_embeddings(self, node_feats, edge_index):
        _, embeddings = self.forward(node_feats, node_feats, edge_index)
        return embeddings

    def cut_loss(self, A, S):
        C = F.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, C).t(), C)
        num = torch.trace(A_pool)
        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, C).t(), C)
        den = torch.trace(D_pooled) + 1e-9
        mincut_loss = -(num / den)
        St_S = torch.matmul(C.t(), C)
        I_S = torch.eye(C.shape[1], device=A.device)
        ortho_loss = torch.norm(St_S / (torch.norm(St_S) + 1e-9) - I_S / (torch.norm(I_S) + 1e-9))
        return mincut_loss + ortho_loss

    def modularity_loss(self, A, S):
        C = F.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A) + 1e-9
        B = A - torch.outer(d, d) / (2 * m)
        I_S = torch.eye(C.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1
        return modularity_term + collapse_reg_term

    def Reg_loss(self, A, embeddings):
        logits = self.classifier(embeddings)
        if self.cut == 1:
            return self.cut_loss(A, logits)
        else:
            return self.modularity_loss(A, logits)

# ------------------------
# Prepare cross-validation
# ------------------------
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=SEED)
accuracies, precisions, recalls, f1_scores, losses, all_auc = [], [], [], [], [], []

def get_device():
    try:
        if torch.cuda.is_available():
            dev = torch.device("cuda")
            torch.tensor([1.0], device=dev) + 1.0
            return dev
    except Exception as e:
        print("CUDA not usable, falling back to CPU:", e)
        try:
            torch.cuda.empty_cache()
        except:
            pass
    return torch.device("cpu")

device = get_device()
print("Using device:", device)

A_tensor = torch.from_numpy(W0).float().to(device)
hidden_dim = 512
cut = 1
num_epochs = 2000
lr = 1e-4
weight_decay = 1e-4
reg_weight = 0.01
num_stacks = 1
num_layers = 1

node_feats = node_feats_all.to(device)
edge_index = edge_index_all.to(device)

N = node_feats.size(0)
lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)]).float()

# Define num_feats
num_feats = X.shape[1]

for fold, (train_idx, test_idx) in enumerate(sss.split(X, y), start=1):
    print(f"\n=== Fold {fold} (Controls vs Patients) ===")

    cn_idx = np.where(y == 0)[0]
    ad_idx = np.where(y == 1)[0]

    sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.1, random_state=fold)
    cn_train_idx, cn_test_idx = next(sss_class.split(X[cn_idx], y[cn_idx]))
    ad_train_idx, ad_test_idx = next(sss_class.split(X[ad_idx], y[ad_idx]))

    cn_train = cn_idx[cn_train_idx]
    ad_train = ad_idx[ad_train_idx]
    cn_test = cn_idx[cn_test_idx]
    ad_test = ad_idx[ad_test_idx]

    balanced_train_idx = np.concatenate([cn_train, ad_train])
    test_idx_final = np.concatenate([cn_test, ad_test])
    np.random.shuffle(balanced_train_idx)
    np.random.shuffle(test_idx_final)

    print(f"Train Controls: {len(cn_train)}, Train Patients: {len(ad_train)}")
    print(f"Test Controls: {len(cn_test)}, Test Patients: {len(ad_test)}")

    y_tensor = torch.from_numpy(y).long().to(device)
    train_idx_t = torch.from_numpy(balanced_train_idx).long().to(device)
    test_idx_t = torch.from_numpy(test_idx_final).long().to(device)

    model = DGI_with_classifier(num_feats, hidden_dim, n_classes=2, cut=cut,
                                device=device, num_stacks=num_stacks, num_layers=num_layers).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    bce_loss = nn.BCEWithLogitsLoss()
    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        perm = torch.randperm(N, device=device)
        corrupt = node_feats[perm]
        logits_dgi, embeddings = model(node_feats, corrupt, edge_index)
        dgi_loss = bce_loss(logits_dgi.squeeze(), lbl)
        logits_cls = model.classifier(embeddings)
        train_logits = logits_cls[train_idx_t]
        train_labels = y_tensor[train_idx_t]
        supervised_loss = ce_loss(train_logits, train_labels)
        reg_loss_val = model.Reg_loss(A_tensor, embeddings)
        total_loss = supervised_loss + dgi_loss + reg_weight * reg_loss_val
        total_loss.backward()
        optimizer.step()

        if epoch % 500 == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval = model.classifier(model.get_embeddings(node_feats, edge_index))
                preds_train = torch.argmax(logits_eval[train_idx_t], dim=1).cpu().numpy()
                acc = accuracy_score(train_labels.cpu().numpy(), preds_train)
            print(f"Epoch {epoch}: Sup={supervised_loss.item():.4f} | DGI={dgi_loss.item():.4f} | "
                  f"Reg={reg_loss_val.item():.4f} | Total={total_loss.item():.4f} | TrainAcc={acc:.4f}")

    model.eval()
    with torch.no_grad():
        emb_final = model.get_embeddings(node_feats, edge_index)
        logits_final = model.classifier(emb_final)
        probs = F.softmax(logits_final, dim=1).cpu().numpy()
        y_pred = np.argmax(probs, axis=1)

    y_test = y[test_idx_final]
    y_pred_test = y_pred[test_idx_final]
    y_proba_test = probs[test_idx_final, 1]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    loss_val = log_loss(y_test, y_proba_test)
    auc_score = roc_auc_score(y_test, y_proba_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    losses.append(loss_val)
    all_auc.append(auc_score)

    print("\n=== Average Results (Controls vs Patients) ===")
    print(f"Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
    print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
    print(f"Recall: {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
    print(f"F1: {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
    print(f"LogLoss: {np.mean(losses):.4f} ± {np.std(losses):.4f}")
    print(f"AUC: {np.mean(all_auc):.4f} ± {np.std(all_auc):.4f}")

Patients Shape: (98, 180)
Controls Shape: (48, 180)
W0: (146, 146)
Number of edges: 13046
Using device: cuda

=== Fold 1 (Controls vs Patients) ===
Train Controls: 43, Train Patients: 88
Test Controls: 5, Test Patients: 10
Epoch 1: Sup=0.7137 | DGI=0.7151 | Reg=-0.2374 | Total=1.4264 | TrainAcc=0.6641
Epoch 500: Sup=0.3139 | DGI=0.6946 | Reg=-0.2925 | Total=1.0056 | TrainAcc=0.9313
Epoch 1000: Sup=0.2526 | DGI=0.6935 | Reg=-0.3039 | Total=0.9431 | TrainAcc=0.9771
Epoch 1500: Sup=0.0992 | DGI=0.6940 | Reg=-0.2747 | Total=0.7904 | TrainAcc=0.9924
Epoch 2000: Sup=0.0849 | DGI=0.6960 | Reg=-0.2864 | Total=0.7780 | TrainAcc=1.0000

=== Average Results (Controls vs Patients) ===
Accuracy: 0.5333 ± 0.0000
Precision: 1.0000 ± 0.0000
Recall: 0.3000 ± 0.0000
F1: 0.4615 ± 0.0000
LogLoss: 1.0476 ± 0.0000
AUC: 0.8000 ± 0.0000

=== Fold 2 (Controls vs Patients) ===
Train Controls: 43, Train Patients: 88
Test Controls: 5, Test Patients: 10
Epoch 1: Sup=0.7008 | DGI=0.7091 | Reg=-0.2371 | Total=1.4075