In [22]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import TensorDataset, DataLoader, Subset
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_undirected
from torch_sparse import SparseTensor

from sklearn.cluster import KMeans
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score
)

In [23]:
def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
data = np.load('/home/snu/Downloads/breastmnist_224.npz', allow_pickle=True)

images = np.concatenate(
    [data['train_images'], data['val_images'], data['test_images']], axis=0
).astype(np.float32) / 255.0

labels = np.concatenate(
    [data['train_labels'], data['val_labels'], data['test_labels']], axis=0
).squeeze().astype(np.int64)

# Convert to 3-channel
images = np.repeat(images[:, None, :, :], 3, axis=1)

X_img = torch.tensor(images)
y = labels

In [25]:
idx0 = np.where(y == 0)[0]
idx1 = np.where(y == 1)[0]

random.seed(42)
idx0 = random.sample(idx0.tolist(), min(1000, len(idx0)))
idx1 = random.sample(idx1.tolist(), min(1000, len(idx1)))

indices = idx0 + idx1
random.shuffle(indices)

X_img = X_img[indices]
y = y[indices]

dataset = TensorDataset(X_img, torch.tensor(y))
loader = DataLoader(dataset, batch_size=64, shuffle=False)

print("Balanced samples:", len(y))

Balanced samples: 780


In [26]:
vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vit.eval().to(device)

features = []
labels_list = []

with torch.no_grad():
    for imgs, lbls in loader:
        imgs = imgs.to(device)
        feats = vit(imgs)
        features.append(feats.cpu())
        labels_list.extend(lbls.numpy())

X = torch.cat(features, dim=0).numpy().astype(np.float32)
y = np.array(labels_list)

N, F_dim = X.shape
print("Feature matrix:", X.shape)

x = torch.from_numpy(X).to(device)

Using cache found in /home/snu/.cache/torch/hub/facebookresearch_dino_main


Feature matrix: (780, 768)


In [27]:
def create_adj(features, alpha=0.7):
    f = features / np.linalg.norm(features, axis=1, keepdims=True)
    W = np.dot(f, f.T)
    W = (W >= alpha).astype(np.float32)
    return W

W = create_adj(X, alpha=0.7)

rows, cols = np.nonzero(W)
edge_index = torch.tensor([rows, cols], dtype=torch.long)
edge_index = to_undirected(edge_index).to(device)

adj = SparseTensor(
    row=edge_index[0],
    col=edge_index[1],
    sparse_sizes=(N, N)
).fill_value(1.).to(device)

print("Edges:", adj.nnz())

Edges: 351480


In [28]:
def get_sim(batch, adj, wt=50, wl=2):
    batch_size = batch.shape[0]
    batch_repeat = batch.repeat(wt)
    rw = adj.random_walk(batch_repeat, wl)[:, 1:]
    rw = rw.t().reshape(-1, batch_size).t()

    row, col, val = [], [], []
    for i in range(batch_size):
        nodes, counts = torch.unique(rw[i], return_counts=True)
        row += [batch[i].item()] * nodes.shape[0]
        col += nodes.tolist()
        val += counts.tolist()

    adj_rw = SparseTensor(
        row=torch.tensor(row),
        col=torch.tensor(col),
        value=torch.tensor(val),
        sparse_sizes=(batch_size, batch_size)
    )
    return adj_rw.set_diag(0.)


In [29]:
def get_mask(adj):
    mean = adj.mean(dim=1)
    mask = (adj.storage.value() -
            mean[adj.storage.row()]) > -1e-10

    return SparseTensor(
        row=adj.storage.row()[mask],
        col=adj.storage.col()[mask],
        value=adj.storage.value()[mask],
        sparse_sizes=adj.sizes()
    )

def scale(z):
    zmin = z.min(dim=1, keepdim=True)[0]
    zmax = z.max(dim=1, keepdim=True)[0]
    return (z - zmin) / (zmax - zmin + 1e-12)

In [30]:
class MAGILoss(nn.Module):
    def __init__(self, tau=0.3):
        super().__init__()
        self.tau = tau

    def forward(self, z, mask):
        sim = torch.mm(z, z.t()) / self.tau
        sim = sim - sim.max(dim=1, keepdim=True)[0].detach()

        logits_mask = torch.ones_like(sim) - torch.eye(z.size(0), device=z.device)
        exp_sim = torch.exp(sim) * logits_mask
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True))

        row, col = mask.storage.row(), mask.storage.col()
        return -log_prob[row, col].mean()

