In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q torch_geometric
!pip install -q class_resolver
!pip3 install pymatting




In [4]:
import numpy as np
import torch
import random
import copy
import scipy.sparse as sp

from torch.utils.data import TensorDataset, DataLoader, Subset
from torchvision import models
import torch.nn as nn
import torch.nn.functional as nnFn

# torch-geometric imports
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from sklearn.manifold import TSNE

In [5]:
import torch
print("CUDA available:", torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

CUDA available: True
GPU Name: NVIDIA A100-SXM4-80GB


In [6]:
data = np.load('/content/drive/MyDrive/TejaswiAbburi_va797/Dataset/Medmnist_data/pneumoniamnist_224.npz', allow_pickle=True)

In [7]:
all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

images = all_images.astype(np.float32) / 255.0
images = np.repeat(images[:, None, :, :], 3, axis=1)  # (N,3,224,224)

X = torch.tensor(images)
y = torch.tensor(all_labels).long()
print("Images, labels shapes:", X.shape, y.shape)

dataset = TensorDataset(X, y)
class0_indices = [i for i in range(len(y)) if y[i] == 0]
class1_indices = [i for i in range(len(y)) if y[i] == 1]

random.seed(42)
sampled_class0 = random.sample(class0_indices, min(2000, len(class0_indices)))
sampled_class1 = random.sample(class1_indices, min(2000, len(class1_indices)))

combined_indices = sampled_class0 + sampled_class1
random.shuffle(combined_indices)

final_dataset = Subset(dataset, combined_indices)
final_loader = DataLoader(final_dataset, batch_size=64, shuffle=False)

Images, labels shapes: torch.Size([5856, 3, 224, 224]) torch.Size([5856])


In [8]:
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Identity()
resnet = resnet.to(device)
resnet.eval()

resnet_feats = []
y_list = []
with torch.no_grad():
    for imgs, labels in final_loader:
        imgs = imgs.to(device)
        feats = resnet(imgs)
        resnet_feats.append(feats.cpu())
        y_list.extend(labels.cpu().tolist())

features = torch.cat(resnet_feats, dim=0).numpy().astype(np.float32)  # shape (N, feat_dim)
y_labels = np.array(y_list).astype(np.float32)
print("Feature shape:", features.shape, "Label shape:", y_labels.shape)



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 224MB/s]


Feature shape: (3583, 512) Label shape: (3583,)


In [9]:
class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

In [10]:
class GATEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ, heads=4):
        super(GATEncoder, self).__init__()
        self.device = device
        self.gat1 = GATConv(input_dim, hidden_dim // heads, heads=heads)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gat1(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [11]:
class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(GAT, self).__init__()
        self.device = device
        self.num_clusters = num_clusters
        self.cut = 0   # always modularity

        self.online_encoder = GATEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # only modularity loss
        self.loss = self.modularity_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)

        return S, logits

    def modularity_loss(self, A, S):
        C = nnFn.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)

        I_S = torch.eye(self.num_clusters, device=self.device)
        k = torch.norm(I_S)
        n = S.shape[0]

        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1

        return modularity_term + collapse_reg_term

In [12]:
def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    if cut == 0:
        W = np.where(W >= alpha, 1, 0).astype(np.float32)
        W = (W / W.max()).astype(np.float32)
    else:
        W = W - (W.max() / alpha)
    return W

In [13]:
def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight

In [14]:
print(features.shape, features.dtype)
cut = 0 # Consider n-cut loss OR Modularity loss (by default cut = 0)
alpha = 0.9 # Edge creation Threshold
device = 'cuda' if torch.cuda.is_available() else 'cpu'
K = 2  # Number of clusters
np.random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feats_dim = 512
K = 2
W0 = create_adj(features, 0, alpha)
node_feats, edge_index, _ = load_data(W0, features)
data0 = Data(x=node_feats, edge_index=edge_index).to(device)
A1 = torch.from_numpy(W0).float().to(device)
print(data0)

(3583, 512) float32
Data(x=[3583, 512], edge_index=[2, 1642249])


In [15]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

model = GAT(feats_dim, 256, K, device, "ELU").to(device)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

num_epochs = 5000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    S, logits = model(data0)
    unsup_loss = model.loss(A1, logits)

    total_loss = unsup_loss
    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {total_loss:.4f}")


