In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install -q torch_geometric
!pip install -q class_resolver
!pip3 install pymatting

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pymatting
  Downloading pymatting-1.1.14-py3-none-any.whl.metadata (7.7 kB)
Downloading pymatting-1.1.14-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.7/54.7 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymatting
Successfully installed pymatting-1.1.14


In [3]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from torch.utils.data import TensorDataset, DataLoader, Subset
import random

In [4]:
data = np.load('/content/drive/MyDrive/TejaswiAbburi_va797/Dataset/Medmnist_data/pneumoniamnist_224.npz', allow_pickle=True)

all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # grayscale → 3-channel
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

images = torch.stack([transform(img) for img in all_images])
labels = torch.tensor(all_labels).long()

In [5]:
dataset = TensorDataset(images, labels)
class0_indices = [i for i in range(len(labels)) if labels[i] == 0]
class1_indices = [i for i in range(len(labels)) if labels[i] == 1]

random.seed(42)
sampled_class0 = random.sample(class0_indices, min(2000, len(class0_indices)))
sampled_class1 = random.sample(class1_indices, min(2000, len(class1_indices)))
combined_indices = sampled_class0 + sampled_class1
random.shuffle(combined_indices)

final_dataset = Subset(dataset, combined_indices)
final_loader = DataLoader(final_dataset, batch_size=64, shuffle=False)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')  # DINO ViT-B/16
vit.eval().to(device)

vit_feats = []
y_list = []

with torch.no_grad():
    for imgs, lbls in final_loader:
        imgs = imgs.to(device)
        feats = vit(imgs)  # (N, 768)
        vit_feats.append(feats.cpu())
        y_list.extend(lbls.cpu().tolist())

F = torch.cat(vit_feats, dim=0).numpy().astype(np.float32)
y_labels = np.array(y_list).astype(np.int64)

print("Feature shape:", F.shape)
print("Label shape:", y_labels.shape)

Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_vitbase16_pretrain.pth


100%|██████████| 327M/327M [00:04<00:00, 71.4MB/s]


Feature shape: (3583, 768)
Label shape: (3583,)


In [7]:
kmeans = KMeans(n_clusters=2, random_state=11, max_iter=5000)
kmeans.fit(F)

klabels_trans = kmeans.transform(F)
klabels_trans = klabels_trans / (klabels_trans.sum(axis=1, keepdims=True) + 1e-10)

y_pred = np.argmin(klabels_trans, axis=1)

In [8]:
acc = accuracy_score(y_labels, y_pred)
inv_acc = accuracy_score(y_labels, 1 - y_pred)
if inv_acc > acc:
    acc = inv_acc
    y_pred = 1 - y_pred
    klabels_trans[:, [0, 1]] = klabels_trans[:, [1, 0]]  # swap probs

y_pred_proba = klabels_trans[:, 1]

prec = precision_score(y_labels, y_pred, zero_division=0)
rec = recall_score(y_labels, y_pred, zero_division=0)
f1 = f1_score(y_labels, y_pred, zero_division=0)
logloss = log_loss(y_labels, y_pred_proba)

print("===== KMeans Results (PneumoniaMNIST) =====")
print("Accuracy Score :", acc)
print("Precision Score:", prec)
print("Recall Score   :", rec)
print("F1 Score       :", f1)
print("Log Loss       :", logloss)

===== KMeans Results (PneumoniaMNIST) =====
Accuracy Score : 0.8975718671504326
Precision Score: 0.9652421652421652
Recall Score   : 0.847
F1 Score       : 0.9022636484687084
Log Loss       : 0.7869649510685244


In [9]:
print(y_pred[:20])

[0 1 0 0 1 1 0 1 1 1 0 1 0 0 1 1 0 0 1 0]


In [10]:
print(y_labels[:20])

[0 1 0 0 1 1 0 1 1 1 0 1 0 0 1 1 0 1 1 0]


In [11]:
num_runs = 10

acc_scores, prec_scores, rec_scores, f1_scores, log_losses = [], [], [], [], []

