In [22]:
# from google.colab import drive
# drive.mount('/content/drive')

In [23]:
# !pip install -q torch_geometric
# !pip install -q class_resolver
# !pip3 install pymatting


In [24]:
import numpy as np
import torch
import random
import copy
import scipy.sparse as sp

from torch.utils.data import TensorDataset, DataLoader, Subset
from torchvision import models
import torch.nn as nn
import torch.nn.functional as nnFn
# torch-geometric imports
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from sklearn.manifold import TSNE

In [25]:
import torch
print("CUDA available:", torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

CUDA available: True
GPU Name: NVIDIA RTX A4000


In [26]:
data = np.load('/home/snu/Downloads/breastmnist_224.npz', allow_pickle=True)

In [27]:
all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

images = all_images.astype(np.float32) / 255.0
images = np.repeat(images[:, None, :, :], 3, axis=1)  # (N,3,224,224)

X = torch.tensor(images)
y = torch.tensor(all_labels).long()
print("Images, labels shapes:", X.shape, y.shape)

dataset = TensorDataset(X, y)
class0_indices = [i for i in range(len(y)) if y[i] == 0]
class1_indices = [i for i in range(len(y)) if y[i] == 1]

random.seed(42)
sampled_class0 = random.sample(class0_indices, min(1000, len(class0_indices)))
sampled_class1 = random.sample(class1_indices, min(1000, len(class1_indices)))

combined_indices = sampled_class0 + sampled_class1
random.shuffle(combined_indices)

final_dataset = Subset(dataset, combined_indices)
final_loader = DataLoader(final_dataset, batch_size=32, shuffle=False)

Images, labels shapes: torch.Size([780, 3, 224, 224]) torch.Size([780])


In [28]:
import torch
import timm
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vit.eval().to(device)

vit_feats = []
y_list = []

with torch.no_grad():
    for imgs, lbls in final_loader:
        imgs = imgs.to(device)
        feats = vit(imgs)
        vit_feats.append(feats.cpu())
        y_list.extend(lbls.cpu().tolist())

F = torch.cat(vit_feats, dim=0).numpy().astype(np.float32)
y_labels = np.array(y_list).astype(np.int64)

print("Feature shape:", F.shape)
print("Label shape:", y_labels.shape)
features = F

Using cache found in /home/snu/.cache/torch/hub/facebookresearch_dino_main


Feature shape: (780, 768)
Label shape: (780,)


In [29]:
class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

In [30]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ):
        super(GCNEncoder, self).__init__()
        self.device = device
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gcn1(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [31]:
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(GCN, self).__init__()
        self.device = device
        self.num_clusters = num_clusters
        self.cut = 0   # always modularity

        self.online_encoder = GCNEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_size=hidden_dim)


        self.loss = self.cut_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)
        return S, logits

    def cut_loss(self, A, S):
        S = nnFn.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)

        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)

        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(self.num_clusters, device=self.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))

        return mincut_loss + ortho_loss

In [32]:
def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    if cut == 0:
        W = np.where(W >= alpha, 1, 0).astype(np.float32)
        W = (W / W.max()).astype(np.float32)
    else:
        W = W - (W.max() / alpha)
    return W

In [33]:
def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight

In [34]:
print(features.shape, features.dtype)
cut = 0 # Consider n-cut loss OR Modularity loss (by default cut = 0)
alpha = 0.73 # Edge creation Threshold
device = 'cuda' if torch.cuda.is_available() else 'cpu'
K = 2  # Number of clusters
np.random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feats_dim = 768
K = 2
W0 = create_adj(features, 0, alpha)
node_feats, edge_index, _ = load_data(W0, features)
data0 = Data(x=node_feats, edge_index=edge_index).to(device)
A1 = torch.from_numpy(W0).float().to(device)
print(data0)

(780, 768) float32
Data(x=[780, 768], edge_index=[2, 266936])


In [35]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

model = GCN(feats_dim, 256, K, device, "ELU").to(device)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

num_epochs = 2500

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    S, logits = model(data0)
    unsup_loss = model.loss(A1, logits)

    total_loss = unsup_loss
    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {total_loss:.4f}")


