In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.8.0+cu126


In [None]:
!pip install sympy



In [None]:
!pip install -q torch_geometric
!pip install -q class_resolver
!pip3 install pymatting

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pymatting
  Downloading pymatting-1.1.14-py3-none-any.whl.metadata (7.7 kB)
Downloading pymatting-1.1.14-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.7/54.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymatting
Successfully installed pymatting-1.1.14


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss

In [3]:
fa_patients_path = "/home/snu/Downloads/NIFD_Patients_FA_Histogram_Feature.npy"
Patients_FA_array = np.load(fa_patients_path, allow_pickle=True)

fa_controls_path = "/home/snu/Downloads/NIFD_Control_FA_Histogram_Feature.npy"
Controls_FA_array = np.load(fa_controls_path, allow_pickle=True)

print("Patients Shape:", Patients_FA_array.shape)
print("Controls Shape:", Controls_FA_array.shape)

X = np.vstack([Controls_FA_array, Patients_FA_array])
y = np.hstack([
    np.zeros(Controls_FA_array.shape[0], dtype=np.int64),  # 0 = Control
    np.ones(Patients_FA_array.shape[0], dtype=np.int64)    # 1 = Patient
])

np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]

Patients Shape: (98, 180)
Controls Shape: (48, 180)


In [4]:
def create_adj(F, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = (W >= alpha).astype(np.float32)
    return W

In [5]:
def asymmetrize_random(adj_matrix, seed=None):
    """
    Randomly orient each undirected edge from a symmetric adjacency matrix.
    """
    adj = np.array(adj_matrix, dtype=np.float32)
    n = adj.shape[0]
    asym = np.zeros((n, n), dtype=np.float32)
    rng = np.random.default_rng(seed)

    for i in range(n):
        for j in range(i + 1, n):
            if adj[i, j]:
                if rng.random() < 0.5:
                    asym[i, j] = adj[i, j]
                else:
                    asym[j, i] = adj[i, j]

    return asym

In [6]:
def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    edge_index = torch.from_numpy(np.array(np.nonzero(adj))).long()
    return node_feats, edge_index

In [7]:
features = X.astype(np.float32)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

W0 = create_adj(features, alpha=0.5)
# W_asym = asymmetrize_random(W0, seed=42)
node_feats, edge_index = load_data(W0, features)
data = Data(x=node_feats, edge_index=edge_index).to(device)
A = torch.from_numpy(W0).to(device)
print(data)

Data(x=[146, 180], edge_index=[2, 21256])


In [31]:
import torch.nn.functional as nnFn

class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ="ELU", num_stacks=2, num_layers=1):
        super(ARMAEncoder, self).__init__()
        self.device = device
        # Define all activation functions
        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "ELU": nnFn.elu,
            "RELU": nnFn.relu
        }
        # Get the activation function based on the input string
        self.act = activations.get(activ, nnFn.elu)

        self.arma = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,   # number of parallel stacks
            num_layers=num_layers,   # depth per stack
            act=self.act,               # nonlinearity inside ARMA
            shared_weights=True,     # weight sharing across layers
            dropout=0.25             # ARMA-internal dropout
        )
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.arma(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [32]:
class AvgReadout(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

In [33]:
class Discriminator(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 0).expand_as(h_pl)
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
        logits = torch.cat((sc_1, sc_2), 0)
        return logits

In [34]:
class DGI(nn.Module):
    def __init__(self, n_in, n_h, dropout=0.25):
        super().__init__()
        self.arma1 = ARMAEncoder(n_in, n_h, device='cuda' if torch.cuda.is_available() else 'cpu', activ=nn.ELU())
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, edge_index):
        # Create Data objects for the ARMAEncoder
        data1 = Data(x=seq1, edge_index=edge_index)
        data2 = Data(x=seq2, edge_index=edge_index)

        h_1 = self.arma1(data1)
        c = self.read(h_1)
        c = self.sigm(c)
        h_2 = self.arma1(data2)
        logits = self.disc(c, h_1, h_2)
        return logits, h_1

In [35]:
class DGI_with_classifier(DGI):
    def __init__(self, n_in, n_h, n_classes=2, cut=0, dropout=0.25):
        super().__init__(n_in, n_h, dropout=dropout)
        self.classifier = nn.Linear(n_h, n_classes)
        self.cut = cut

    def get_embeddings(self, node_feats, edge_index):
        _, embeddings = self.forward(node_feats, node_feats, edge_index)
        return embeddings

    def cut_loss(self, A, S):
        S = F.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)
        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)
        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(S.shape[1], device=A.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))
        return mincut_loss + ortho_loss

    def modularity_loss(self, A, S):
        C = F.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)
        I_S = torch.eye(C.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1
        return modularity_term + collapse_reg_term

    def Reg_loss(self, A, embeddings):
        logits = self.classifier(embeddings)
        if self.cut == 1:
            return self.cut_loss(A, logits)
        else:
            return self.modularity_loss(A, logits)

In [36]:
hidden_dim = 512
cut = 1
dropout = 0.2
model = DGI_with_classifier(features.shape[1], hidden_dim, n_classes=2, cut=cut, dropout=dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay = 0.00001)
bce_loss = nn.BCEWithLogitsLoss()

num_epochs = 8000
for epoch in range(num_epochs + 1):
    model.train()
    optimizer.zero_grad()

    perm = torch.randperm(node_feats.size(0))
    corrupt_features = node_feats[perm]

    logits, embeddings = model(node_feats.to(device), corrupt_features.to(device), edge_index.to(device))

    lbl = torch.cat([
        torch.ones(node_feats.size(0)),
        torch.zeros(node_feats.size(0))
    ]).to(device)

    dgi_loss = bce_loss(logits.squeeze(), lbl)
    reg_loss = model.Reg_loss(A, embeddings)
    loss = dgi_loss + 0.1 * reg_loss

    if epoch % 500 == 0:
        print(f"Epoch {epoch} | DGI Loss: {dgi_loss.item():.4f} | Reg Loss: {reg_loss.item():.4f} | Total: {loss.item():.4f}")

    loss.backward()
    optimizer.step()

Epoch 0 | DGI Loss: 0.7129 | Reg Loss: -0.2335 | Total: 0.6896
Epoch 500 | DGI Loss: 0.7381 | Reg Loss: -0.4550 | Total: 0.6926
Epoch 1000 | DGI Loss: 0.7158 | Reg Loss: -0.4528 | Total: 0.6705
Epoch 1500 | DGI Loss: 0.7077 | Reg Loss: -0.4360 | Total: 0.6641
Epoch 2000 | DGI Loss: 0.7062 | Reg Loss: -0.4317 | Total: 0.6630
Epoch 2500 | DGI Loss: 0.6951 | Reg Loss: -0.4748 | Total: 0.6476
Epoch 3000 | DGI Loss: 0.6960 | Reg Loss: -0.4396 | Total: 0.6521
Epoch 3500 | DGI Loss: 0.6952 | Reg Loss: -0.4623 | Total: 0.6490
Epoch 4000 | DGI Loss: 0.7110 | Reg Loss: -0.4783 | Total: 0.6632
Epoch 4500 | DGI Loss: 0.6952 | Reg Loss: -0.4743 | Total: 0.6478
Epoch 5000 | DGI Loss: 0.7508 | Reg Loss: -0.4753 | Total: 0.7033
Epoch 5500 | DGI Loss: 0.6937 | Reg Loss: -0.4642 | Total: 0.6473
Epoch 6000 | DGI Loss: 0.6936 | Reg Loss: -0.4679 | Total: 0.6468
Epoch 6500 | DGI Loss: 0.6932 | Reg Loss: -0.4791 | Total: 0.6453
Epoch 7000 | DGI Loss: 0.6940 | Reg Loss: -0.4764 | Total: 0.6464
Epoch 7500 | D

In [37]:
model.eval()
with torch.no_grad():
    embeddings = model.get_embeddings(node_feats.to(device), edge_index.to(device))
    class_probabilities = F.softmax(model.classifier(embeddings), dim=1).cpu().numpy()

y_pred = np.argmax(class_probabilities, axis=1)

In [38]:
acc_score = accuracy_score(y, y_pred)
acc_score_inverted = accuracy_score(y, 1 - y_pred)
prec_score = precision_score(y, y_pred)
rec_score = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)
log_loss_value = log_loss(y, class_probabilities)

