In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install -q torch_geometric
!pip install -q class_resolver
!pip3 install pymatting




In [2]:
import numpy as np
import torch
import random
import copy
import scipy.sparse as sp

from torch.utils.data import TensorDataset, DataLoader, Subset
from torchvision import models
import torch.nn as nn
import torch.nn.functional as nnFn

# torch-geometric imports
from torch_geometric.nn import ARMAConv
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from sklearn.manifold import TSNE

In [3]:
import torch
print("CUDA available:", torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

CUDA available: True
GPU Name: NVIDIA RTX A4000


In [4]:
data = np.load('/home/snu/Downloads/breastmnist_224.npz', allow_pickle=True)

In [5]:
all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

images = all_images.astype(np.float32) / 255.0
images = np.repeat(images[:, None, :, :], 3, axis=1)  # (N,3,224,224)

X = torch.tensor(images)
y = torch.tensor(all_labels).long()
print("Images, labels shapes:", X.shape, y.shape)

dataset = TensorDataset(X, y)
class0_indices = [i for i in range(len(y)) if y[i] == 0]
class1_indices = [i for i in range(len(y)) if y[i] == 1]

random.seed(42)
sampled_class0 = random.sample(class0_indices, min(1000, len(class0_indices)))
sampled_class1 = random.sample(class1_indices, min(1000, len(class1_indices)))

combined_indices = sampled_class0 + sampled_class1
random.shuffle(combined_indices)

final_dataset = Subset(dataset, combined_indices)
final_loader = DataLoader(final_dataset, batch_size=32, shuffle=False)

Images, labels shapes: torch.Size([780, 3, 224, 224]) torch.Size([780])


In [6]:
import torch
import timm
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vit.eval().to(device)

vit_feats = []
y_list = []

with torch.no_grad():
    for imgs, lbls in final_loader:
        imgs = imgs.to(device)
        feats = vit(imgs)
        vit_feats.append(feats.cpu())
        y_list.extend(lbls.cpu().tolist())

F = torch.cat(vit_feats, dim=0).numpy().astype(np.float32)
y_labels = np.array(y_list).astype(np.int64)

print("Feature shape:", F.shape)
print("Label shape:", y_labels.shape)
features = F

Using cache found in /home/snu/.cache/torch/hub/facebookresearch_dino_main


Feature shape: (780, 768)
Label shape: (780,)


In [7]:
class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# class ARMAEncoder(torch.nn.Module):
#     def __init__(self, input_dim, hidden_dim, device, activ, stacks=3, layers=3):
#         super(ARMAEncoder, self).__init__()
#         self.device = device
#         self.arma = ARMAConv(input_dim, hidden_dim, num_stacks=stacks, num_layers=layers)
#         self.batchnorm = nn.BatchNorm1d(hidden_dim)
#         self.dropout = nn.Dropout(0.3)
#         self.mlp = nn.Linear(hidden_dim, hidden_dim)

#     def forward(self, data):
#         x, edge_index = data.x, data.edge_index
#         x = self.arma(x, edge_index)
#         x = self.dropout(x)
#         x = self.batchnorm(x)
#         logits = self.mlp(x)
#         return logits

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch_geometric.nn import ARMAConv

class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ="ELU",
                 num_stacks=1, num_layers=3):
        super(ARMAEncoder, self).__init__()
        self.device = device

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "ELU": nnFn.elu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        # Layer 1
        self.arma1 = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            act=self.act,
            shared_weights=True,
            dropout=0.25
        )
        self.bn1 = nn.BatchNorm1d(hidden_dim)

        # Layer 2
        self.arma2 = ARMAConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            act=self.act,
            shared_weights=True,
            dropout=0.25
        )
        self.bn2 = nn.BatchNorm1d(hidden_dim)

        # Layer 3
        self.arma3 = ARMAConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,
            num_layers=num_layers,
            act=self.act,
            shared_weights=True,
            dropout=0.25
        )
        self.bn3 = nn.BatchNorm1d(hidden_dim)

        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.arma1(x, edge_index)
        x = self.bn1(x)
        x = self.act(x)
        x = self.dropout(x)

        x = self.arma2(x, edge_index)
        x = self.bn2(x)
        x = self.act(x)
        x = self.dropout(x)

        x = self.arma3(x, edge_index)
        x = self.bn3(x)
        x = self.act(x)
        x = self.dropout(x)

        logits = self.mlp(x)
        return logits


