In [40]:
# from google.colab import drive
# drive.mount('/content/drive')

In [41]:
# !pip install -q torch_geometric
# !pip install -q class_resolver
# !pip3 install pymatting

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
import torch.optim as optim
import numpy as np
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score, log_loss
)
from torch.utils.data import TensorDataset, DataLoader, Subset
from torch_geometric.data import Data
from torch_geometric.nn import ARMAConv
from torch.utils.data import TensorDataset, DataLoader, Subset
from torchvision import models

In [43]:
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.set_num_threads(4)

In [44]:
data = np.load('/home/snu/Downloads/breastmnist_224.npz', allow_pickle=True)
all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

images = all_images.astype(np.float32) / 255.0
images = np.repeat(images[:, None, :, :], 3, axis=1)  # (N,3,224,224)
X = torch.tensor(images)
y = torch.tensor(all_labels).long()
print("Images, labels shapes:", X.shape, y.shape)

Images, labels shapes: torch.Size([780, 3, 224, 224]) torch.Size([780])


In [45]:
class0_idx = [i for i in range(len(y)) if y[i] == 0]
class1_idx = [i for i in range(len(y)) if y[i] == 1]

random.seed(42)
sampled_class0 = random.sample(class0_idx, min(1000, len(class0_idx)))
sampled_class1 = random.sample(class1_idx, min(1000, len(class1_idx)))

selected_indices = sampled_class0 + sampled_class1
random.shuffle(selected_indices)

subset_dataset = Subset(TensorDataset(X, y), selected_indices)
subset_loader = DataLoader(subset_dataset, batch_size=64, shuffle=False)

In [46]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vit.eval().to(device)

vit_feats, y_list = [], []

with torch.no_grad():
    for imgs, lbls in subset_loader:
        imgs = imgs.to(device)
        feats = vit(imgs)
        vit_feats.append(feats.cpu())
        y_list.extend(lbls.cpu().tolist())

features_matrix = torch.cat(vit_feats, dim=0).numpy().astype(np.float32)
y = np.array(y_list).astype(np.int64)

num_nodes, num_feats = features_matrix.shape
print(f"Extracted ViT-DINO Features: {features_matrix.shape}, Labels: {y.shape}")

Using cache found in /home/snu/.cache/torch/hub/facebookresearch_dino_main


Extracted ViT-DINO Features: (780, 768), Labels: (780,)


In [47]:
def create_adj(features_matrix, alpha=1):
    norms = np.linalg.norm(features_matrix, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    F_norm = features_matrix / norms
    W = np.dot(F_norm, F_norm.T)
    W = (W >= alpha).astype(np.float32)
    return W

In [48]:
W0 = create_adj(features_matrix, alpha=0.73)
print(f"W0: {W0.shape}")

W0: (780, 780)


In [49]:
def load_graph_torch(adj, node_feats):
    node_feats_t = torch.from_numpy(node_feats).float()
    edge_idx = np.array(np.nonzero(adj))
    edge_index = torch.from_numpy(edge_idx).long()
    return node_feats_t, edge_index
node_feats_all, edge_index_all = load_graph_torch(W0, features_matrix)
print(f"Number of edges: {edge_index_all.size(1)}")

Number of edges: 266936


In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GATEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ="RELU", heads=1):
        super(GATEncoder, self).__init__()

        activations = {
            "SELU": F.selu,
            "SiLU": F.silu,
            "GELU": F.gelu,
            "ELU": F.elu,
            "RELU": F.relu
        }
        self.act = activations.get(activ, F.elu)

        self.gat = GATConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            heads=heads,
            dropout=0.3,
            concat=False
        )

        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gat(x, edge_index)
        x = self.act(x)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits


In [51]:
class AvgReadout(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

In [52]:
class Discriminator(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)
    def forward(self, c, h_pl, h_mi):
        c_x = torch.unsqueeze(c, 0).expand_as(h_pl)
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
        logits = torch.cat((sc_1, sc_2), 0)
        return logits

In [53]:
class DGI(nn.Module):
    def __init__(self, n_in, n_h, heads, dropout=0.25):
        super().__init__()
        self.gat1 = GATEncoder(n_in, n_h, heads=heads, device='cuda' if torch.cuda.is_available() else 'cpu', activ=nn.ELU())
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h * heads)

    def forward(self, seq1, seq2, edge_index):
        # Create Data objects for the GATEncoder
        data1 = Data(x=seq1, edge_index=edge_index)
        data2 = Data(x=seq2, edge_index=edge_index)

        h_1 = self.gat1(data1)
        c = self.read(h_1)
        c = self.sigm(c)
        h_2 = self.gat1(data2)
        logits = self.disc(c, h_1, h_2)
        return logits, h_1

