In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.8.0+cu126


In [None]:
!pip install sympy



In [None]:
!pip install -q torch_geometric
!pip install -q class_resolver
!pip3 install pymatting

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pymatting
  Downloading pymatting-1.1.14-py3-none-any.whl.metadata (7.7 kB)
Downloading pymatting-1.1.14-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.7/54.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymatting
Successfully installed pymatting-1.1.14


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss

In [3]:
fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Load AD features
fa_feature_path = "/home/snu/Downloads/Histogram_AD_FA_20bin_updated.npy"
Histogram_feature_AD_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Combine features and labels
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)
])
np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]
num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

Features: (223, 180), Labels: (223,)


In [4]:
def create_adj(F, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = (W >= alpha).astype(np.float32)
    return W

In [5]:
def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    edge_index = torch.from_numpy(np.array(np.nonzero(adj))).long()
    return node_feats, edge_index

In [6]:
features = X.astype(np.float32)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

W0 = create_adj(features, alpha=0.8)
# W_asym = asymmetrize_random(W0, seed=42)
node_feats, edge_index = load_data(W0, features)
data = Data(x=node_feats, edge_index=edge_index).to(device)
A = torch.from_numpy(W0).to(device)
print(data)

Data(x=[223, 180], edge_index=[2, 41689])


In [7]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ):
        super(GCNEncoder, self).__init__()
        self.device = device
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.25)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gcn1(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [8]:
class AvgReadout(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

In [9]:
class Discriminator(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 0).expand_as(h_pl)
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
        logits = torch.cat((sc_1, sc_2), 0)
        return logits

In [10]:
class DGI(nn.Module):
    def __init__(self, n_in, n_h, dropout=0.25):
        super().__init__()
        self.gcn1 = GCNEncoder(n_in, n_h, device='cuda' if torch.cuda.is_available() else 'cpu', activ=nn.ELU())
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, edge_index):
        # Create Data objects for the GCNEncoder
        data1 = Data(x=seq1, edge_index=edge_index)
        data2 = Data(x=seq2, edge_index=edge_index)

        h_1 = self.gcn1(data1)
        c = self.read(h_1)
        c = self.sigm(c)
        h_2 = self.gcn1(data2)
        logits = self.disc(c, h_1, h_2)
        return logits, h_1

In [11]:
class DGI_with_classifier(DGI):
    def __init__(self, n_in, n_h, n_classes=2, cut=0, dropout=0.25):
        super().__init__(n_in, n_h, dropout=dropout)
        self.classifier = nn.Linear(n_h, n_classes)
        self.cut = cut

    def get_embeddings(self, node_feats, edge_index):
        _, embeddings = self.forward(node_feats, node_feats, edge_index)
        return embeddings

    def cut_loss(self, A, S):
        S = F.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)
        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)
        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(S.shape[1], device=A.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))
        return mincut_loss + ortho_loss

    def modularity_loss(self, A, S):
        C = F.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)
        I_S = torch.eye(C.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1
        return modularity_term + collapse_reg_term

    def Reg_loss(self, A, embeddings):
        logits = self.classifier(embeddings)
        if self.cut == 1:
            return self.cut_loss(A, logits)
        else:
            return self.modularity_loss(A, logits)

In [12]:
hidden_dim = 256
cut = 1
dropout = 0.25
model = DGI_with_classifier(features.shape[1], hidden_dim, n_classes=2, cut=cut, dropout=dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay = 0.00001)
bce_loss = nn.BCEWithLogitsLoss()

num_epochs = 10000
for epoch in range(num_epochs + 1):
    model.train()
    optimizer.zero_grad()

    perm = torch.randperm(node_feats.size(0))
    corrupt_features = node_feats[perm]

    logits, embeddings = model(node_feats.to(device), corrupt_features.to(device), edge_index.to(device))

    lbl = torch.cat([
        torch.ones(node_feats.size(0)),
        torch.zeros(node_feats.size(0))
    ]).to(device)

    dgi_loss = bce_loss(logits.squeeze(), lbl)
    reg_loss = model.Reg_loss(A, embeddings)
    loss = dgi_loss + 0.01 * reg_loss

    if epoch % 500 == 0:
        print(f"Epoch {epoch} | DGI Loss: {dgi_loss.item():.4f} | Reg Loss: {reg_loss.item():.4f} | Total: {loss.item():.4f}")

    loss.backward()
    optimizer.step()

Epoch 0 | DGI Loss: 0.7121 | Reg Loss: -0.2351 | Total: 0.7097
Epoch 500 | DGI Loss: 0.6962 | Reg Loss: -0.4995 | Total: 0.6912
Epoch 1000 | DGI Loss: 0.6946 | Reg Loss: -0.4709 | Total: 0.6899
Epoch 1500 | DGI Loss: 0.6949 | Reg Loss: -0.4802 | Total: 0.6901
Epoch 2000 | DGI Loss: 0.6933 | Reg Loss: -0.5156 | Total: 0.6882
Epoch 2500 | DGI Loss: 0.6932 | Reg Loss: -0.5073 | Total: 0.6881
Epoch 3000 | DGI Loss: 0.6979 | Reg Loss: -0.5058 | Total: 0.6928
Epoch 3500 | DGI Loss: 0.6931 | Reg Loss: -0.4869 | Total: 0.6883
Epoch 4000 | DGI Loss: 0.6941 | Reg Loss: -0.5066 | Total: 0.6891
Epoch 4500 | DGI Loss: 0.6932 | Reg Loss: -0.4506 | Total: 0.6887
Epoch 5000 | DGI Loss: 0.6931 | Reg Loss: -0.5021 | Total: 0.6881
Epoch 5500 | DGI Loss: 0.6932 | Reg Loss: -0.5216 | Total: 0.6879
Epoch 6000 | DGI Loss: 0.6931 | Reg Loss: -0.5198 | Total: 0.6879
Epoch 6500 | DGI Loss: 0.6931 | Reg Loss: -0.5309 | Total: 0.6878
Epoch 7000 | DGI Loss: 0.6931 | Reg Loss: -0.5309 | Total: 0.6878
Epoch 7500 | D

In [13]:
model.eval()
with torch.no_grad():
    embeddings = model.get_embeddings(node_feats.to(device), edge_index.to(device))
    class_probabilities = F.softmax(model.classifier(embeddings), dim=1).cpu().numpy()

y_pred = np.argmax(class_probabilities, axis=1)

In [14]:
acc_score = accuracy_score(y, y_pred)
acc_score_inverted = accuracy_score(y, 1 - y_pred)
prec_score = precision_score(y, y_pred)
rec_score = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)
log_loss_value = log_loss(y, class_probabilities)