In [13]:
class ARMA(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ):
        super(ARMA, self).__init__()
        self.device = device
        self.num_clusters = num_clusters

        self.online_encoder = ARMAEncoder(input_dim, hidden_dim, device, activ)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # use cut loss instead of modularity
        self.loss = self.cut_loss

    def forward(self, data):
        x = self.online_encoder(data)
        logits = self.online_predictor(x)
        S = nnFn.softmax(logits, dim=1)
        return S, logits

    def cut_loss(self, A, S):
        S = nnFn.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)

        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)

        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(self.num_clusters, device=self.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))

        return mincut_loss + ortho_loss

In [14]:
def create_adj(F, cut, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    if cut == 0:
        W = np.where(W >= alpha, 1, 0).astype(np.float32)
        W = (W / W.max()).astype(np.float32)
    else:
        W = W - (W.max() / alpha)
    return W

In [15]:
def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats)
    edge_index = torch.from_numpy(np.array(np.nonzero((adj > 0))))
    row, col = edge_index
    edge_weight = torch.from_numpy(adj[row, col])
    return node_feats, edge_index, edge_weight

In [16]:
print(features.shape, features.dtype)
cut = 0 # Consider n-cut loss OR Modularity loss (by default cut = 0)
alpha = 0.73 # Edge creation Threshold
device = 'cuda' if torch.cuda.is_available() else 'cpu'
K = 2  # Number of clusters
np.random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feats_dim = 768
K = 2
W0 = create_adj(features, 0, alpha)
node_feats, edge_index, _ = load_data(W0, features)
data0 = Data(x=node_feats, edge_index=edge_index).to(device)
A1 = torch.from_numpy(W0).float().to(device)
print(data0)

(780, 768) float32
Data(x=[780, 768], edge_index=[2, 266936])


In [17]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

model = ARMA(feats_dim, 256, K, device, "ELU").to(device)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

num_epochs = 5000

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    S, logits = model(data0)
    unsup_loss = model.loss(A1, logits)

    total_loss = unsup_loss
    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {total_loss:.4f}")


Epoch 0 | Loss: -0.2412
Epoch 100 | Loss: -0.6002
Epoch 200 | Loss: -0.6141
Epoch 300 | Loss: -0.6142
Epoch 400 | Loss: -0.6185
Epoch 500 | Loss: -0.6180
Epoch 600 | Loss: -0.6200
Epoch 700 | Loss: -0.6207
Epoch 800 | Loss: -0.6199
Epoch 900 | Loss: -0.6203
Epoch 1000 | Loss: -0.6213
Epoch 1100 | Loss: -0.6210
Epoch 1200 | Loss: -0.6203
Epoch 1300 | Loss: -0.6213
Epoch 1400 | Loss: -0.6211
Epoch 1500 | Loss: -0.6210
Epoch 1600 | Loss: -0.6211
Epoch 1700 | Loss: -0.6200
Epoch 1800 | Loss: -0.6211
Epoch 1900 | Loss: -0.6208
Epoch 2000 | Loss: -0.6211
Epoch 2100 | Loss: -0.6210
Epoch 2200 | Loss: -0.6203
Epoch 2300 | Loss: -0.6214
Epoch 2400 | Loss: -0.6214
Epoch 2500 | Loss: -0.6208
Epoch 2600 | Loss: -0.6206
Epoch 2700 | Loss: -0.6199
Epoch 2800 | Loss: -0.6211
Epoch 2900 | Loss: -0.6211
Epoch 3000 | Loss: -0.6214
Epoch 3100 | Loss: -0.6217
Epoch 3200 | Loss: -0.6206
Epoch 3300 | Loss: -0.6213
Epoch 3400 | Loss: -0.6206
Epoch 3500 | Loss: -0.6213
Epoch 3600 | Loss: -0.6207
Epoch 3700 | 

In [18]:
model.eval()
with torch.no_grad():
    S, logits = model(data0)
    y_pred = torch.argmax(logits, dim=1).cpu().numpy()
    y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()


acc_score = accuracy_score(y_labels, y_pred)
acc_score_inverted = accuracy_score(y_labels, 1 - y_pred)

if acc_score_inverted > acc_score:
    acc_score = acc_score_inverted
    y_pred = 1 - y_pred

prec_score = precision_score(y_labels, y_pred)
rec_score = recall_score(y_labels, y_pred)
f1 = f1_score(y_labels, y_pred)
log_loss_value = log_loss(y_labels, y_pred_proba)

