In [1]:
import numpy as np
import random
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_undirected
from torch_sparse import SparseTensor

from sklearn.metrics import (
    accuracy_score,
    f1_score,
    normalized_mutual_info_score,
    adjusted_rand_score,
    precision_score,
    recall_score,
    roc_auc_score
)
from sklearn.cluster import SpectralClustering
from sklearn.cluster import KMeans

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
fa_patients = np.load("/home/snu/Downloads/NIFD_Patients_FA_Histogram_Feature.npy", allow_pickle=True)
fa_controls = np.load("/home/snu/Downloads/NIFD_Control_FA_Histogram_Feature.npy", allow_pickle=True)

X = np.vstack([fa_patients, fa_controls]).astype(np.float32)
y = np.hstack([
    np.zeros(len(fa_patients), dtype=np.int64),
    np.ones(len(fa_controls), dtype=np.int64)
])
np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]
N, F_dim = X.shape
print("Nodes:", N, "Features:", F_dim)

x = torch.from_numpy(X).to(device)
y_torch = torch.from_numpy(y).to(device)

Nodes: 146 Features: 180


In [4]:
print(fa_patients.shape)
print(fa_controls.shape)

(98, 180)
(48, 180)


In [4]:
def create_adj(features, alpha=0.5):
    f = features / np.linalg.norm(features, axis=1, keepdims=True)
    W = np.dot(f, f.T)
    W = (W >= alpha).astype(np.float32)
    return W

W = create_adj(X, alpha=0.5)

rows, cols = np.nonzero(W)
edge_index = torch.tensor([rows, cols], dtype=torch.long)
edge_index = to_undirected(edge_index).to(device)

adj = SparseTensor(
    row=edge_index[0],
    col=edge_index[1],
    sparse_sizes=(N, N)
).fill_value(1.).to(device)

  edge_index = torch.tensor([rows, cols], dtype=torch.long)


In [5]:
def get_sim(batch, adj, wt=20, wl=2):
    rowptr, col, _ = adj.csr()
    batch_size = batch.shape[0]
    batch_repeat = batch.repeat(wt)
    rw = adj.random_walk(batch_repeat, wl)[:, 1:]
    rw = rw.t().reshape(-1, batch_size).t()

    row, col_, val = [], [], []
    for i in range(batch.shape[0]):
        rw_nodes, rw_times = torch.unique(rw[i], return_counts=True)
        row += [batch[i].item()] * rw_nodes.shape[0]
        col_ += rw_nodes.tolist()
        val += rw_times.tolist()

    adj_rw = SparseTensor(
        row=torch.tensor(row),
        col=torch.tensor(col_),
        value=torch.tensor(val),
        sparse_sizes=(batch.shape[0], batch.shape[0])
    )
    adj_rw = adj_rw.set_diag(0.)
    return adj_rw

In [6]:
def get_mask(adj):
    mean = adj.mean(dim=1)
    mask = (adj.storage.value() -
            mean[adj.storage.row()]) > -1e-10
    return SparseTensor(
        row=adj.storage.row()[mask],
        col=adj.storage.col()[mask],
        value=adj.storage.value()[mask],
        sparse_sizes=adj.sizes()
    )

def scale(z):
    zmin = z.min(dim=1, keepdim=True)[0]
    zmax = z.max(dim=1, keepdim=True)[0]
    return (z - zmin) / (zmax - zmin + 1e-12)


In [7]:
class MAGILoss(nn.Module):
    def __init__(self, tau=0.3):
        super().__init__()
        self.tau = tau

    def forward(self, z, mask):
        dot = torch.mm(z, z.t()) / self.tau
        dot = dot - dot.max(dim=1, keepdim=True)[0].detach()

        logits_mask = torch.ones_like(dot) - torch.eye(z.size(0), device=z.device)
        exp_logits = torch.exp(dot) * logits_mask
        log_prob = dot - torch.log(exp_logits.sum(1, keepdim=True))

        row, col = mask.storage.row(), mask.storage.col()
        loss = -log_prob[row, col].mean()
        return loss

In [8]:
class Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim=256):
        super().__init__()
        self.conv = GCNConv(in_dim, hidden_dim)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = F.leaky_relu(x, 0.5)
        return x

In [9]:
n_runs = 10
epochs = 5000