print("Accuracy:", acc_score)
print("Accuracy (inverted):", acc_score_inverted)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.4315068493150685
Accuracy (inverted): 0.5684931506849316
Precision: 0.6027397260273972
Recall: 0.4489795918367347
F1: 0.5146198830409356
Log Loss: 2.9233120571306195


In [39]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss

hidden_dim   = 512
cut          = 1
dropout      = 0.25
num_runs     = 10
num_epochs   = 8000
lambda_list  = [5]
#lambda_list  = [8]
base_seed    = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

node_feats = node_feats.to(device)
edge_index = edge_index.to(device)
A = A.to(device)

if isinstance(y, torch.Tensor):
    y_np = y.detach().cpu().numpy().astype(int)
else:
    y_np = np.asarray(y).astype(int)

N, feats_dim = node_feats.size(0), node_feats.size(1)

all_results = []
bce_loss = nn.BCEWithLogitsLoss()

for lam in lambda_list:
    print(f"\n================ LAMBDA = {lam} ================\n")

    acc_scores, prec_scores, rec_scores, f1_scores, log_losses = [], [], [], [], []

    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")

        seed = base_seed + run
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)


        model = DGI_with_classifier(feats_dim, hidden_dim, n_classes=2, cut=cut, dropout=dropout).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00001)
        scheduler = StepLR(optimizer, step_size=200, gamma=0.5)


        for epoch in range(num_epochs + 1):
            model.train()
            optimizer.zero_grad()

            perm = torch.randperm(N, device=device)
            corrupt_features = node_feats[perm]

            logits, embeddings = model(node_feats, corrupt_features, edge_index)

            lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])
            dgi_loss = bce_loss(logits.squeeze(), lbl)
            reg_loss = model.Reg_loss(A, embeddings)

            loss = dgi_loss + lam * reg_loss

            if epoch % 500 == 0:
                print(f"Epoch {epoch:4d} | DGI: {dgi_loss.item():.4f} | Reg: {reg_loss.item():.4f} | "
                      f"λ*Reg: {(lam * reg_loss).item():.4f} | Total: {loss.item():.4f}")

            loss.backward()
            optimizer.step()
            scheduler.step()

        model.eval()
        with torch.no_grad():
            emb = model.get_embeddings(node_feats, edge_index)
            logits_cls = model.classifier(emb)                   # [N, 2]
            class_probabilities = F.softmax(logits_cls, dim=1).cpu().numpy()
            y_pred = np.argmax(class_probabilities, axis=1)

        acc  = accuracy_score(y_np, y_pred)
        acc_inv = accuracy_score(y_np, 1 - y_pred)

        if acc_inv > acc:
            acc = acc_inv
            y_pred = 1 - y_pred
            class_probabilities = class_probabilities[:, ::-1]

        prec = precision_score(y_np, y_pred, zero_division=0)
        rec  = recall_score(y_np, y_pred, zero_division=0)
        f1   = f1_score(y_np, y_pred, zero_division=0)
        ll   = log_loss(y_np, class_probabilities)

        print(f"Run {run+1} | Accuracy: {acc:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f} | LogLoss: {ll:.4f}")

        acc_scores.append(acc)
        prec_scores.append(prec)
        rec_scores.append(rec)
        f1_scores.append(f1)
        log_losses.append(ll)

    lambda_results = {
        "lambda": lam,
        "accuracy":  (float(np.mean(acc_scores)), float(np.std(acc_scores))),
        "precision": (float(np.mean(prec_scores)), float(np.std(prec_scores))),
        "recall":    (float(np.mean(rec_scores)), float(np.std(rec_scores))),
        "f1":        (float(np.mean(f1_scores)),  float(np.std(f1_scores))),
        "log_loss":  (float(np.mean(log_losses)), float(np.std(log_losses))),
    }
    all_results.append(lambda_results)

    print(f"\n--- RESULTS FOR LAMBDA = {lam} ---")
    print(f"Accuracy : {lambda_results['accuracy'][0]:.4f} ± {lambda_results['accuracy'][1]:.4f}")
    print(f"Precision: {lambda_results['precision'][0]:.4f} ± {lambda_results['precision'][1]:.4f}")
    print(f"Recall   : {lambda_results['recall'][0]:.4f} ± {lambda_results['recall'][1]:.4f}")
    print(f"F1 Score : {lambda_results['f1'][0]:.4f} ± {lambda_results['f1'][1]:.4f}")
    print(f"Log Loss : {lambda_results['log_loss'][0]:.4f} ± {lambda_results['log_loss'][1]:.4f}")