print("Accuracy:", acc_score)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.5576923076923077
Precision: 0.7494456762749445
Recall: 0.5929824561403508
F1: 0.6620959843290891
Log Loss: 6.528569908284446


In [19]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW

results = []

# Store the initial extracted features and labels (from the sampled dataset)
initial_extracted_features = features
initial_sampled_labels = y_labels

for run_seed in range(10):
    print("\n================ Run", run_seed, "================")

    # Set seeds for reproducibility
    np.random.seed(run_seed)
    torch.manual_seed(run_seed)
    random.seed(run_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(run_seed)

    # Shuffle features and labels for the current run
    perm = np.random.permutation(initial_extracted_features.shape[0])
    current_run_features = initial_extracted_features[perm]
    current_run_labels = initial_sampled_labels[perm]

    cut = 0
    alpha = 0.73
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    feats_dim = 768
    K = 2

    W0 = create_adj(current_run_features, cut, alpha)
    node_feats, edge_index, _ = load_data(W0, current_run_features)
    data0 = Data(x=node_feats, edge_index=edge_index).to(device)
    A1 = torch.from_numpy(W0).float().to(device)

    model = ARMA(feats_dim, 256, K, device, "ELU").to(device)
    optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

    num_epochs = 5000

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()


        S, logits = model(data0)
        unsup_loss = model.loss(A1, logits)

        unsup_loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss: {unsup_loss:.4f}")

    model.eval()
    with torch.no_grad():
        S, logits = model(data0)
        y_pred = torch.argmax(logits, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits, dim=1).cpu().numpy()

    acc_score = accuracy_score(current_run_labels, y_pred)
    acc_score_inverted = accuracy_score(current_run_labels, 1 - y_pred)

    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec_score = precision_score(current_run_labels, y_pred)
    rec_score = recall_score(current_run_labels, y_pred)
    f1 = f1_score(current_run_labels, y_pred)
    log_loss_value = log_loss(current_run_labels, y_pred_proba)

    print("Accuracy:", acc_score, "Precision:", prec_score, "Recall:", rec_score, "F1:", f1)

    results.append({
        "seed": run_seed,
        "accuracy": acc_score,
        "precision": prec_score,
        "recall": rec_score,
        "f1": f1,
        "log_loss": log_loss_value
    })

accs = [r["accuracy"] for r in results]
precisions = [r["precision"] for r in results]
recalls = [r["recall"] for r in results]
f1s = [r["f1"] for r in results]

print("\n===== Final Results across 10 runs ===telek")
print("Accuracy: mean=", np.mean(accs), "std=", np.std(accs))
print("Precision: mean=", np.mean(precisions), "std=", np.std(precisions))
print("Recall: mean=", np.mean(recalls), "std=", np.std(recalls))
print("F1: mean=", np.mean(f1s), "std=", np.std(f1s))


Epoch 0 | Loss: -0.2393
Epoch 1000 | Loss: -0.6193
Epoch 2000 | Loss: -0.6201
Epoch 3000 | Loss: -0.6200
Epoch 4000 | Loss: -0.6199
Accuracy: 0.5884615384615385 Precision: 0.7588357588357588 Recall: 0.6403508771929824 F1: 0.6945765937202664

Epoch 0 | Loss: -0.2431
Epoch 1000 | Loss: -0.6203
Epoch 2000 | Loss: -0.6216
Epoch 3000 | Loss: -0.6212
Epoch 4000 | Loss: -0.6208
Accuracy: 0.5897435897435898 Precision: 0.757201646090535 Recall: 0.6456140350877193 F1: 0.696969696969697

Epoch 0 | Loss: -0.2407
Epoch 1000 | Loss: -0.6214
Epoch 2000 | Loss: -0.6209
Epoch 3000 | Loss: -0.6208
Epoch 4000 | Loss: -0.6211
Accuracy: 0.5615384615384615 Precision: 0.7556053811659192 Recall: 0.5912280701754385 F1: 0.6633858267716536

Epoch 0 | Loss: -0.2454
Epoch 1000 | Loss: -0.6063
Epoch 2000 | Loss: -0.6076
Epoch 3000 | Loss: -0.6066
Epoch 4000 | Loss: -0.6075
Accuracy: 0.5858974358974359 Precision: 0.7990314769975787 Recall: 0.5789473684210527 F1: 0.671414038657172

Epoch 0 | Loss: -0.2466
Epoch 1000