Epoch 0 | Loss: -0.2297
Epoch 100 | Loss: -0.5975
Epoch 200 | Loss: -0.6071
Epoch 300 | Loss: -0.6091
Epoch 400 | Loss: -0.6103
Epoch 500 | Loss: -0.6107
Epoch 600 | Loss: -0.6120
Epoch 700 | Loss: -0.6129
Epoch 800 | Loss: -0.6128
Epoch 900 | Loss: -0.6131
Epoch 1000 | Loss: -0.6130
Epoch 1100 | Loss: -0.6130
Epoch 1200 | Loss: -0.6132
Epoch 1300 | Loss: -0.6132
Epoch 1400 | Loss: -0.6140
Epoch 1500 | Loss: -0.6131
Epoch 1600 | Loss: -0.6133
Epoch 1700 | Loss: -0.6139
Epoch 1800 | Loss: -0.6128
Epoch 1900 | Loss: -0.6131
Epoch 2000 | Loss: -0.6132
Epoch 2100 | Loss: -0.6137
Epoch 2200 | Loss: -0.6141
Epoch 2300 | Loss: -0.6136
Epoch 2400 | Loss: -0.6137


In [36]:
model.eval()
with torch.no_grad():
    S, logits = model(data0)
    y_pred = torch.argmax(logits, dim=1).cpu().numpy()
    y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()


acc_score = accuracy_score(y_labels, y_pred)
acc_score_inverted = accuracy_score(y_labels, 1 - y_pred)

if acc_score_inverted > acc_score:
    acc_score = acc_score_inverted
    y_pred = 1 - y_pred

prec_score = precision_score(y_labels, y_pred)
rec_score = recall_score(y_labels, y_pred)
f1 = f1_score(y_labels, y_pred)
log_loss_value = log_loss(y_labels, y_pred_proba)

print("Accuracy:", acc_score)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.5935897435897436
Precision: 0.8154613466334164
Recall: 0.5736842105263158
F1: 0.6735324407826982
Log Loss: 3.565511247664375


In [38]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

# Store the initial extracted features and labels (from the sampled dataset)
initial_extracted_features = features
initial_sampled_labels = y_labels

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    # Set seeds for reproducibility
    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    # Shuffle features and labels for the current run
    perm = np.random.permutation(initial_extracted_features.shape[0])
    current_run_features = initial_extracted_features[perm]
    current_run_labels = initial_sampled_labels[perm]

    cut = 0
    alpha = 0.73
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 768
    K = 2

    W0 = create_adj(current_run_features, cut, alpha)
    node_feats, edge_index, _ = load_data(W0, current_run_features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = GCN(feats_dim, 256, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()


        S, logits = model(data0) # S for probabilities, logits for loss
        unsup_loss = model.loss(A1, logits)

        unsup_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {unsup_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, _ = model(data0) # Get probabilities S
        y_pred = torch.argmax(S, dim=1).cpu().numpy()
        y_pred_proba = S.cpu().numpy()

    acc_score = accuracy_score(current_run_labels, y_pred)
    acc_score_inverted = accuracy_score(current_run_labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(current_run_labels, y_pred)
    rec_score = recall_score(current_run_labels, y_pred)
    f1 = f1_score(current_run_labels, y_pred)
    log_loss_value = log_loss(current_run_labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })

accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs ===telek")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))



Epoch 0 | Loss: -0.2352
Epoch 1000 | Loss: -0.6019
Epoch 2000 | Loss: -0.6017
Epoch 3000 | Loss: -0.6028
Epoch 4000 | Loss: -0.6022
Accuracy: 0.5846153846153846 Precision: 0.8121827411167513 Recall: 0.5614035087719298 F1: 0.6639004149377593

Epoch 0 | Loss: -0.2547
Epoch 1000 | Loss: -0.6074
Epoch 2000 | Loss: -0.6085
Epoch 3000 | Loss: -0.6088
Epoch 4000 | Loss: -0.6086
Accuracy: 0.5089743589743589 Precision: 0.7441253263707572 Recall: 0.5 F1: 0.5981112277019937

Epoch 0 | Loss: -0.2336
Epoch 1000 | Loss: -0.6152
Epoch 2000 | Loss: -0.6156
Epoch 3000 | Loss: -0.6156
Epoch 4000 | Loss: -0.6158
Accuracy: 0.5397435897435897 Precision: 0.7670886075949367 Recall: 0.531578947368421 F1: 0.627979274611399

Epoch 0 | Loss: -0.2322
Epoch 1000 | Loss: -0.6023
Epoch 2000 | Loss: -0.6025
Epoch 3000 | Loss: -0.6028
Epoch 4000 | Loss: -0.6024
Accuracy: 0.5858974358974359 Precision: 0.8142493638676844 Recall: 0.5614035087719298 F1: 0.6645898234683282

Epoch 0 | Loss: -0.2394
Epoch 1000 | Loss: -0.61