print("\n================ FINAL SUMMARY FOR ALL LAMBDAS ================\n")
print(f"{'Lambda':>8} | {'Accuracy':>18} | {'Precision':>18} | {'Recall':>18} | {'F1 Score':>18} | {'Log Loss':>18}")
print("-" * 108)
for res in all_results:
    print(f"{res['lambda']:>8} | "
          f"{res['accuracy'][0]:.4f} ± {res['accuracy'][1]:.4f} | "
          f"{res['precision'][0]:.4f} ± {res['precision'][1]:.4f} | "
          f"{res['recall'][0]:.4f} ± {res['recall'][1]:.4f} | "
          f"{res['f1'][0]:.4f} ± {res['f1'][1]:.4f} | "
          f"{res['log_loss'][0]:.4f} ± {res['log_loss'][1]:.4f}")




--- Run 1/10 ---
Epoch    0 | DGI: 0.7146 | Reg: -0.2334 | λ*Reg: -1.1671 | Total: -0.4525
Epoch  500 | DGI: 0.6933 | Reg: -0.4906 | λ*Reg: -2.4531 | Total: -1.7598
Epoch 1000 | DGI: 0.6931 | Reg: -0.4867 | λ*Reg: -2.4333 | Total: -1.7402
Epoch 1500 | DGI: 0.6931 | Reg: -0.4886 | λ*Reg: -2.4430 | Total: -1.7498
Epoch 2000 | DGI: 0.6931 | Reg: -0.4888 | λ*Reg: -2.4439 | Total: -1.7508
Epoch 2500 | DGI: 0.6931 | Reg: -0.4778 | λ*Reg: -2.3888 | Total: -1.6956
Epoch 3000 | DGI: 0.6931 | Reg: -0.4840 | λ*Reg: -2.4198 | Total: -1.7266
Epoch 3500 | DGI: 0.6931 | Reg: -0.4854 | λ*Reg: -2.4271 | Total: -1.7339
Epoch 4000 | DGI: 0.6931 | Reg: -0.4923 | λ*Reg: -2.4616 | Total: -1.7684
Epoch 4500 | DGI: 0.6931 | Reg: -0.4887 | λ*Reg: -2.4435 | Total: -1.7503
Epoch 5000 | DGI: 0.6931 | Reg: -0.4713 | λ*Reg: -2.3564 | Total: -1.6633
Epoch 5500 | DGI: 0.6931 | Reg: -0.4862 | λ*Reg: -2.4312 | Total: -1.7380
Epoch 6000 | DGI: 0.6931 | Reg: -0.4881 | λ*Reg: -2.4405 | Total: -1.7473
Epoch 6500 | DGI: 