In [54]:
class DGI_with_classifier(DGI):
    def __init__(self, n_in, n_h, heads, n_classes=2, cut=0, dropout=0.25):
        super().__init__(n_in, n_h, heads, dropout=dropout)
        self.classifier = nn.Linear(n_h * heads, n_classes)
        self.cut = cut

    def get_embeddings(self, node_feats, edge_index):
        _, embeddings = self.forward(node_feats, node_feats, edge_index)
        return embeddings

    def cut_loss(self, A, S):
        S = F.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)
        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)
        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(S.shape[1], device=A.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))
        return mincut_loss + ortho_loss

    def modularity_loss(self, A, S):
        C = F.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)
        I_S = torch.eye(C.shape[1], device=A.device)
        k = torch.norm(I_S)
        n = S.shape[0]
        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1
        return modularity_term + collapse_reg_term

    def Reg_loss(self, A, embeddings):
        logits = self.classifier(embeddings)
        if self.cut == 1:
            return self.cut_loss(A, logits)
        else:
            return self.modularity_loss(A, logits)

In [55]:
from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=SEED)
accuracies, precisions, recalls, f1_scores, losses = [], [], [], [], []
all_y_true = []
all_y_proba = []
all_fpr, all_tpr, all_auc = [], [], []

In [56]:
def get_device():
    try:
        if torch.cuda.is_available():
            dev = torch.device("cuda")
            # test a tiny tensor operation to check CUDA health
            torch.tensor([1.0], device=dev) + 1.0
            return dev
    except Exception as e:
        print("CUDA not usable, falling back to CPU:", e)
        try:
            torch.cuda.empty_cache()
        except:
            pass
    return torch.device("cpu")

device = get_device()
print("Using device:", device)

Using device: cuda


In [57]:
A_tensor = torch.from_numpy(W0).to(device)
hidden_dim = 512
cut = 1
num_epochs = 2000
lr = 1e-4
weight_decay = 1e-4
# sup_weight = 1.0
# dgi_weight = 0.3
reg_weight = 0.001
heads = 1
dropout = 0.25

