In [None]:
!pip install -q torch_geometric
!pip install -q class_resolver
!pip3 install pymatting



In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from torch.utils.data import TensorDataset, DataLoader, Subset
import random
from torchvision import models

In [3]:
data = np.load('/home/snu/Downloads/breastmnist_224.npz', allow_pickle=True)

# Combine train, val, and test sets
all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

In [4]:
images = all_images.astype(np.float32) / 255.0
images = np.repeat(images[:, None, :, :], 3, axis=1)  # Convert to 3 channels (N, 3, 224, 224)

# Convert to torch tensors
X = torch.tensor(images)
y = torch.tensor(all_labels).long()
print(X.shape, y.shape)

torch.Size([780, 3, 224, 224]) torch.Size([780])


In [5]:
dataset = TensorDataset(X, y)
class0_indices = [i for i in range(len(y)) if y[i] == 0]
class1_indices = [i for i in range(len(y)) if y[i] == 1]

random.seed(42)
sampled_class0 = random.sample(class0_indices, min(1000, len(class0_indices)))
sampled_class1 = random.sample(class1_indices, min(1000, len(class1_indices)))

combined_indices = sampled_class0 + sampled_class1
random.shuffle(combined_indices)

# Final subset
final_dataset = Subset(dataset, combined_indices)
final_loader = DataLoader(final_dataset, batch_size=64, shuffle=False)

In [6]:
import torch
import timm
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vit.eval().to(device)

vit_feats = []
y_list = []

with torch.no_grad():
    for imgs, lbls in final_loader:
        imgs = imgs.to(device)
        feats = vit(imgs)
        vit_feats.append(feats.cpu())
        y_list.extend(lbls.cpu().tolist())

F = torch.cat(vit_feats, dim=0).numpy().astype(np.float32)
y_labels = np.array(y_list).astype(np.int64)

print("Feature shape:", F.shape)
print("Label shape:", y_labels.shape)


Using cache found in /home/snu/.cache/torch/hub/facebookresearch_dino_main


Feature shape: (780, 768)
Label shape: (780,)


In [7]:
def create_adj(F, alpha=1):
    F_norm = F / np.linalg.norm(F, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)
    W = (W >= alpha).astype(np.float32)
    return W

In [8]:
def asymmetrize_random(adj_matrix, seed=None):
    """
    Randomly orient each undirected edge from a symmetric adjacency matrix.
    """
    adj = np.array(adj_matrix, dtype=np.float32)
    n = adj.shape[0]
    asym = np.zeros((n, n), dtype=np.float32)
    rng = np.random.default_rng(seed)

    for i in range(n):
        for j in range(i + 1, n):
            if adj[i, j]:
                if rng.random() < 0.5:
                    asym[i, j] = adj[i, j]
                else:
                    asym[j, i] = adj[i, j]

    return asym

In [9]:
def load_data(adj, node_feats):
    node_feats = torch.from_numpy(node_feats).float()
    edge_index = torch.from_numpy(np.array(np.nonzero(adj))).long()
    return node_feats, edge_index

In [10]:
features = F # Use ResNet embeddings

device = 'cuda' if torch.cuda.is_available() else 'cpu'

W0 = create_adj(features, alpha=0.73)
# W_asym = asymmetrize_random(W0, seed=42)
node_feats, edge_index = load_data(W0, features)
data = Data(x=node_feats, edge_index=edge_index).to(device)
A = torch.from_numpy(W0).to(device)
print(data)

Data(x=[780, 768], edge_index=[2, 266936])


In [11]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv

class GATEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, heads, device, activ=nn.ELU()):
        super(GATEncoder, self).__init__()
        self.device = device

        # Layer 1
        self.gat1 = GATConv(input_dim, hidden_dim, heads=heads, concat=True)
        self.bn1 = nn.BatchNorm1d(hidden_dim * heads)

        # Layer 2
        self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True)
        self.bn2 = nn.BatchNorm1d(hidden_dim * heads)

        # Layer 3
        self.gat3 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True)
        self.bn3 = nn.BatchNorm1d(hidden_dim * heads)

        self.dropout = nn.Dropout(0.25)
        self.act = activ
        self.mlp = nn.Linear(hidden_dim * heads, hidden_dim * heads)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.gat1(x, edge_index)
        x = self.bn1(x)
        x = self.act(x)
        x = self.dropout(x)

        x = self.gat2(x, edge_index)
        x = self.bn2(x)
        x = self.act(x)
        x = self.dropout(x)

        x = self.gat3(x, edge_index)
        x = self.bn3(x)
        x = self.act(x)
        x = self.dropout(x)

        logits = self.mlp(x)
        return logits


