In [2]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import TensorDataset, DataLoader, Subset
from torchvision import models

from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_undirected
from torch_sparse import SparseTensor

from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def setup_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# data = np.load('/content/drive/MyDrive/TejaswiAbburi_va797/Dataset/Medmnist_data/pneumoniamnist_224.npz', allow_pickle=True)
data = np.load('pneumoniamnist_224.npz', allow_pickle=True)
images = np.concatenate(
    [data['train_images'], data['val_images'], data['test_images']], axis=0
).astype(np.float32) / 255.0

labels = np.concatenate(
    [data['train_labels'], data['val_labels'], data['test_labels']], axis=0
).squeeze().astype(np.int64)

# Convert grayscale → 3-channel
images = np.repeat(images[:, None, :, :], 3, axis=1)

X_img = torch.tensor(images)
y = labels

print("Total samples:", len(y))


idx0 = np.where(y == 0)[0]
idx1 = np.where(y == 1)[0]

random.seed(42)
idx0 = random.sample(idx0.tolist(), min(2000, len(idx0)))
idx1 = random.sample(idx1.tolist(), min(2000, len(idx1)))

indices = idx0 + idx1
random.shuffle(indices)

X_img = X_img[indices]
y = y[indices]

dataset = TensorDataset(X_img, torch.tensor(y))
loader = DataLoader(dataset, batch_size=64, shuffle=False)

print("Balanced samples:", len(y))

# ======================================================
# ResNet18 Feature Extraction
# ======================================================
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Identity()
resnet.eval().to(device)

features = []
labels_list = []

with torch.no_grad():
    for imgs, lbls in loader:
        imgs = imgs.to(device)
        feats = resnet(imgs)
        features.append(feats.cpu())
        labels_list.extend(lbls.numpy())

X = torch.cat(features, dim=0).numpy().astype(np.float32)
y = np.array(labels_list)

N, F_dim = X.shape
print("Feature matrix:", X.shape)

x = torch.from_numpy(X).to(device)

# ======================================================
# Build cosine similarity graph
# ======================================================
def create_adj(features, alpha=0.9):
    f = features / np.linalg.norm(features, axis=1, keepdims=True)
    W = np.dot(f, f.T)
    W = (W >= alpha).astype(np.float32)
    return W

W = create_adj(X, alpha=0.9)

rows, cols = np.nonzero(W)
edge_index = torch.tensor([rows, cols], dtype=torch.long)
edge_index = to_undirected(edge_index).to(device)

adj = SparseTensor(
    row=edge_index[0],
    col=edge_index[1],
    sparse_sizes=(N, N)
).fill_value(1.).to(device)

print("Edges:", adj.nnz())

# ======================================================
# MAGI utilities
# ======================================================
def get_sim(batch, adj, wt=50, wl=2):
    batch_size = batch.shape[0]
    batch_repeat = batch.repeat(wt)

    rw = adj.random_walk(batch_repeat, wl)[:, 1:]
    rw = rw.t().reshape(-1, batch_size).t()

    row, col, val = [], [], []
    for i in range(batch_size):
        nodes, counts = torch.unique(rw[i], return_counts=True)
        row += [batch[i].item()] * nodes.shape[0]
        col += nodes.tolist()
        val += counts.tolist()

    adj_rw = SparseTensor(
        row=torch.tensor(row),
        col=torch.tensor(col),
        value=torch.tensor(val),
        sparse_sizes=(batch_size, batch_size)
    )
    return adj_rw.set_diag(0.)

def get_mask(adj):
    mean = adj.mean(dim=1)
    mask = (adj.storage.value() -
            mean[adj.storage.row()]) > -1e-10

    return SparseTensor(
        row=adj.storage.row()[mask],
        col=adj.storage.col()[mask],
        value=adj.storage.value()[mask],
        sparse_sizes=adj.sizes()
    )

def scale(z):
    zmin = z.min(dim=1, keepdim=True)[0]
    zmax = z.max(dim=1, keepdim=True)[0]
    return (z - zmin) / (zmax - zmin + 1e-12)

# ======================================================
# MAGI Loss
# ======================================================
class MAGILoss(nn.Module):
    def __init__(self, tau=0.3):
        super().__init__()
        self.tau = tau

    def forward(self, z, mask):
        sim = torch.mm(z, z.t()) / self.tau
        sim = sim - sim.max(dim=1, keepdim=True)[0].detach()

        logits_mask = torch.ones_like(sim) - torch.eye(z.size(0), device=z.device)
        exp_sim = torch.exp(sim) * logits_mask
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True))

        row, col = mask.storage.row(), mask.storage.col()
        return -log_prob[row, col].mean()