print("Accuracy:", acc_score)
print("Accuracy (inverted):", acc_score_inverted)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.7219730941704036
Accuracy (inverted): 0.27802690582959644
Precision: 0.6166666666666667
Recall: 0.8222222222222222
F1: 0.7047619047619048
Log Loss: 2.5423930391929392


In [15]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss

hidden_dim   = 256
cut          = 1
dropout      = 0.25
num_runs     = 10
num_epochs   = 10000
lambda_list  = [0.01]
base_seed    = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

node_feats = node_feats.to(device)
edge_index = edge_index.to(device)
A = A.to(device)

if isinstance(y, torch.Tensor):
    y_np = y.detach().cpu().numpy().astype(int)
else:
    y_np = np.asarray(y).astype(int)

N, feats_dim = node_feats.size(0), node_feats.size(1)

all_results = []
bce_loss = nn.BCEWithLogitsLoss()

for lam in lambda_list:
    print(f"\n================ LAMBDA = {lam} ================\n")

    acc_scores, prec_scores, rec_scores, f1_scores, log_losses = [], [], [], [], []

    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")

        seed = base_seed + run
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)


        model = DGI_with_classifier(feats_dim, hidden_dim, n_classes=2, cut=cut, dropout=dropout).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00001)
        scheduler = StepLR(optimizer, step_size=200, gamma=0.5)


        for epoch in range(num_epochs + 1):
            model.train()
            optimizer.zero_grad()

            perm = torch.randperm(N, device=device)
            corrupt_features = node_feats[perm]

            logits, embeddings = model(node_feats, corrupt_features, edge_index)

            lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])
            dgi_loss = bce_loss(logits.squeeze(), lbl)
            reg_loss = model.Reg_loss(A, embeddings)

            loss = dgi_loss + lam * reg_loss

            if epoch % 500 == 0:
                print(f"Epoch {epoch:4d} | DGI: {dgi_loss.item():.4f} | Reg: {reg_loss.item():.4f} | "
                      f"λ*Reg: {(lam * reg_loss).item():.4f} | Total: {loss.item():.4f}")

            loss.backward()
            optimizer.step()
            scheduler.step()

        model.eval()
        with torch.no_grad():
            emb = model.get_embeddings(node_feats, edge_index)
            logits_cls = model.classifier(emb)                   # [N, 2]
            class_probabilities = F.softmax(logits_cls, dim=1).cpu().numpy()
            y_pred = np.argmax(class_probabilities, axis=1)

        acc  = accuracy_score(y_np, y_pred)
        acc_inv = accuracy_score(y_np, 1 - y_pred)

        if acc_inv > acc:
            acc = acc_inv
            y_pred = 1 - y_pred
            class_probabilities = class_probabilities[:, ::-1]

        prec = precision_score(y_np, y_pred, zero_division=0)
        rec  = recall_score(y_np, y_pred, zero_division=0)
        f1   = f1_score(y_np, y_pred, zero_division=0)
        ll   = log_loss(y_np, class_probabilities)

        print(f"Run {run+1} | Accuracy: {acc:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f} | LogLoss: {ll:.4f}")

        acc_scores.append(acc)
        prec_scores.append(prec)
        rec_scores.append(rec)
        f1_scores.append(f1)
        log_losses.append(ll)

    lambda_results = {
        "lambda": lam,
        "accuracy":  (float(np.mean(acc_scores)), float(np.std(acc_scores))),
        "precision": (float(np.mean(prec_scores)), float(np.std(prec_scores))),
        "recall":    (float(np.mean(rec_scores)), float(np.std(rec_scores))),
        "f1":        (float(np.mean(f1_scores)),  float(np.std(f1_scores))),
        "log_loss":  (float(np.mean(log_losses)), float(np.std(log_losses))),
    }
    all_results.append(lambda_results)

    print(f"\n--- RESULTS FOR LAMBDA = {lam} ---")
    print(f"Accuracy : {lambda_results['accuracy'][0]:.4f} ± {lambda_results['accuracy'][1]:.4f}")
    print(f"Precision: {lambda_results['precision'][0]:.4f} ± {lambda_results['precision'][1]:.4f}")
    print(f"Recall   : {lambda_results['recall'][0]:.4f} ± {lambda_results['recall'][1]:.4f}")
    print(f"F1 Score : {lambda_results['f1'][0]:.4f} ± {lambda_results['f1'][1]:.4f}")
    print(f"Log Loss : {lambda_results['log_loss'][0]:.4f} ± {lambda_results['log_loss'][1]:.4f}")