In [12]:
class AvgReadout(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

In [13]:
class Discriminator(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 0).expand_as(h_pl)
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
        logits = torch.cat((sc_1, sc_2), 0)
        return logits

In [14]:
class DGI(nn.Module):
    def __init__(self, n_in, n_h, heads, dropout=0.25):
        super().__init__()
        self.gat1 = GATEncoder(n_in, n_h, heads=heads, device='cuda' if torch.cuda.is_available() else 'cpu', activ=nn.ELU())
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h * heads)

    def forward(self, seq1, seq2, edge_index):
        # Create Data objects for the GATEncoder
        data1 = Data(x=seq1, edge_index=edge_index)
        data2 = Data(x=seq2, edge_index=edge_index)

        h_1 = self.gat1(data1)
        c = self.read(h_1)
        c = self.sigm(c)
        h_2 = self.gat1(data2)
        logits = self.disc(c, h_1, h_2)
        return logits, h_1

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class DGI_with_classifier(DGI):
    def __init__(self, n_in, n_h, heads, n_classes=2, cut=0, dropout=0.25):
        super().__init__(n_in, n_h, heads, dropout=dropout)
        self.classifier = nn.Linear(n_h * heads, n_classes)
        self.cut = cut


    def get_embeddings(self, node_feats, edge_index):
        _, embeddings = self.forward(node_feats, node_feats, edge_index)
        return embeddings

    def cut_loss(self, A, S):
        # ensure tensor
        if isinstance(A, np.ndarray):
            A = torch.from_numpy(A).float().to(S.device)

        S = F.softmax(S, dim=1)   # cluster assignment
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)

        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)

        mincut_loss = -(num / den)

        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(S.shape[1], device=A.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))

        return mincut_loss + ortho_loss

    def modularity_loss(self, A, S):
        # ensure tensor
        if isinstance(A, np.ndarray):
            A = torch.from_numpy(A).float().to(S.device)

        C = F.softmax(S, dim=1)   # cluster assignment
        d = torch.sum(A, dim=1)
        m = torch.sum(A)

        B = A - torch.outer(d, d) / (2 * m)

        I_S = torch.eye(C.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]

        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1

        return modularity_term + collapse_reg_term

    def Reg_loss(self, A, embeddings):
        # classifier output used as soft cluster assignment
        logits = self.classifier(embeddings)

        if self.cut == 1:
            return self.cut_loss(A, logits)
        else:
            return self.modularity_loss(A, logits)

In [16]:
hidden_dim = 256
cut = 1
dropout = 0.25
heads = 1
# make sure adjacency is a tensor on GPU
if isinstance(A, np.ndarray):
    A = torch.from_numpy(A).float().to(device)
else:
    A = A.float().to(device)