# ======================================================
# Encoder (GCN)
# ======================================================
class Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim=256):
        super().__init__()
        self.conv = GCNConv(in_dim, hidden_dim)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        return F.leaky_relu(x, 0.2)

# ======================================================
# 10-Run MAGI + KMeans Evaluation
# ======================================================
n_runs = 10
epochs = 3000

accs, precs, recs, f1s = [], [], [], []

for run in range(n_runs):
    print(f"\n===== Run {run+1}/{n_runs} =====")
    setup_seed(42 + run)

    model = Encoder(F_dim, 256).to(device)
    loss_fn = MAGILoss(tau=0.3)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

    batch = torch.arange(N, device=device)
    adj_rw = get_sim(batch, adj)
    mask = get_mask(adj_rw)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        z = model(x, edge_index)
        z = scale(z)
        z = F.normalize(z, dim=1)

        loss = loss_fn(z, mask)
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print(f"Epoch {epoch} | Loss {loss.item():.4f}")

    # Embeddings
    model.eval()
    with torch.no_grad():
        z = model(x, edge_index)
        z = scale(z)
        z = F.normalize(z, dim=1).cpu().numpy()

    # KMeans
    km = KMeans(n_clusters=2, n_init=20, random_state=run)
    y_pred = km.fit_predict(z)

    # Align labels
    if accuracy_score(y, 1 - y_pred) > accuracy_score(y, y_pred):
        y_pred = 1 - y_pred

    accs.append(accuracy_score(y, y_pred))
    precs.append(precision_score(y, y_pred))
    recs.append(recall_score(y, y_pred))
    f1s.append(f1_score(y, y_pred))

    print(
        f"Run {run+1} → "
        f"ACC: {accs[-1]:.4f}, "
        f"PREC: {precs[-1]:.4f}, "
        f"REC: {recs[-1]:.4f}, "
        f"F1: {f1s[-1]:.4f}"
    )

# ======================================================
# Results
# ======================================================
print("\n===== MAGI + KMeans (PneumoniaMNIST, 10 Runs) =====")
print(f"ACC : {np.mean(accs):.4f} ± {np.std(accs):.4f}")
print(f"PREC: {np.mean(precs):.4f} ± {np.std(precs):.4f}")
print(f"REC : {np.mean(recs):.4f} ± {np.std(recs):.4f}")
print(f"F1  : {np.mean(f1s):.4f} ± {np.std(f1s):.4f}")


Total samples: 5856
Balanced samples: 3583




Feature matrix: (3583, 512)


  edge_index = torch.tensor([rows, cols], dtype=torch.long)


Edges: 1642013

===== Run 1/10 =====
Epoch 0 | Loss 8.1003
Epoch 500 | Loss 7.3648
Epoch 1000 | Loss 7.3471
Epoch 1500 | Loss 7.3380
Epoch 2000 | Loss 7.3324
Epoch 2500 | Loss 7.3294
Run 1 → ACC: 0.9191, PREC: 0.9416, REC: 0.9115, F1: 0.9263

===== Run 2/10 =====
Epoch 0 | Loss 8.1141
Epoch 500 | Loss 7.3642
Epoch 1000 | Loss 7.3478
Epoch 1500 | Loss 7.3461
Epoch 2000 | Loss 7.3343
Epoch 2500 | Loss 7.3331
Run 2 → ACC: 0.9199, PREC: 0.9422, REC: 0.9125, F1: 0.9271

===== Run 3/10 =====
Epoch 0 | Loss 8.1174
Epoch 500 | Loss 7.3626
Epoch 1000 | Loss 7.3457
Epoch 1500 | Loss 7.3413
Epoch 2000 | Loss 7.3329
Epoch 2500 | Loss 7.3303
Run 3 → ACC: 0.9191, PREC: 0.9412, REC: 0.9120, F1: 0.9264

===== Run 4/10 =====
Epoch 0 | Loss 8.1160
Epoch 500 | Loss 7.3644
Epoch 1000 | Loss 7.3471
Epoch 1500 | Loss 7.3463
Epoch 2000 | Loss 7.3355
Epoch 2500 | Loss 7.3283
Run 4 → ACC: 0.9193, PREC: 0.9403, REC: 0.9135, F1: 0.9267

===== Run 5/10 =====
Epoch 0 | Loss 8.0975
Epoch 500 | Loss 7.3646
Epoch 100