print("\n================ FINAL SUMMARY FOR ALL LAMBDAS ================\n")
print(f"{'Lambda':>8} | {'Accuracy':>18} | {'Precision':>18} | {'Recall':>18} | {'F1 Score':>18} | {'Log Loss':>18}")
print("-" * 108)
for res in all_results:
    print(f"{res['lambda']:>8} | "
          f"{res['accuracy'][0]:.4f} ± {res['accuracy'][1]:.4f} | "
          f"{res['precision'][0]:.4f} ± {res['precision'][1]:.4f} | "
          f"{res['recall'][0]:.4f} ± {res['recall'][1]:.4f} | "
          f"{res['f1'][0]:.4f} ± {res['f1'][1]:.4f} | "
          f"{res['log_loss'][0]:.4f} ± {res['log_loss'][1]:.4f}")




--- Run 1/10 ---
Epoch    0 | DGI: 0.7155 | Reg: -0.2433 | λ*Reg: -0.0024 | Total: 0.7130
Epoch  500 | DGI: 0.6932 | Reg: -0.5248 | λ*Reg: -0.0052 | Total: 0.6879
Epoch 1000 | DGI: 0.6931 | Reg: -0.5274 | λ*Reg: -0.0053 | Total: 0.6879
Epoch 1500 | DGI: 0.6931 | Reg: -0.5278 | λ*Reg: -0.0053 | Total: 0.6879
Epoch 2000 | DGI: 0.6931 | Reg: -0.5342 | λ*Reg: -0.0053 | Total: 0.6878
Epoch 2500 | DGI: 0.6931 | Reg: -0.5258 | λ*Reg: -0.0053 | Total: 0.6879
Epoch 3000 | DGI: 0.6931 | Reg: -0.5290 | λ*Reg: -0.0053 | Total: 0.6879
Epoch 3500 | DGI: 0.6931 | Reg: -0.5317 | λ*Reg: -0.0053 | Total: 0.6878
Epoch 4000 | DGI: 0.6931 | Reg: -0.5286 | λ*Reg: -0.0053 | Total: 0.6879
Epoch 4500 | DGI: 0.6931 | Reg: -0.5314 | λ*Reg: -0.0053 | Total: 0.6878
Epoch 5000 | DGI: 0.6931 | Reg: -0.5292 | λ*Reg: -0.0053 | Total: 0.6879
Epoch 5500 | DGI: 0.6931 | Reg: -0.5275 | λ*Reg: -0.0053 | Total: 0.6879
Epoch 6000 | DGI: 0.6931 | Reg: -0.5324 | λ*Reg: -0.0053 | Total: 0.6878
Epoch 6500 | DGI: 0.6931 | Reg:

Accuracy: 0.7533333333333333
Precision: 0.7360406091370558
Recall: 0.8682634730538922
F1: 0.7967032967032966
Log Loss: 0.5772490248961657