Epoch 0 | Loss: -0.2756
Epoch 100 | Loss: -0.4643
Epoch 200 | Loss: -0.4659
Epoch 300 | Loss: -0.4663
Epoch 400 | Loss: -0.4667
Epoch 500 | Loss: -0.4668
Epoch 600 | Loss: -0.4670
Epoch 700 | Loss: -0.4670
Epoch 800 | Loss: -0.4671
Epoch 900 | Loss: -0.4671
Epoch 1000 | Loss: -0.4672
Epoch 1100 | Loss: -0.4671
Epoch 1200 | Loss: -0.4672
Epoch 1300 | Loss: -0.4672
Epoch 1400 | Loss: -0.4672
Epoch 1500 | Loss: -0.4672
Epoch 1600 | Loss: -0.4672
Epoch 1700 | Loss: -0.4672
Epoch 1800 | Loss: -0.4672
Epoch 1900 | Loss: -0.4672
Epoch 2000 | Loss: -0.4672
Epoch 2100 | Loss: -0.4672
Epoch 2200 | Loss: -0.4672
Epoch 2300 | Loss: -0.4672
Epoch 2400 | Loss: -0.4673
Epoch 2500 | Loss: -0.4672
Epoch 2600 | Loss: -0.4672
Epoch 2700 | Loss: -0.4672
Epoch 2800 | Loss: -0.4672
Epoch 2900 | Loss: -0.4672
Epoch 3000 | Loss: -0.4672
Epoch 3100 | Loss: -0.4672
Epoch 3200 | Loss: -0.4672
Epoch 3300 | Loss: -0.4672
Epoch 3400 | Loss: -0.4672
Epoch 3500 | Loss: -0.4672
Epoch 3600 | Loss: -0.4672
Epoch 3700 | 

In [16]:
model.eval()
with torch.no_grad():
    S, logits = model(data0)
    y_pred = torch.argmax(logits, dim=1).cpu().numpy()
    y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()


acc_score = accuracy_score(y_labels, y_pred)
acc_score_inverted = accuracy_score(y_labels, 1 - y_pred)

if acc_score_inverted > acc_score:
    acc_score = acc_score_inverted
    y_pred = 1 - y_pred

prec_score = precision_score(y_labels, y_pred)
rec_score = recall_score(y_labels, y_pred)
f1 = f1_score(y_labels, y_pred)
log_loss_value = log_loss(y_labels, y_pred_proba)

print("Accuracy:", acc_score)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.912922132291376
Precision: 0.9627192982456141
Recall: 0.878
F1: 0.9184100418410042
Log Loss: 8.042749452230495


In [17]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

# Store the initial extracted features and labels (from the sampled dataset)
initial_extracted_features = features
initial_sampled_labels = y_labels

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    # Set seeds for reproducibility
    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    # Shuffle features and labels for the current run
    perm = np.random.permutation(initial_extracted_features.shape[0])
    current_run_features = initial_extracted_features[perm]
    current_run_labels = initial_sampled_labels[perm]

    cut = 0
    alpha = 0.9
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 512
    K = 2

    W0 = create_adj(current_run_features, cut, alpha)
    node_feats, edge_index, _ = load_data(W0, current_run_features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = GAT(feats_dim, 256, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()


        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        unsup_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {unsup_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(current_run_labels, y_pred)
    acc_score_inverted = accuracy_score(current_run_labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(current_run_labels, y_pred)
    rec_score = recall_score(current_run_labels, y_pred)
    f1 = f1_score(current_run_labels, y_pred)
    log_loss_value = log_loss(current_run_labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })

accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs ===telek")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))


Epoch 0 | Loss: -0.2841
Epoch 1000 | Loss: -0.4672
Epoch 2000 | Loss: -0.4673
Epoch 3000 | Loss: -0.4673
Epoch 4000 | Loss: -0.4673
Accuracy: 0.9120848451018699 Precision: 0.9616438356164384 Recall: 0.8775 F1: 0.9176470588235294

Epoch 0 | Loss: -0.2822
Epoch 1000 | Loss: -0.4673
Epoch 2000 | Loss: -0.4674
Epoch 3000 | Loss: -0.4674
Epoch 4000 | Loss: -0.4674
Accuracy: 0.9118057493720346 Precision: 0.9621295279912184 Recall: 0.8765 F1: 0.9173207744636316

Epoch 0 | Loss: -0.2839
Epoch 1000 | Loss: -0.4673
Epoch 2000 | Loss: -0.4673
Epoch 3000 | Loss: -0.4673
Epoch 4000 | Loss: -0.4673
Accuracy: 0.9118057493720346 Precision: 0.9631463146314632 Recall: 0.8755 F1: 0.9172341540073337

Epoch 0 | Loss: -0.2838
Epoch 1000 | Loss: -0.4672
Epoch 2000 | Loss: -0.4673
Epoch 3000 | Loss: -0.4673
Epoch 4000 | Loss: -0.4673
Accuracy: 0.9067820262349986 Precision: 0.9587004405286343 Recall: 0.8705 F1: 0.9124737945492662

Epoch 0 | Loss: -0.2839
Epoch 1000 | Loss: -0.4672
Epoch 2000 | Loss: -0.4673
E