1 | 0.6699 ± 0.0193 | 0.8443 ± 0.0227 | 0.6235 ± 0.0211 | 0.7170 ± 0.0175 | 0.5985 ± 0.0127

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss

hidden_dim   = 256
cut          = 0
dropout      = 0.25
num_runs     = 10
num_epochs   = 10000
lambda_list  = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]
#lambda_list  = [8]
base_seed    = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

node_feats = node_feats.to(device)
edge_index = edge_index.to(device)
A = A.to(device)

if isinstance(y, torch.Tensor):
    y_np = y.detach().cpu().numpy().astype(int)
else:
    y_np = np.asarray(y).astype(int)

N, feats_dim = node_feats.size(0), node_feats.size(1)

all_results = []
bce_loss = nn.BCEWithLogitsLoss()

for lam in lambda_list:
    print(f"\n================ LAMBDA = {lam} ================\n")

    acc_scores, prec_scores, rec_scores, f1_scores, log_losses = [], [], [], [], []

    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")

        seed = base_seed + run
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)


        model = DGI_with_classifier(feats_dim, hidden_dim, n_classes=2, cut=cut, dropout=dropout).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
        scheduler = StepLR(optimizer, step_size=200, gamma=0.5)


        for epoch in range(num_epochs + 1):
            model.train()
            optimizer.zero_grad()

            perm = torch.randperm(N, device=device)
            corrupt_features = node_feats[perm]

            logits, embeddings = model(node_feats, corrupt_features, edge_index)

            lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])
            dgi_loss = bce_loss(logits.squeeze(), lbl)
            reg_loss = model.Reg_loss(A, embeddings)

            loss = dgi_loss + lam * reg_loss

            if epoch % 500 == 0:
                print(f"Epoch {epoch:4d} | DGI: {dgi_loss.item():.4f} | Reg: {reg_loss.item():.4f} | "
                      f"λ*Reg: {(lam * reg_loss).item():.4f} | Total: {loss.item():.4f}")

            loss.backward()
            optimizer.step()
            scheduler.step()

        model.eval()
        with torch.no_grad():
            emb = model.get_embeddings(node_feats, edge_index)
            logits_cls = model.classifier(emb)                   # [N, 2]
            class_probabilities = F.softmax(logits_cls, dim=1).cpu().numpy()
            y_pred = np.argmax(class_probabilities, axis=1)

        acc  = accuracy_score(y_np, y_pred)
        acc_inv = accuracy_score(y_np, 1 - y_pred)

        if acc_inv > acc:
            acc = acc_inv
            y_pred = 1 - y_pred
            class_probabilities = class_probabilities[:, ::-1]

        prec = precision_score(y_np, y_pred, zero_division=0)
        rec  = recall_score(y_np, y_pred, zero_division=0)
        f1   = f1_score(y_np, y_pred, zero_division=0)
        ll   = log_loss(y_np, class_probabilities)

        print(f"Run {run+1} | Accuracy: {acc:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f} | LogLoss: {ll:.4f}")

        acc_scores.append(acc)
        prec_scores.append(prec)
        rec_scores.append(rec)
        f1_scores.append(f1)
        log_losses.append(ll)

    lambda_results = {
        "lambda": lam,
        "accuracy":  (float(np.mean(acc_scores)), float(np.std(acc_scores))),
        "precision": (float(np.mean(prec_scores)), float(np.std(prec_scores))),
        "recall":    (float(np.mean(rec_scores)), float(np.std(rec_scores))),
        "f1":        (float(np.mean(f1_scores)),  float(np.std(f1_scores))),
        "log_loss":  (float(np.mean(log_losses)), float(np.std(log_losses))),
    }
    all_results.append(lambda_results)

    print(f"\n--- RESULTS FOR LAMBDA = {lam} ---")
    print(f"Accuracy : {lambda_results['accuracy'][0]:.4f} ± {lambda_results['accuracy'][1]:.4f}")
    print(f"Precision: {lambda_results['precision'][0]:.4f} ± {lambda_results['precision'][1]:.4f}")
    print(f"Recall   : {lambda_results['recall'][0]:.4f} ± {lambda_results['recall'][1]:.4f}")
    print(f"F1 Score : {lambda_results['f1'][0]:.4f} ± {lambda_results['f1'][1]:.4f}")
    print(f"Log Loss : {lambda_results['log_loss'][0]:.4f} ± {lambda_results['log_loss'][1]:.4f}")