In [58]:
for fold, (train_idx, test_idx) in enumerate(sss.split(features_matrix, y)):
    print(f"\n=== Fold {fold+1} ===")

    # BreastMNIST classes
    benign_idx = np.where(y == 0)[0]
    malignant_idx = np.where(y == 1)[0]

    sss_class = StratifiedShuffleSplit(
        n_splits=20, test_size=0.9, random_state=fold
    )

    benign_train_idx, benign_test_idx = next(
        sss_class.split(features_matrix[benign_idx], y[benign_idx])
    )
    malignant_train_idx, malignant_test_idx = next(
        sss_class.split(features_matrix[malignant_idx], y[malignant_idx])
    )

    benign_train = benign_idx[benign_train_idx]
    malignant_train = malignant_idx[malignant_train_idx]
    benign_test = benign_idx[benign_test_idx]
    malignant_test = malignant_idx[malignant_test_idx]

    balanced_train_idx = np.concatenate([benign_train, malignant_train])
    test_idx = np.concatenate([benign_test, malignant_test])

    np.random.shuffle(balanced_train_idx)
    np.random.shuffle(test_idx)

    print(f"Train Benign: {len(benign_train)}, Train Malignant: {len(malignant_train)}")
    print(f"Test Benign: {len(benign_test)}, Test Malignant: {len(malignant_test)}")

    node_feats = node_feats_all.to(device)
    edge_index = edge_index_all.to(device)
    y_tensor = torch.from_numpy(y).long().to(device)
    train_idx_t = torch.from_numpy(balanced_train_idx).long().to(device)
    test_idx_t = torch.from_numpy(test_idx).long().to(device)

    N = node_feats.size(0)
    lbl = torch.cat([
        torch.ones(N, device=device),
        torch.zeros(N, device=device)
    ])

    model = DGI_with_classifier(num_feats, hidden_dim, heads=heads, n_classes=2, cut=cut, dropout=dropout).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=lr, weight_decay=weight_decay
    )

    bce_loss = nn.BCEWithLogitsLoss()
    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()

        perm = torch.randperm(N, device=device)
        corrupt = node_feats[perm]

        logits_dgi, embeddings = model(node_feats, corrupt, edge_index)
        dgi_loss = bce_loss(logits_dgi.squeeze(), lbl)

        logits_cls = model.classifier(embeddings)
        train_logits = logits_cls[train_idx_t]
        train_labels = y_tensor[train_idx_t]
        supervised_loss = ce_loss(train_logits, train_labels)

        reg_loss_val = model.Reg_loss(A_tensor, embeddings)
        total_loss = supervised_loss + dgi_loss + reg_weight * reg_loss_val

        total_loss.backward()
        optimizer.step()

        if epoch % 500 == 0 or epoch == 1:
            model.eval()
            with torch.no_grad():
                logits_eval = model.classifier(
                    model.get_embeddings(node_feats, edge_index)
                )
                preds_train = torch.argmax(
                    logits_eval[train_idx_t], dim=1
                ).cpu().numpy()
                acc = accuracy_score(
                    train_labels.cpu().numpy(), preds_train
                )

            print(
                f"Epoch {epoch}: "
                f"Sup={supervised_loss.item():.4f} | "
                f"DGI={dgi_loss.item():.4f} | "
                f"Reg={reg_loss_val.item():.4f} | "
                f"Total={total_loss.item():.4f} | "
                f"TrainAcc={acc:.4f}"
            )

    model.eval()
    with torch.no_grad():
        emb_final = model.get_embeddings(node_feats, edge_index)
        logits_final = model.classifier(emb_final)
        probs = torch.softmax(logits_final, dim=1).cpu().numpy()
        y_pred = np.argmax(probs, axis=1)

    y_test = y[test_idx]
    y_pred_test = y_pred[test_idx]
    y_proba_test = probs[test_idx, 1]

    acc = accuracy_score(y_test, y_pred_test)
    prec = precision_score(y_test, y_pred_test, zero_division=0)
    rec = recall_score(y_test, y_pred_test, zero_division=0)
    f1 = f1_score(y_test, y_pred_test, zero_division=0)
    loss_val = log_loss(y_test, y_proba_test)
    auc_score = roc_auc_score(y_test, y_proba_test)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)
    losses.append(loss_val)
    all_auc.append(auc_score)

    print(
        f"Fold {fold+1} \u2192 "
        f"Acc={acc:.4f} "
        f"Prec={prec:.4f} "
        f"Rec={rec:.4f} "
        f"F1={f1:.4f} "
        f"AUC={auc_score:.4f}"
    )

print("\n=== Average Results ===")
print(f"Accuracy: {np.mean(accuracies):.4f} \u00B1 {np.std(accuracies):.4f}")
print(f"Precision: {np.mean(precisions):.4f} \u00B1 {np.std(precisions):.4f}")
print(f"Recall: {np.mean(recalls):.4f} \u00B1 {np.std(recalls):.4f}")
print(f"F1: {np.mean(f1_scores):.4f} \u00B1 {np.std(f1_scores):.4f}")
print(f"LogLoss: {np.mean(losses):.4f} \u00B1 {np.std(losses):.4f}")
print(f"AUC: {np.mean(all_auc):.4f} \u00B1 {np.std(all_auc):.4f}")


=== Fold 1 ===
Train Benign: 21, Train Malignant: 57
Test Benign: 189, Test Malignant: 513
Epoch 1: Sup=0.7134 | DGI=0.7141 | Reg=-0.2370 | Total=1.4272 | TrainAcc=0.7821
Epoch 500: Sup=0.0858 | DGI=0.6931 | Reg=-0.3457 | Total=0.7786 | TrainAcc=0.9231
Epoch 1000: Sup=0.0860 | DGI=0.6932 | Reg=-0.3480 | Total=0.7789 | TrainAcc=0.9487
Epoch 1500: Sup=0.0961 | DGI=0.6940 | Reg=-0.3150 | Total=0.7897 | TrainAcc=0.9231
Epoch 2000: Sup=0.0940 | DGI=0.6931 | Reg=-0.3234 | Total=0.7869 | TrainAcc=0.9231
Fold 1 → Acc=0.7165 Prec=0.8066 Rec=0.8051 F1=0.8059 AUC=0.7329