model = DGI_with_classifier(features.shape[1], hidden_dim, heads=heads, n_classes=2, cut=cut, dropout=dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
bce_loss = nn.BCEWithLogitsLoss()

num_epochs = 2500
for epoch in range(num_epochs + 1):
    model.train()
    optimizer.zero_grad()

    perm = torch.randperm(node_feats.size(0))
    corrupt_features = node_feats[perm]

    logits, embeddings = model(node_feats.to(device), corrupt_features.to(device), edge_index.to(device))

    lbl = torch.cat([
        torch.ones(node_feats.size(0)),
        torch.zeros(node_feats.size(0))
    ]).to(device)

    dgi_loss = bce_loss(logits.squeeze(), lbl)
    reg_loss = model.Reg_loss(A, embeddings)
    loss = dgi_loss + 0.01 * reg_loss

    if epoch % 500 == 0:
        print(f"Epoch {epoch} | DGI Loss: {dgi_loss.item():.4f} | Reg Loss: {reg_loss.item():.4f} | Total: {loss.item():.4f}")

    loss.backward()
    optimizer.step()


Epoch 0 | DGI Loss: 0.7096 | Reg Loss: -0.2570 | Total: 0.7071
Epoch 500 | DGI Loss: 0.7898 | Reg Loss: -0.5026 | Total: 0.7848
Epoch 1000 | DGI Loss: 0.1945 | Reg Loss: -0.3198 | Total: 0.1913
Epoch 1500 | DGI Loss: 0.5244 | Reg Loss: -0.3516 | Total: 0.5209
Epoch 2000 | DGI Loss: 0.0450 | Reg Loss: -0.2780 | Total: 0.0422
Epoch 2500 | DGI Loss: 0.5423 | Reg Loss: -0.2952 | Total: 0.5393


In [17]:
model.eval()
with torch.no_grad():
    embeddings = model.get_embeddings(node_feats.to(device), edge_index.to(device))
    class_probabilities = F.softmax(model.classifier(embeddings), dim=1).cpu().numpy()

y_pred = np.argmax(class_probabilities, axis=1)

In [18]:
# Extract the true labels for the subset used for prediction
y_subset = y[combined_indices].cpu().numpy()

acc_score = accuracy_score(y_subset, y_pred)
acc_score_inverted = accuracy_score(y_subset, 1 - y_pred)
prec_score = precision_score(y_subset, y_pred)
rec_score = recall_score(y_subset, y_pred)
f1 = f1_score(y_subset, y_pred)
log_loss_value = log_loss(y_subset, class_probabilities)

print("Accuracy:", acc_score)
print("Accuracy (inverted):", acc_score_inverted)
print("Precision:", prec_score)
print("Recall:", rec_score)
print("F1:", f1)
print("Log Loss:", log_loss_value)

Accuracy: 0.7153846153846154
Accuracy (inverted): 0.2846153846153846
Precision: 0.732620320855615
Recall: 0.9614035087719298
F1: 0.8315629742033384
Log Loss: 0.6053487977663149


In [19]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
import torch.nn.functional as F
import torch

NUM_RUNS = 10

chosen_acc_list = []   # <-- store best of acc vs inverted
prec_list = []
rec_list = []
f1_list = []
logloss_list = []

for run in range(NUM_RUNS):
    print(f"\n===== RUN {run+1}/{NUM_RUNS} =====")

    hidden_dim = 256
    cut = 1
    dropout = 0.25
    heads = 1

    if isinstance(A, np.ndarray):
        A_gpu = torch.from_numpy(A).float().to(device)
    else:
        A_gpu = A.float().to(device)

    model = DGI_with_classifier(
        features.shape[1], hidden_dim,
        heads=heads, n_classes=2,
        cut=cut, dropout=dropout
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
    bce_loss = nn.BCEWithLogitsLoss()

    num_epochs = 2500
    for epoch in range(num_epochs + 1):
        model.train()
        optimizer.zero_grad()

        perm = torch.randperm(node_feats.size(0))
        corrupt_features = node_feats[perm]

        logits, embeddings = model(
            node_feats.to(device),
            corrupt_features.to(device),
            edge_index.to(device)
        )

        lbl = torch.cat([
            torch.ones(node_feats.size(0)),
            torch.zeros(node_feats.size(0))
        ]).to(device)

        dgi_loss = bce_loss(logits.squeeze(), lbl)
        reg_loss = model.Reg_loss(A_gpu, embeddings)
        loss = dgi_loss + 0.001 * reg_loss

        if epoch % 500 == 0:
            print(f"Epoch {epoch} | DGI Loss: {dgi_loss.item():.4f} | "
                  f"Reg Loss: {reg_loss.item():.4f} | Total: {loss.item():.4f}")

        loss.backward()
        optimizer.step()

    # -------- Evaluation per RUN --------
    model.eval()
    with torch.no_grad():
        embeddings = model.get_embeddings(node_feats.to(device), edge_index.to(device))
        class_probabilities = F.softmax(model.classifier(embeddings), dim=1).cpu().numpy()

    y_pred = np.argmax(class_probabilities, axis=1)
    y_subset = y[combined_indices].cpu().numpy()

    acc = accuracy_score(y_subset, y_pred)
    acc_inv = accuracy_score(y_subset, 1 - y_pred)

    # ---------- choose whichever is higher ----------
    chosen_acc = max(acc, acc_inv)

    prec = precision_score(y_subset, y_pred)
    rec = recall_score(y_subset, y_pred)
    f1 = f1_score(y_subset, y_pred)
    ll = log_loss(y_subset, class_probabilities)

    print(f"Run {run+1} Results:")
    print(f"Accuracy: {acc:.4f}")
    print(f"Accuracy (Inverted): {acc_inv:.4f}")
    print(f"--> Selected Accuracy: {chosen_acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall: {rec:.4f}")
    print(f"F1: {f1:.4f}")
    print(f"Log Loss: {ll:.4f}")

    chosen_acc_list.append(chosen_acc)
    prec_list.append(prec)
    rec_list.append(rec)
    f1_list.append(f1)
    logloss_list.append(ll)

# ----- Print mean ± std -----
def mean_std(arr):
    return np.mean(arr), np.std(arr)

metrics = {
    "Selected Accuracy": chosen_acc_list,
    "Precision": prec_list,
    "Recall": rec_list,
    "F1": f1_list,
    "Log Loss": logloss_list
}

print("\n===== FINAL RESULTS OVER 10 RUNS =====")
for name, values in metrics.items():
    m, s = mean_std(values)
    print(f"{name}: {m:.4f} ± {s:.4f}")



===== RUN 1/10 =====
Epoch 0 | DGI Loss: 0.7077 | Reg Loss: -0.2562 | Total: 0.7074
Epoch 500 | DGI Loss: 0.4361 | Reg Loss: -0.3884 | Total: 0.4358
Epoch 1000 | DGI Loss: 0.1814 | Reg Loss: -0.4408 | Total: 0.1809
Epoch 1500 | DGI Loss: 0.0718 | Reg Loss: -0.4582 | Total: 0.0713
Epoch 2000 | DGI Loss: 0.0226 | Reg Loss: -0.4603 | Total: 0.0222
Epoch 2500 | DGI Loss: 0.0864 | Reg Loss: -0.4385 | Total: 0.0859
Run 1 Results:
Accuracy: 0.2667
Accuracy (Inverted): 0.7333
--> Selected Accuracy: 0.7333
Precision: 0.3333
Recall: 0.0035
F1: 0.0069
Log Loss: 1.9904

===== RUN 2/10 =====
Epoch 0 | DGI Loss: 0.7005 | Reg Loss: -0.2434 | Total: 0.7003
Epoch 500 | DGI Loss: 0.0906 | Reg Loss: -0.5109 | Total: 0.0901
Epoch 1000 | DGI Loss: 0.0283 | Reg Loss: -0.5221 | Total: 0.0278
Epoch 1500 | DGI Loss: 0.9798 | Reg Loss: -0.4992 | Total: 0.9793
Epoch 2000 | DGI Loss: 0.0299 | Reg Loss: -0.5356 | Total: 0.0293
Epoch 2500 | DGI Loss: 0.0232 | Reg Loss: -0.5302 | Total: 0.0226
Run 2 Results:
Accura

In [None]:
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import StepLR
# from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss

# hidden_dim   = 256
# cut          = 0
# dropout      = 0.25
# num_runs     = 10
# heads = 2
# num_epochs   = 8000
# lambda_list  = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]
# base_seed    = 42

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# node_feats = node_feats.to(device)
# edge_index = edge_index.to(device)
# A = A.to(device)

# # Use the y_subset created earlier
# y_subset_np = y_subset.astype(int)

# N, feats_dim = node_feats.size(0), node_feats.size(1)

# all_results = []
# bce_loss = nn.BCEWithLogitsLoss()

# for lam in lambda_list:
#     print(f"\n================ LAMBDA = {lam} ================\n")

#     acc_scores, prec_scores, rec_scores, f1_scores, log_losses = [], [], [], [], []

#     for run in range(num_runs):
#         print(f"\n--- Run {run+1}/{num_runs} ---")

#         seed = base_seed + run
#         torch.manual_seed(seed)
#         np.random.seed(seed)
#         if torch.cuda.is_available():
#             torch.cuda.manual_seed_all(seed)


#         model = DGI_with_classifier(features.shape[1], hidden_dim, heads=heads, n_classes=2, cut=cut, dropout=dropout).to(device)
#         optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
#         scheduler = StepLR(optimizer, step_size=200, gamma=0.5)


#         for epoch in range(num_epochs + 1):
#             model.train()
#             optimizer.zero_grad()

#             perm = torch.randperm(N, device=device)
#             corrupt_features = node_feats[perm]

#             logits, embeddings = model(node_feats, corrupt_features, edge_index)

#             lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])
#             dgi_loss = bce_loss(logits.squeeze(), lbl)
#             reg_loss = model.Reg_loss(A, embeddings)

#             loss = dgi_loss + lam * reg_loss

#             if epoch % 500 == 0:
#                 print(f"Epoch {epoch:4d} | DGI: {dgi_loss.item():.4f} | Reg: {reg_loss.item():.4f} | "
#                       f"λ*Reg: {(lam * reg_loss).item():.4f} | Total: {loss.item():.4f}")

#             loss.backward()
#             optimizer.step()
#             scheduler.step()

#         model.eval()
#         with torch.no_grad():
#             emb = model.get_embeddings(node_feats, edge_index)
#             logits_cls = model.classifier(emb)                   # [N, 2]
#             class_probabilities = F.softmax(logits_cls, dim=1).cpu().numpy()
#             y_pred = np.argmax(class_probabilities, axis=1)

#         acc  = accuracy_score(y_subset_np, y_pred)
#         acc_inv = accuracy_score(y_subset_np, 1 - y_pred)

#         if acc_inv > acc:
#             acc = acc_inv
#             y_pred = 1 - y_pred
#             class_probabilities = class_probabilities[:, ::-1]

#         prec = precision_score(y_subset_np, y_pred, zero_division=0)
#         rec  = recall_score(y_subset_np, y_pred, zero_division=0)
#         f1   = f1_score(y_subset_np, y_pred, zero_division=0)
#         ll   = log_loss(y_subset_np, class_probabilities)

#         print(f"Run {run+1} | Accuracy: {acc:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f} | LogLoss: {ll:.4f}")

#         acc_scores.append(acc)
#         prec_scores.append(prec)
#         rec_scores.append(rec)
#         f1_scores.append(f1)
#         log_losses.append(ll)

#     lambda_results = {
#         "lambda": lam,
#         "accuracy":  (float(np.mean(acc_scores)), float(np.std(acc_scores))),
#         "precision": (float(np.mean(prec_scores)), float(np.std(prec_scores))),
#         "recall":    (float(np.mean(rec_scores)), float(np.std(rec_scores))),
#         "f1":        (float(np.mean(f1_scores)),  float(np.std(f1_scores))),
#         "log_loss":  (float(np.mean(log_losses)), float(np.std(log_losses))),
#     }
#     all_results.append(lambda_results)

#     print(f"\n--- RESULTS FOR LAMBDA = {lam} ---")
#     print(f"Accuracy : {lambda_results['accuracy'][0]:.4f} ± {lambda_results['accuracy'][1]:.4f}")
#     print(f"Precision: {lambda_results['precision'][0]:.4f} ± {lambda_results['precision'][1]:.4f}")
#     print(f"Recall   : {lambda_results['recall'][0]:.4f} ± {lambda_results['recall'][1]:.4f}")
#     print(f"F1 Score : {lambda_results['f1'][0]:.4f} ± {lambda_results['f1'][1]:.4f}")
#     print(f"Log Loss : {lambda_results['log_loss'][0]:.4f} ± {lambda_results['log_loss'][1]:.4f}")

# print("\n================ FINAL SUMMARY FOR ALL LAMBDAS ================\n")
# print(f"{'Lambda':>8} | {'Accuracy':>18} | {'Precision':>18} | {'Recall':>18} | {'F1 Score':>18} | {'Log Loss':>18}")
# print("-" * 108)
# for res in all_results:
#     print(f"{res['lambda']:>8} | "
#           f"{res['accuracy'][0]:.4f} ± {res['accuracy'][1]:.4f} | "
#           f"{res['precision'][0]:.4f} ± {res['precision'][1]:.4f} | "
#           f"{res['recall'][0]:.4f} ± {res['recall'][1]:.4f} | "
#           f"{res['f1'][0]:.4f} ± {res['f1'][1]:.4f} | "
#           f"{res['log_loss'][0]:.4f} ± {res['log_loss'][1]:.4f}")