print("\n================ FINAL SUMMARY FOR ALL LAMBDAS ================\n")
print(f"{'Lambda':>8} | {'Accuracy':>18} | {'Precision':>18} | {'Recall':>18} | {'F1 Score':>18} | {'Log Loss':>18}")
print("-" * 108)
for res in all_results:
    print(f"{res['lambda']:>8} | "
          f"{res['accuracy'][0]:.4f} ± {res['accuracy'][1]:.4f} | "
          f"{res['precision'][0]:.4f} ± {res['precision'][1]:.4f} | "
          f"{res['recall'][0]:.4f} ± {res['recall'][1]:.4f} | "
          f"{res['f1'][0]:.4f} ± {res['f1'][1]:.4f} | "
          f"{res['log_loss'][0]:.4f} ± {res['log_loss'][1]:.4f}")




--- Run 1/10 ---
Epoch    0 | DGI: 0.7187 | Reg: -0.2839 | λ*Reg: -0.0003 | Total: 0.7184
Epoch  500 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 1000 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 1500 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 2000 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 2500 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 3000 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 3500 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 4000 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 4500 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 5000 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 5500 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 6000 | DGI: 0.6931 | Reg: -0.2841 | λ*Reg: -0.0003 | Total: 0.6929
Epoch 6500 | DGI: 0.6931 | Reg:

================ FINAL SUMMARY FOR ALL LAMBDAS ================

  Lambda |           Accuracy |          Precision |             Recall |           F1 Score |           Log Loss
------------------------------------------------------------------------------------------------------------
      10 | 0.6849 ± 0.0156 | 0.8672 ± 0.0123 | 0.6265 ± 0.0200 | 0.7273 ± 0.0158 | 0.6057 ± 0.0015
      12 | 0.6849 ± 0.0147 | 0.8683 ± 0.0122 | 0.6255 ± 0.0177 | 0.7271 ± 0.0145 | 0.6055 ± 0.0015


with 1e-4 both

================ FINAL SUMMARY FOR ALL LAMBDAS ================

  Lambda |           Accuracy |          Precision |             Recall |           F1 Score |           Log Loss
------------------------------------------------------------------------------------------------------------
   0.001 | 0.5767 ± 0.0239 | 0.7258 ± 0.0417 | 0.6061 ± 0.0747 | 0.6558 ± 0.0315 | 0.6883 ± 0.0027
   0.005 | 0.5747 ± 0.0280 | 0.7356 ± 0.0404 | 0.5796 ± 0.0560 | 0.6453 ± 0.0315 | 0.6890 ± 0.0025
   0.009 | 0.5788 ± 0.0319 | 0.7419 ± 0.0397 | 0.5765 ± 0.0523 | 0.6465 ± 0.0339 | 0.6892 ± 0.0026
    0.01 | 0.5774 ± 0.0336 | 0.7440 ± 0.0434 | 0.5704 ± 0.0499 | 0.6435 ± 0.0338 | 0.6892 ± 0.0026
    0.05 | 0.5733 ± 0.0348 | 0.7646 ± 0.0577 | 0.5327 ± 0.0480 | 0.6255 ± 0.0348 | 0.6869 ± 0.0040
    0.09 | 0.5829 ± 0.0423 | 0.7777 ± 0.0631 | 0.5367 ± 0.0496 | 0.6326 ± 0.0387 | 0.6853 ± 0.0067
     0.1 | 0.5815 ± 0.0409 | 0.7769 ± 0.0621 | 0.5347 ± 0.0424 | 0.6314 ± 0.0347 | 0.6848 ± 0.0071
     0.3 | 0.5877 ± 0.0324 | 0.7974 ± 0.0472 | 0.5204 ± 0.0376 | 0.6284 ± 0.0307 | 0.6747 ± 0.0121
     0.5 | 0.6041 ± 0.0304 | 0.8200 ± 0.0395 | 0.5276 ± 0.0390 | 0.6409 ± 0.0314 | 0.6675 ± 0.0159
     0.9 | 0.6014 ± 0.0239 | 0.8247 ± 0.0320 | 0.5163 ± 0.0189 | 0.6349 ± 0.0213 | 0.6548 ± 0.0207
       1 | 0.6027 ± 0.0247 | 0.8219 ± 0.0311 | 0.5214 ± 0.0191 | 0.6379 ± 0.0223 | 0.6520 ± 0.0214
       2 | 0.6295 ± 0.0213 | 0.8377 ± 0.0157 | 0.5561 ± 0.0382 | 0.6675 ± 0.0287 | 0.6308 ± 0.0204
       5 | 0.6767 ± 0.0156 | 0.8652 ± 0.0205 | 0.6143 ± 0.0119 | 0.7184 ± 0.0131 | 0.6093 ± 0.0026
       8 | 0.6863 ± 0.0175 | 0.8740 ± 0.0144 | 0.6224 ± 0.0233 | 0.7269 ± 0.0181 | 0.6063 ± 0.0016