=== Fold 2 ===
Train Benign: 21, Train Malignant: 57
Test Benign: 189, Test Malignant: 513
Epoch 1: Sup=0.7205 | DGI=0.7141 | Reg=-0.2381 | Total=1.4343 | TrainAcc=0.2692
Epoch 500: Sup=0.0007 | DGI=0.6931 | Reg=-0.4402 | Total=0.6934 | TrainAcc=1.0000
Epoch 1000: Sup=0.0002 | DGI=0.6960 | Reg=-0.5051 | Total=0.6957 | TrainAcc=1.0000
Epoch 1500: Sup=0.0001 | DGI=0.6931 | Reg=-0.5341 | Total=0.6927 | TrainAcc=1.0000
Epoch 2000: S

In [59]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import pandas as pd
# from sklearn.model_selection import StratifiedShuffleSplit
# from sklearn.metrics import (
#     accuracy_score, precision_score, recall_score, f1_score,
#     log_loss, roc_auc_score
# )

# A_tensor = torch.from_numpy(W0).to(device)
# hidden_dim = 512
# cut = 0
# num_epochs = 5000
# lr = 1e-4
# weight_decay = 1e-4
# reg_weights = [0.001, 0.005, 0.009, 0.01, 0.05, 0.09, 0.1, 0.3, 0.5, 0.9, 1, 2, 5, 8]
# sss = StratifiedShuffleSplit(n_splits=20, test_size=0.9, random_state=SEED)

# results_summary = []

# for reg_weight in reg_weights:
#     print(f"\n================ REG_WEIGHT = {reg_weight} ================")
#     accuracies, precisions, recalls, f1_scores, losses, all_auc = [], [], [], [], [], []

#     for fold, (train_idx, test_idx) in enumerate(sss.split(X, y)):
#         print(f"\n=== Fold {fold+1} ===")

#         cn_idx = np.where(y == 0)[0]
#         mci_idx = np.where(y == 1)[0]
#         sss_class = StratifiedShuffleSplit(n_splits=20, test_size=0.5, random_state=fold)
#         cn_train_idx, cn_test_idx = next(sss_class.split(X[cn_idx], y[cn_idx]))
#         mci_train_idx, mci_test_idx = next(sss_class.split(X[mci_idx], y[mci_idx]))

#         cn_train = cn_idx[cn_train_idx]
#         mci_train = mci_idx[mci_train_idx]
#         cn_test = cn_idx[cn_test_idx]
#         mci_test = mci_idx[mci_test_idx]

#         balanced_train_idx = np.concatenate([cn_train, mci_train])
#         test_idx = np.concatenate([cn_test, mci_test])
#         np.random.shuffle(balanced_train_idx)
#         np.random.shuffle(test_idx)

#         print(f"Train CN: {len(cn_train)}, Train MCI: {len(mci_train)}")
#         print(f"Test CN: {len(cn_test)}, Test MCI: {len(mci_test)}")

#         node_feats = node_feats_all.to(device)
#         edge_index = edge_index_all.to(device)
#         y_tensor = torch.from_numpy(y).long().to(device)
#         train_idx_t = torch.from_numpy(balanced_train_idx).long().to(device)
#         test_idx_t = torch.from_numpy(test_idx).long().to(device)

#         N = node_feats.size(0)
#         lbl = torch.cat([torch.ones(N, device=device), torch.zeros(N, device=device)])

#         model = DGI_with_classifier(num_feats, hidden_dim, n_classes=2, cut=cut).to(device)
#         optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
#         bce_loss = nn.BCEWithLogitsLoss()
#         ce_loss = nn.CrossEntropyLoss()

