In [27]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.2.0+cu121


In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset
import numpy as np
import random
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, log_loss
from munkres import Munkres

In [29]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [30]:
class PairDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx][0], self.pairs[idx][1], self.labels[idx]

In [31]:
class SiameseNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)

    def forward_once(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

    def forward(self, x1, x2):
        out1 = self.forward_once(x1)
        out2 = self.forward_once(x2)
        dist = F.pairwise_distance(out1, out2)
        return dist, out1, out2

In [32]:
class OrthoLinear(nn.Module):
    def __init__(self, in_dim, out_dim, eps=1e-2):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.eps = eps

    def forward(self, x):
        Y_tilde = self.fc(x)
        gram = Y_tilde.T @ Y_tilde + self.eps * torch.eye(Y_tilde.shape[1], device=x.device)
        L = torch.linalg.cholesky(gram)
        L_inv = torch.inverse(L)
        return Y_tilde @ L_inv.T

In [33]:
class SpectralNet(nn.Module):
    def __init__(self, input_dim, n_clusters, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ortho = OrthoLinear(hidden_dim, n_clusters)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.ortho(x)

In [34]:
def contrastive_loss(distance, label, margin=1.0):
    pos = label * torch.pow(distance, 2)
    neg = (1 - label) * torch.pow(torch.clamp(margin - distance, min=0.0), 2)
    return torch.mean(pos + neg)

def spectral_loss(Y, W):
    D = torch.diag(W.sum(axis=1))
    L = D - W
    num = torch.trace(Y.T @ L @ Y)
    denom = torch.trace(Y.T @ D @ Y)
    return num / (denom + 1e-12)

In [35]:
def compute_affinity(X, scale, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)
    distances, indices = distances[:, 1:], indices[:, 1:]
    W = np.zeros((len(X), len(X)))
    for i in range(len(X)):
        for j in range(n_neighbors):
            W[i, indices[i, j]] = np.exp(-distances[i, j] ** 2 / (2 * scale ** 2))
            W[indices[i, j], i] = W[i, indices[i, j]]
    return W

def compute_scale(X, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, _ = nbrs.kneighbors(X)
    return np.median(distances[:, -1])

def calculate_accuracy(y_pred, y_true, n_clusters):
    cm = confusion_matrix(y_true, y_pred)
    cost = np.zeros((n_clusters, n_clusters))
    for i in range(n_clusters):
        for j in range(n_clusters):
            cost[i, j] = cm[:, j].sum() - cm[i, j]
    m = Munkres()
    mapping = m.compute(cost.tolist())
    new_labels = np.zeros_like(y_pred)
    for row, col in mapping:
        new_labels[y_pred == row] = col
    return (new_labels == y_true).mean()

In [36]:
def train_siamese(siamese_net, dataloader, epochs=50, lr=1e-3, device='cpu'):
    siamese_net.to(device)
    opt = optim.Adam(siamese_net.parameters(), lr=lr)
    for ep in range(epochs):
        siamese_net.train()
        total_loss = 0.0
        for x1, x2, labels in dataloader:
            x1 = x1.to(device).float()
            x2 = x2.to(device).float()
            labels = labels.to(device).float()
            dist, _, _ = siamese_net(x1, x2)
            loss = contrastive_loss(dist, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item() * x1.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        print(f"[Siamese] Epoch {ep + 1}/{epochs}, Avg Loss={avg_loss:.6f}")
    siamese_net.to('cpu')
    return siamese_net

def train_spectral(spectral_net, X_train, W, epochs=50, lr=1e-3, tol=1e-6, device='cpu'):
    spectral_net.to(device)
    opt = optim.Adam([
        {'params': spectral_net.fc1.parameters()},
        {'params': spectral_net.fc2.parameters()},
        {'params': spectral_net.ortho.fc.parameters()}
    ], lr=lr)

    X_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)
    W_tensor = torch.tensor(W, dtype=torch.float32, device=device)

    prev_loss = float('inf')
    for ep in range(epochs):
        spectral_net.train()
        Y = spectral_net(X_tensor)  # (n_samples, n_clusters)
        loss = spectral_loss(Y, W_tensor)
        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_item = loss.item()
        print(f"[SpectralNet] Epoch {ep + 1}/{epochs}, Loss={loss_item:.8f}")

        if abs(prev_loss - loss_item) < tol:
            print("SpectralNet converged (tol reached).")
            break
        prev_loss = loss_item
    spectral_net.to('cpu')
    return spectral_net

In [37]:
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ====== Load PneumoniaMNIST ======
    data = np.load('/home/snu/Downloads/breastmnist_224.npz', allow_pickle=True)
    all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
    all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

    images = all_images.astype(np.float32) / 255.0
    images = np.repeat(images[:, None, :, :], 3, axis=1)  # (N, 3, 224, 224)
    labels = all_labels.astype(np.int64)

    # ====== Select 2000 samples per class ======
    selected_indices = []
    num_per_class = 1000
    classes = np.unique(labels)
    for c in classes:
        class_idx = np.where(labels == c)[0]
        chosen = np.random.choice(class_idx, size=min(num_per_class, len(class_idx)), replace=False)
        selected_indices.extend(chosen)

    selected_indices = np.array(selected_indices)
    images = images[selected_indices]
    labels = labels[selected_indices]

    # ====== Create dataset ======
    dataset = TensorDataset(torch.tensor(images), torch.tensor(labels))
    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    # ====== Load ViT-DINO ======
    vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
    vit.eval().to(device)

    feats, y_list = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device).float()
            f = vit(imgs)
            if isinstance(f, (list, tuple)):
                f = f[0]
            feats.append(f.cpu())
            y_list.append(lbls)

    X = torch.cat(feats, dim=0).numpy().astype(np.float32)
    y = torch.cat(y_list, dim=0).numpy().astype(np.int64)

    # Shuffle
    perm = np.random.permutation(len(X))
    X, y = X[perm], y[perm]
    print("Balanced subset:", X.shape, y.shape)

    num_nodes, num_feats = X.shape
    n_clusters = 2
    hidden_dim = 256
    batch_size = 16
    n_neighbors = 20

    # ====== Build siamese pairs ======
    pairs, labels_pairs = [], []
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)
    for i in range(len(X)):
        for j in indices[i, 1:]:
            pairs.append([X[i], X[j]])
            labels_pairs.append(1)
        all_indices = set(range(len(X)))
        neighbor_set = set(indices[i, 1:])
        non_neighbors = list(all_indices - neighbor_set - {i})
        j = np.random.choice(non_neighbors)
        pairs.append([X[i], X[j]])
        labels_pairs.append(0)

    dataset_pairs = PairDataset(pairs, labels_pairs)
    dataloader_pairs = DataLoader(dataset_pairs, batch_size=batch_size, shuffle=True)

    # ====== Train Siamese ======
    siamese = SiameseNet(num_feats, hidden_dim)
    siamese = train_siamese(siamese, dataloader_pairs, epochs=50, device=device)

    with torch.no_grad():
        X_embed = siamese.forward_once(torch.tensor(X, dtype=torch.float32)).numpy()

    # ====== Train SpectralNet ======
    scale = compute_scale(X_embed, n_neighbors=n_neighbors)
    W = compute_affinity(X_embed, scale, n_neighbors=n_neighbors)
    spectral = SpectralNet(num_feats, n_clusters, hidden_dim)
    spectral = train_spectral(spectral, X, W, epochs=50, device=device)

    with torch.no_grad():
        Y = spectral(torch.tensor(X, dtype=torch.float32)).numpy()
        y_pred_proba = F.softmax(torch.tensor(Y), dim=1).numpy()

    kmeans = KMeans(n_clusters=n_clusters, n_init=20)
    y_pred = kmeans.fit_predict(Y)

    acc_score = calculate_accuracy(y_pred, y, n_clusters)
    acc_score_inverted = calculate_accuracy(1 - y_pred, y, n_clusters)
    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec = precision_score(y, y_pred)
    rec = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    ll = log_loss(y, y_pred_proba)

    print("Final clustering accuracy:", acc_score)
    print("Precision:", prec)
    print("Recall:", rec)
    print("F1 Score:", f1)
    print("Log Loss:", ll)

    return {"accuracy": acc_score, "precision": prec, "recall": rec, "f1": f1, "log_loss": ll}

# --------------------------
# Multi-runs
# --------------------------
if __name__ == "__main__":
    num_runs = 10
    all_results = {"accuracy": [], "precision": [], "recall": [], "f1": [], "log_loss": []}
    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")
        set_seed(run)
        res = main()
        for k in all_results.keys():
            all_results[k].append(res[k])

    print("\n=== FINAL SUMMARY ===")
    for metric, vals in all_results.items():
        print(f"{metric:>10} | mean={np.mean(vals):.4f} ± {np.std(vals):.4f}")



--- Run 1/10 ---


Using cache found in /home/snu/.cache/torch/hub/facebookresearch_dino_main


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 15.72 GiB of which 46.62 MiB is free. Process 235935 has 3.78 GiB memory in use. Process 239319 has 694.00 MiB memory in use. Process 239627 has 1.38 GiB memory in use. Process 239987 has 1.38 GiB memory in use. Process 240113 has 694.00 MiB memory in use. Including non-PyTorch memory, this process has 1.98 GiB memory in use. Process 241651 has 2.11 GiB memory in use. Process 242043 has 3.21 GiB memory in use. Of the allocated memory 1.68 GiB is allocated by PyTorch, and 57.36 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

=== FINAL SUMMARY ===
  accuracy | mean=0.5271 ± 0.0186
 precision | mean=0.6707 ± 0.0621
    recall | mean=0.3043 ± 0.2325
        f1 | mean=0.3661 ± 0.1580
  log_loss | mean=0.6926 ± 0.0041 (batch size = 32)

=== FINAL SUMMARY ===
  accuracy | mean=0.5439 ± 0.0262
 precision | mean=0.6301 ± 0.0744
    recall | mean=0.3763 ± 0.3009
        f1 | mean=0.3889 ± 0.2028
  log_loss | mean=0.6909 ± 0.0050(batch size = 16)