with 1-3, 1e-4

================ FINAL SUMMARY FOR ALL LAMBDAS ================

  Lambda |           Accuracy |          Precision |             Recall |           F1 Score |           Log Loss
------------------------------------------------------------------------------------------------------------
   0.001 | 0.6212 ± 0.0575 | 0.7334 ± 0.0713 | 0.7449 ± 0.2504 | 0.7037 ± 0.1095 | 0.6931 ± 0.0000
   0.005 | 0.6034 ± 0.0487 | 0.7564 ± 0.0639 | 0.6265 ± 0.1474 | 0.6722 ± 0.0649 | 0.6929 ± 0.0004
   0.009 | 0.6034 ± 0.0384 | 0.7728 ± 0.0593 | 0.5969 ± 0.1252 | 0.6636 ± 0.0524 | 0.6926 ± 0.0009
    0.01 | 0.5925 ± 0.0446 | 0.7830 ± 0.0525 | 0.5469 ± 0.0607 | 0.6416 ± 0.0494 | 0.6926 ± 0.0010
    0.05 | 0.6164 ± 0.0223 | 0.8261 ± 0.0191 | 0.5439 ± 0.0456 | 0.6545 ± 0.0318 | 0.6793 ± 0.0118
    0.09 | 0.6096 ± 0.0223 | 0.8224 ± 0.0194 | 0.5347 ± 0.0448 | 0.6466 ± 0.0321 | 0.6675 ± 0.0117
     0.1 | 0.6110 ± 0.0205 | 0.8229 ± 0.0200 | 0.5367 ± 0.0414 | 0.6485 ± 0.0292 | 0.6655 ± 0.0114
     0.3 | 0.6103 ± 0.0230 | 0.8220 ± 0.0209 | 0.5357 ± 0.0363 | 0.6479 ± 0.0283 | 0.6510 ± 0.0094
     0.5 | 0.6130 ± 0.0211 | 0.8220 ± 0.0175 | 0.5408 ± 0.0348 | 0.6517 ± 0.0265 | 0.6459 ± 0.0093
     0.9 | 0.6219 ± 0.0203 | 0.8275 ± 0.0169 | 0.5520 ± 0.0324 | 0.6617 ± 0.0250 | 0.6390 ± 0.0094
       1 | 0.6233 ± 0.0217 | 0.8290 ± 0.0182 | 0.5531 ± 0.0332 | 0.6629 ± 0.0261 | 0.6375 ± 0.0094
       2 | 0.6363 ± 0.0285 | 0.8387 ± 0.0258 | 0.5673 ± 0.0377 | 0.6762 ± 0.0320 | 0.6249 ± 0.0091
       5 | 0.6712 ± 0.0309 | 0.8603 ± 0.0287 | 0.6092 ± 0.0342 | 0.7129 ± 0.0299 | 0.6079 ± 0.0042
       8 | 0.6699 ± 0.0086 | 0.8558 ± 0.0075 | 0.6112 ± 0.0148 | 0.7130 ± 0.0100 | 0.6054 ± 0.0025