In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.6.0+cu124


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, log_loss
from munkres import Munkres
import random

In [2]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.cuda.manual_seed_all(42)

In [3]:
fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Load MCI features
fa_feature_path = "/home/snu/Downloads/Histogram_MCI_FA_20bin_updated.npy"
Histogram_feature_MCI_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Combine features and labels
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_MCI_FA_array.shape[0], dtype=np.int64)
])
np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]
num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

Features: (300, 180), Labels: (300,)


In [4]:
class PairDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx][0], self.pairs[idx][1], self.labels[idx]

In [5]:
class SiameseNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)

    def forward_once(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

    def forward(self, x1, x2):
        out1 = self.forward_once(x1)
        out2 = self.forward_once(x2)
        dist = F.pairwise_distance(out1, out2)
        return dist, out1, out2

In [6]:
class OrthoLinear(nn.Module):
    """Layer that enforces orthogonality constraint using Cholesky factorization"""
    def __init__(self, in_dim, out_dim, eps=1e-4):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.eps = eps

    def forward(self, x):
        Y_tilde = self.fc(x)  # shape (batch_size, out_dim)
        # Compute Y_tilde^T Y_tilde (out_dim x out_dim)
        gram = Y_tilde.T @ Y_tilde + self.eps * torch.eye(Y_tilde.shape[1], device=x.device)
        # Cholesky factorization
        L = torch.linalg.cholesky(gram)
        # Inverse of L
        L_inv = torch.inverse(L)
        # Orthogonalized output
        Y = Y_tilde @ L_inv.T
        return Y

In [7]:
class SpectralNet(nn.Module):
    def __init__(self, input_dim, n_clusters, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ortho = OrthoLinear(hidden_dim, n_clusters)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.ortho(x)

In [8]:
def contrastive_loss(distance, label, margin=1.0):
    pos = label * torch.pow(distance, 2)
    neg = (1 - label) * torch.pow(torch.clamp(margin - distance, min=0.0), 2)
    return torch.mean(pos + neg)

In [9]:
def spectral_loss(Y, W):
    # Y: (n_samples, n_clusters)
    # W: (n_samples, n_samples)
    D = torch.diag(W.sum(axis=1))
    L = D - W
    # numerator: trace(Y^T L Y)
    num = torch.trace(Y.T @ L @ Y)
    denom = torch.trace(Y.T @ D @ Y)
    return num / (denom + 1e-12)

In [10]:
def compute_affinity(X, scale, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)
    distances, indices = distances[:, 1:], indices[:, 1:]
    W = np.zeros((len(X), len(X)))
    for i in range(len(X)):
        for j in range(n_neighbors):
            W[i, indices[i, j]] = np.exp(-distances[i, j] ** 2 / (2 * scale ** 2))
            W[indices[i, j], i] = W[i, indices[i, j]]
    return W

def compute_scale(X, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, _ = nbrs.kneighbors(X)
    return np.median(distances[:, -1])

In [11]:
def calculate_accuracy(y_pred, y_true, n_clusters):
    cm = confusion_matrix(y_true, y_pred)
    cost = np.zeros((n_clusters, n_clusters))
    for i in range(n_clusters):
        for j in range(n_clusters):
            cost[i, j] = cm[:, j].sum() - cm[i, j]
    m = Munkres()
    mapping = m.compute(cost.tolist())
    new_labels = np.zeros_like(y_pred)
    for row, col in mapping:
        new_labels[y_pred == row] = col
    return (new_labels == y_true).mean()

In [12]:
def train_siamese(siamese_net, dataloader, epochs=50, lr=1e-3, device='cpu'):
    siamese_net.to(device)
    opt = optim.Adam(siamese_net.parameters(), lr=lr)
    for ep in range(epochs):
        siamese_net.train()
        total_loss = 0.0
        for x1, x2, labels in dataloader:
            x1 = x1.to(device).float()
            x2 = x2.to(device).float()
            labels = labels.to(device).float()
            dist, _, _ = siamese_net(x1, x2)
            loss = contrastive_loss(dist, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item() * x1.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        print(f"[Siamese] Epoch {ep + 1}/{epochs}, Avg Loss={avg_loss:.6f}")
    siamese_net.to('cpu')
    return siamese_net

def train_spectral(spectral_net, X_train, W, epochs=50, lr=1e-3, tol=1e-6, device='cpu'):
    spectral_net.to(device)
    opt = optim.Adam([
        {'params': spectral_net.fc1.parameters()},
        {'params': spectral_net.fc2.parameters()},
        {'params': spectral_net.ortho.fc.parameters()}
    ], lr=lr)

    X_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)
    W_tensor = torch.tensor(W, dtype=torch.float32, device=device)

    prev_loss = float('inf')
    for ep in range(epochs):
        spectral_net.train()
        Y = spectral_net(X_tensor)  # (n_samples, n_clusters)
        loss = spectral_loss(Y, W_tensor)
        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_item = loss.item()
        print(f"[SpectralNet] Epoch {ep + 1}/{epochs}, Loss={loss_item:.8f}")

        if abs(prev_loss - loss_item) < tol:
            print("SpectralNet converged (tol reached).")
            break
        prev_loss = loss_item
    spectral_net.to('cpu')
    return spectral_net

In [13]:
def main():
    n_clusters = 2
    hidden_dim = 256
    batch_size = 32
    n_neighbors = 20

    pairs, labels = [], []
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)

    for i in range(len(X)):
        for j in indices[i, 1:]:  # positive pairs (neighbors)
            pairs.append([X[i], X[j]])
            labels.append(1)
        all_indices = set(range(len(X)))
        neighbor_set = set(indices[i, 1:])
        non_neighbors = list(all_indices - neighbor_set - {i})
        j = np.random.choice(non_neighbors)
        pairs.append([X[i], X[j]])
        labels.append(0)

    dataset = PairDataset(pairs, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Train Siamese Network
    siamese = SiameseNet(num_feats, hidden_dim)
    siamese = train_siamese(siamese, dataloader, epochs=50)

    # Embed data with Siamese
    with torch.no_grad():
        X_embed = siamese.forward_once(torch.tensor(X, dtype=torch.float32)).numpy()

    # Compute scale and affinity matrix for SpectralNet
    scale = compute_scale(X_embed, n_neighbors=n_neighbors)
    W = compute_affinity(X_embed, scale, n_neighbors=n_neighbors)
    num_edges = np.count_nonzero(W) // 2
    print("Number of edges in the graph:", num_edges)

    # Train SpectralNet
    spectral = SpectralNet(num_feats, n_clusters, hidden_dim)
    spectral = train_spectral(spectral, X, W, epochs=50)

    # Get embeddings and predictions
    with torch.no_grad():
        Y = spectral(torch.tensor(X, dtype=torch.float32)).numpy()
        y_pred_proba = F.softmax(torch.tensor(Y), dim=1).numpy()

    # KMeans clustering on SpectralNet outputs
    kmeans = KMeans(n_clusters=n_clusters, n_init=20)
    y_pred = kmeans.fit_predict(Y)

    # Evaluate accuracy with Hungarian mapping
    acc_score = calculate_accuracy(y_pred, y, n_clusters)
    acc_score_inverted = calculate_accuracy(1 - y_pred, y, n_clusters)
    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    print(y_pred)
    prec_score = precision_score(y, y_pred)
    rec_score = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    log_loss_value = log_loss(y, y_pred_proba)

    print("Final clustering accuracy:", acc_score)
    print("Precision:", prec_score)
    print("Recall:", rec_score)
    print("F1 Score:", f1)
    print("Log Loss:", log_loss_value)

    results = {"accuracy": acc_score, "precision": prec_score, "recall": rec_score, "f1": f1, "log_loss": log_loss_value}
    return results

if __name__ == "__main__":
    main()

[Siamese] Epoch 1/50, Avg Loss=0.037101
[Siamese] Epoch 2/50, Avg Loss=0.035806
[Siamese] Epoch 3/50, Avg Loss=0.035145
[Siamese] Epoch 4/50, Avg Loss=0.034396
[Siamese] Epoch 5/50, Avg Loss=0.034955
[Siamese] Epoch 6/50, Avg Loss=0.034664
[Siamese] Epoch 7/50, Avg Loss=0.034515
[Siamese] Epoch 8/50, Avg Loss=0.034448
[Siamese] Epoch 9/50, Avg Loss=0.034291
[Siamese] Epoch 10/50, Avg Loss=0.034067
[Siamese] Epoch 11/50, Avg Loss=0.034157
[Siamese] Epoch 12/50, Avg Loss=0.034385
[Siamese] Epoch 13/50, Avg Loss=0.033981
[Siamese] Epoch 14/50, Avg Loss=0.034264
[Siamese] Epoch 15/50, Avg Loss=0.034008
[Siamese] Epoch 16/50, Avg Loss=0.033949
[Siamese] Epoch 17/50, Avg Loss=0.034007
[Siamese] Epoch 18/50, Avg Loss=0.034152
[Siamese] Epoch 19/50, Avg Loss=0.034113
[Siamese] Epoch 20/50, Avg Loss=0.033830
[Siamese] Epoch 21/50, Avg Loss=0.034101
[Siamese] Epoch 22/50, Avg Loss=0.033737
[Siamese] Epoch 23/50, Avg Loss=0.033910
[Siamese] Epoch 24/50, Avg Loss=0.033871
[Siamese] Epoch 25/50, Av

In [14]:
if __name__ == "__main__":
    num_runs = 10
    all_results = {"accuracy": [], "precision": [], "recall": [], "f1": [], "log_loss": []}

    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")
        torch.manual_seed(run)
        np.random.seed(run)
        random.seed(run)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(run)

        results = main()

        all_results["accuracy"].append(results["accuracy"])
        all_results["precision"].append(results["precision"])
        all_results["recall"].append(results["recall"])
        all_results["f1"].append(results["f1"])
        all_results["log_loss"].append(results["log_loss"])


    print("\n================ FINAL SUMMARY ================\n")
    print(f"{'Metric':>12} | {'Mean':>10} ± {'Std':<10}")
    print("-" * 40)
    for metric, values in all_results.items():
        print(f"{metric:>12} | {np.mean(values):10.4f} ± {np.std(values):<10.4f}")


--- Run 1/10 ---
[Siamese] Epoch 1/50, Avg Loss=0.038488
[Siamese] Epoch 2/50, Avg Loss=0.036755
[Siamese] Epoch 3/50, Avg Loss=0.036676
[Siamese] Epoch 4/50, Avg Loss=0.036439
[Siamese] Epoch 5/50, Avg Loss=0.036161
[Siamese] Epoch 6/50, Avg Loss=0.035895
[Siamese] Epoch 7/50, Avg Loss=0.035849
[Siamese] Epoch 8/50, Avg Loss=0.035726
[Siamese] Epoch 9/50, Avg Loss=0.035926
[Siamese] Epoch 10/50, Avg Loss=0.035517
[Siamese] Epoch 11/50, Avg Loss=0.035649
[Siamese] Epoch 12/50, Avg Loss=0.035450
[Siamese] Epoch 13/50, Avg Loss=0.035381
[Siamese] Epoch 14/50, Avg Loss=0.035414
[Siamese] Epoch 15/50, Avg Loss=0.035355
[Siamese] Epoch 16/50, Avg Loss=0.035417
[Siamese] Epoch 17/50, Avg Loss=0.035250
[Siamese] Epoch 18/50, Avg Loss=0.035369
[Siamese] Epoch 19/50, Avg Loss=0.035023
[Siamese] Epoch 20/50, Avg Loss=0.035208
[Siamese] Epoch 21/50, Avg Loss=0.035073
[Siamese] Epoch 22/50, Avg Loss=0.035182
[Siamese] Epoch 23/50, Avg Loss=0.034908
[Siamese] Epoch 24/50, Avg Loss=0.034948
[Siames

With 20 neighbors, with 32 batch size

      Metric |       Mean ± Std       
----------------------------------------
    accuracy |     0.6270 ± 0.0491    
   precision |     0.5165 ± 0.1315    
      recall |     0.4737 ± 0.1695    
          f1 |     0.4912 ± 0.1511    
    log_loss |     0.7008 ± 0.0080    

With 20 neighbors, with 64 batch size
      Metric |       Mean ± Std       
----------------------------------------
    accuracy |     0.6207 ± 0.0469    
   precision |     0.5173 ± 0.1251    
      recall |     0.4707 ± 0.1689    
          f1 |     0.4892 ± 0.1484    
    log_loss |     0.7009 ± 0.0070    