#         for epoch in range(1, num_epochs + 1):
#             model.train()
#             optimizer.zero_grad()
#             perm = torch.randperm(N, device=device)
#             corrupt = node_feats[perm]
#             logits_dgi, embeddings = model(node_feats, corrupt, edge_index)
#             dgi_loss = bce_loss(logits_dgi.squeeze(), lbl)
#             logits_cls = model.classifier(embeddings)
#             train_logits = logits_cls[train_idx_t]
#             train_labels = y_tensor[train_idx_t]
#             supervised_loss = ce_loss(train_logits, train_labels)
#             reg_loss_val = model.Reg_loss(A_tensor, embeddings)
#             total_loss = supervised_loss + dgi_loss + reg_weight * reg_loss_val
#             total_loss.backward()
#             optimizer.step()
#             if epoch % 500 == 0 or epoch == 1:
#                 model.eval()
#                 with torch.no_grad():
#                     logits_eval = model.classifier(model.get_embeddings(node_feats, edge_index))
#                     preds_train = torch.argmax(logits_eval[train_idx_t], dim=1).cpu().numpy()
#                     acc = accuracy_score(train_labels.cpu().numpy(), preds_train)
#                 print(f"Epoch {epoch}: Sup={supervised_loss.item():.4f} | "
#                       f"DGI={dgi_loss.item():.4f} | Reg={reg_loss_val.item():.4f} | "
#                       f"Total={total_loss.item():.4f} | TrainAcc={acc:.4f}")

#         model.eval()
#         with torch.no_grad():
#             emb_final = model.get_embeddings(node_feats, edge_index)
#             logits_final = model.classifier(emb_final)
#             probs = F.softmax(logits_final, dim=1).cpu().numpy()
#             y_pred = np.argmax(probs, axis=1)

#         y_test = y[test_idx]
#         y_pred_test = y_pred[test_idx]
#         y_proba_test = probs[test_idx, 1]

#         acc = accuracy_score(y_test, y_pred_test)
#         prec = precision_score(y_test, y_pred_test, zero_division=0)
#         rec = recall_score(y_test, y_pred_test, zero_division=0)
#         f1 = f1_score(y_test, y_pred_test, zero_division=0)
#         loss_val = log_loss(y_test, y_proba_test)
#         auc_score = roc_auc_score(y_test, y_proba_test)

#         accuracies.append(acc)
#         precisions.append(prec)
#         recalls.append(rec)
#         f1_scores.append(f1)
#         losses.append(loss_val)
#         all_auc.append(auc_score)
#         print(f"Fold {fold+1} → Acc={acc:.4f} Prec={prec:.4f} Rec={rec:.4f} "
#               f"F1={f1:.4f} AUC={auc_score:.4f}")

#     mean_acc, std_acc = np.mean(accuracies), np.std(accuracies)
#     mean_prec, std_prec = np.mean(precisions), np.std(precisions)
#     mean_rec, std_rec = np.mean(recalls), np.std(recalls)
#     mean_f1, std_f1 = np.mean(f1_scores), np.std(f1_scores)
#     mean_loss, std_loss = np.mean(losses), np.std(losses)
#     mean_auc, std_auc = np.mean(all_auc), np.std(all_auc)

#     results_summary.append({
#         "Reg_Weight": reg_weight,
#         "Accuracy": f"{mean_acc:.4f} ± {std_acc:.4f}",
#         "Precision": f"{mean_prec:.4f} ± {std_prec:.4f}",
#         "Recall": f"{mean_rec:.4f} ± {std_rec:.4f}",
#         "F1": f"{mean_f1:.4f} ± {std_f1:.4f}",
#         "LogLoss": f"{mean_loss:.4f} ± {std_loss:.4f}",
#         "AUC": f"{mean_auc:.4f} ± {std_auc:.4f}"
#     })

#     print(f"\n=== Average Results for reg_weight = {reg_weight} ===")
#     print(f"Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
#     print(f"Precision: {mean_prec:.4f} ± {std_prec:.4f}")
#     print(f"Recall: {mean_rec:.4f} ± {std_rec:.4f}")
#     print(f"F1: {mean_f1:.4f} ± {std_f1:.4f}")
#     print(f"LogLoss: {mean_loss:.4f} ± {std_loss:.4f}")
#     print(f"AUC: {mean_auc:.4f} ± {std_auc:.4f}")

# # ============================
# # Final summary table
# # ============================
# print("\n\n========== FINAL SUMMARY TABLE ==========")
# results_df = pd.DataFrame(results_summary)
# print(results_df.to_string(index=False))