for run in range(num_runs):
    print(f"\n--- Run {run+1}/{num_runs} ---")
    np.random.seed(run)
    torch.manual_seed(run)

    kmeans = KMeans(n_clusters=2, random_state=run, max_iter=5000)
    kmeans.fit(F)

    klabels_trans = kmeans.transform(F)
    klabels_trans = klabels_trans / (klabels_trans.sum(axis=1, keepdims=True) + 1e-10)

    y_pred = np.argmin(klabels_trans, axis=1)

    acc = accuracy_score(y_labels, y_pred)
    inv_acc = accuracy_score(y_labels, 1 - y_pred)
    if inv_acc > acc:
        acc = inv_acc
        y_pred = 1 - y_pred
        klabels_trans[:, [0, 1]] = klabels_trans[:, [1, 0]]

    y_pred_proba = klabels_trans[:, 1]

    prec = precision_score(y_labels, y_pred, zero_division=0)
    rec = recall_score(y_labels, y_pred, zero_division=0)
    f1 = f1_score(y_labels, y_pred, zero_division=0)
    logloss = log_loss(y_labels, y_pred_proba)

    acc_scores.append(acc)
    prec_scores.append(prec)
    rec_scores.append(rec)
    f1_scores.append(f1)
    log_losses.append(logloss)

    print(f"Run {run+1} | Acc: {acc:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f} | "
          f"F1: {f1:.4f} | LogLoss: {logloss:.4f}")

print("\n================ FINAL SUMMARY ================\n")
print(f"{'Metric':>15} | {'Mean':>10} ± {'Std':<10}")
print("-" * 50)
print(f"{'Accuracy':>15} | {np.mean(acc_scores):.4f} ± {np.std(acc_scores):.4f}")
print(f"{'Precision':>15} | {np.mean(prec_scores):.4f} ± {np.std(prec_scores):.4f}")
print(f"{'Recall':>15} | {np.mean(rec_scores):.4f} ± {np.std(rec_scores):.4f}")
print(f"{'F1 Score':>15} | {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print(f"{'Log Loss':>15} | {np.mean(log_losses):.4f} ± {np.std(log_losses):.4f}")


--- Run 1/10 ---
Run 1 | Acc: 0.8976 | Prec: 0.9652 | Rec: 0.8470 | F1: 0.9023 | LogLoss: 0.7870

--- Run 2/10 ---
Run 2 | Acc: 0.8959 | Prec: 0.9667 | Rec: 0.8425 | F1: 0.9003 | LogLoss: 0.7870

--- Run 3/10 ---
Run 3 | Acc: 0.9012 | Prec: 0.9703 | Rec: 0.8490 | F1: 0.9056 | LogLoss: 0.7872

--- Run 4/10 ---
Run 4 | Acc: 0.9012 | Prec: 0.9703 | Rec: 0.8490 | F1: 0.9056 | LogLoss: 0.7872

--- Run 5/10 ---
Run 5 | Acc: 0.8948 | Prec: 0.9661 | Rec: 0.8410 | F1: 0.8992 | LogLoss: 0.7869

--- Run 6/10 ---
Run 6 | Acc: 0.9012 | Prec: 0.9703 | Rec: 0.8490 | F1: 0.9056 | LogLoss: 0.7872

--- Run 7/10 ---
Run 7 | Acc: 0.9012 | Prec: 0.9703 | Rec: 0.8490 | F1: 0.9056 | LogLoss: 0.7872

--- Run 8/10 ---
Run 8 | Acc: 0.8951 | Prec: 0.9667 | Rec: 0.8410 | F1: 0.8995 | LogLoss: 0.7869

--- Run 9/10 ---
Run 9 | Acc: 0.9012 | Prec: 0.9703 | Rec: 0.8490 | F1: 0.9056 | LogLoss: 0.7872

--- Run 10/10 ---
Run 10 | Acc: 0.9012 | Prec: 0.9703 | Rec: 0.8490 | F1: 0.9056 | LogLoss: 0.7872


         Metric 