acc_list, prec_list, rec_list, f1_list, auc_list = [], [], [], [], []

for run in range(n_runs):
    print(f"\n========== Run {run+1}/{n_runs} ==========")

    setup_seed(42 + run)

    encoder = Encoder(F_dim, 256).to(device)
    criterion = MAGILoss(tau=0.3)
    optimizer = torch.optim.Adam(
        encoder.parameters(), lr=5e-4, weight_decay=1e-3
    )

    batch = torch.arange(N, device=device)
    adj_rw = get_sim(batch, adj)
    mask = get_mask(adj_rw)

    for epoch in range(epochs):
        encoder.train()
        optimizer.zero_grad()

        z = encoder(x, edge_index)
        z = scale(z)
        z = F.normalize(z, dim=1)

        loss = criterion(z, mask)
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print(f"Run {run+1} | Epoch {epoch} | Loss {loss.item():.4f}")


    encoder.eval()
    with torch.no_grad():
        z = encoder(x, edge_index)
        z = scale(z)
        z = F.normalize(z, dim=1).cpu().numpy()

    kmeans = KMeans(n_clusters=2, n_init=20, random_state=run)
    y_pred = kmeans.fit_predict(z)

    acc1 = accuracy_score(y, y_pred)
    acc2 = accuracy_score(y, 1 - y_pred)
    if acc2 > acc1:
        y_pred = 1 - y_pred

    acc  = accuracy_score(y, y_pred)
    prec = precision_score(y, y_pred)
    rec  = recall_score(y, y_pred)
    f1   = f1_score(y, y_pred)


    acc_list.append(acc)
    prec_list.append(prec)
    rec_list.append(rec)
    f1_list.append(f1)


    print(
        f"Run {run+1} → "
        f"ACC: {acc:.4f}, "
        f"PREC: {prec:.4f}, "
        f"REC: {rec:.4f}, "
        f"F1: {f1:.4f}, "
    )
print("\n===== MAGI + K-Means (10 Runs) =====")
print(f"ACC : {np.mean(acc_list):.4f} \u00b1 {np.std(acc_list):.4f}")
print(f"PREC: {np.mean(prec_list):.4f} \u00b1 {np.std(prec_list):.4f}")
print(f"REC : {np.mean(rec_list):.4f} \u00b1 {np.std(rec_list):.4f}")
print(f"F1  : {np.mean(f1_list):.4f} \u00b1 {np.std(f1_list):.4f}")


Run 1 | Epoch 0 | Loss 4.9767
Run 1 | Epoch 500 | Loss 4.9767
Run 1 | Epoch 1000 | Loss 4.9766
Run 1 | Epoch 1500 | Loss 4.9766
Run 1 | Epoch 2000 | Loss 4.9766
Run 1 | Epoch 2500 | Loss 4.9766
Run 1 | Epoch 3000 | Loss 4.9766
Run 1 | Epoch 3500 | Loss 4.9766
Run 1 | Epoch 4000 | Loss 4.9766
Run 1 | Epoch 4500 | Loss 4.9766
Run 1 → ACC: 0.6781, PREC: 0.5294, REC: 0.1875, F1: 0.2769, 

Run 2 | Epoch 0 | Loss 4.9767
Run 2 | Epoch 500 | Loss 4.9767
Run 2 | Epoch 1000 | Loss 4.9767
Run 2 | Epoch 1500 | Loss 4.9767
Run 2 | Epoch 2000 | Loss 4.9767
Run 2 | Epoch 2500 | Loss 4.9767
Run 2 | Epoch 3000 | Loss 4.9767
Run 2 | Epoch 3500 | Loss 4.9767
Run 2 | Epoch 4000 | Loss 4.9767
Run 2 | Epoch 4500 | Loss 4.9767
Run 2 → ACC: 0.6644, PREC: 0.0000, REC: 0.0000, F1: 0.0000, 

Run 3 | Epoch 0 | Loss 4.9767
Run 3 | Epoch 500 | Loss 4.9767
Run 3 | Epoch 1000 | Loss 4.9767
Run 3 | Epoch 1500 | Loss 4.9767
Run 3 | Epoch 2000 | Loss 4.9767
Run 3 | Epoch 2500 | Loss 4.9767
Run 3 | Epoch 3000 | Loss 4.9