In [31]:
class Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim=256):
        super().__init__()
        self.conv = GCNConv(in_dim, hidden_dim)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        return F.leaky_relu(x, 0.2)

In [32]:
n_runs = 10
epochs = 3000

accs, precs, recs, f1s = [], [], [], []

for run in range(n_runs):
    print(f"\n===== Run {run+1}/{n_runs} =====")
    setup_seed(42 + run)

    model = Encoder(F_dim, 256).to(device)
    loss_fn = MAGILoss(tau=0.3)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

    batch = torch.arange(N, device=device)
    adj_rw = get_sim(batch, adj)
    mask = get_mask(adj_rw)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        z = model(x, edge_index)
        z = scale(z)
        z = F.normalize(z, dim=1)

        loss = loss_fn(z, mask)
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print(f"Epoch {epoch} | Loss {loss.item():.4f}")

    model.eval()
    with torch.no_grad():
        z = model(x, edge_index)
        z = scale(z)
        z = F.normalize(z, dim=1).cpu().numpy()

    km = KMeans(n_clusters=2, n_init=20, random_state=run)
    y_pred = km.fit_predict(z)

    if accuracy_score(y, 1 - y_pred) > accuracy_score(y, y_pred):
        y_pred = 1 - y_pred

    acc = accuracy_score(y, y_pred)
    prec = precision_score(y, y_pred)
    rec = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)

    accs.append(acc)
    precs.append(prec)
    recs.append(rec)
    f1s.append(f1)
    print(
        f"Run {run+1} \u2192 "
        f"ACC: {acc:.4f}, "
        f"PREC: {prec:.4f}, "
        f"REC: {rec:.4f}, "
        f"F1: {f1:.4f}"
    )

print("\n===== MAGI + KMeans (BreastMNIST, 10 Runs) =====")
print(f"ACC : {np.mean(accs):.4f} \u00b1 {np.std(accs):.4f}")
print(f"PREC: {np.mean(precs):.4f} \u00b1 {np.std(precs):.4f}")
print(f"REC : {np.mean(recs):.4f} \u00b1 {np.std(recs):.4f}")
print(f"F1  : {np.mean(f1s):.4f} \u00b1 {np.std(f1s):.4f}")


===== Run 1/10 =====
Epoch 0 | Loss 6.6386
Epoch 500 | Loss 6.4795
Epoch 1000 | Loss 6.4772
Epoch 1500 | Loss 6.4787
Epoch 2000 | Loss 6.4831
Epoch 2500 | Loss 6.4741
Run 1 → ACC: 0.6628, PREC: 0.7241, REC: 0.8702, F1: 0.7904

===== Run 2/10 =====
Epoch 0 | Loss 6.6375
Epoch 500 | Loss 6.4819
Epoch 1000 | Loss 6.4825
Epoch 1500 | Loss 6.4780
Epoch 2000 | Loss 6.4767
Epoch 2500 | Loss 6.4850
Run 2 → ACC: 0.6603, PREC: 0.7233, REC: 0.8667, F1: 0.7885

===== Run 3/10 =====
Epoch 0 | Loss 6.6406
Epoch 500 | Loss 6.4806
Epoch 1000 | Loss 6.4822
Epoch 1500 | Loss 6.4749
Epoch 2000 | Loss 6.4801
Epoch 2500 | Loss 6.4746
Run 3 → ACC: 0.6603, PREC: 0.7246, REC: 0.8632, F1: 0.7878

===== Run 4/10 =====
Epoch 0 | Loss 6.6412
Epoch 500 | Loss 6.4813
Epoch 1000 | Loss 6.4798
Epoch 1500 | Loss 6.4771
Epoch 2000 | Loss 6.4769
Epoch 2500 | Loss 6.4758
Run 4 → ACC: 0.6551, PREC: 0.7223, REC: 0.8579, F1: 0.7843

===== Run 5/10 =====
Epoch 0 | Loss 6.6380
Epoch 500 | Loss 6.4816
Epoch 1000 | Loss 6.4828