In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch, torch_geometric
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("PyG:", torch_geometric.__version__)


In [None]:
!pip install -q torch_geometric
!pip install -q class_resolver
!pip3 install pymatting

In [2]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
# import util
from PIL import Image
import matplotlib.pyplot as plt
from numpy import asarray
import tifffile as tiff
import torch.nn as nn
import torch.nn.functional as nnFn
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from sklearn.manifold import TSNE
import random
from torch_geometric.nn import ARMAConv
import copy

In [26]:
fa_feature_path = "/home/snu/Downloads/Histogram_CN_FA_20bin_updated.npy"
Histogram_feature_CN_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Load AD features
fa_feature_path = "/home/snu/Downloads/Histogram_AD_FA_20bin_updated.npy"
Histogram_feature_AD_FA_array = np.load(fa_feature_path, allow_pickle=True)

# Combine features and labels
X = np.vstack([Histogram_feature_CN_FA_array, Histogram_feature_AD_FA_array])
y = np.hstack([
    np.zeros(Histogram_feature_CN_FA_array.shape[0], dtype=np.int64),
    np.ones(Histogram_feature_AD_FA_array.shape[0], dtype=np.int64)
])
np.random.seed(42)
perm = np.random.permutation(X.shape[0])
X = X[perm]
y = y[perm]
num_nodes, num_feats = X.shape
print(f"Features: {X.shape}, Labels: {y.shape}")

Features: (223, 180), Labels: (223,)


## 1 layer

In [27]:
def sim(h1, h2, tau = 0.2):
    z1 = nnFn.normalize(h1, dim=-1, p=2)
    z2 = nnFn.normalize(h2, dim=-1, p=2)
    return torch.mm(z1, z2.t()) / tau

def contrastive_loss_wo_cross_network(h1, h2, z):
    f = lambda x: torch.exp(x)
    intra_sim = f(sim(h1, h1))
    inter_sim = f(sim(h1, h2))
    return -torch.log(inter_sim.diag() /
                     (intra_sim.sum(dim=-1) + inter_sim.sum(dim=-1) - intra_sim.diag()))

def contrastive_loss_wo_cross_view(h1, h2, z):
    f = lambda x: torch.exp(x)
    cross_sim = f(sim(h1, z))
    return -torch.log(cross_sim.diag() / cross_sim.sum(dim=-1))

In [28]:
class MLP(nn.Module):
    def __init__(self, inp_size, outp_size, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(), # nn.ELU()
            nn.Dropout(0.3),
            nn.Linear(hidden_size, outp_size)
        )

    def forward(self, x):
        return self.net(x)

In [29]:
class ARMAEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, device, activ="ELU", num_stacks=1, num_layers=1):
        super(ARMAEncoder, self).__init__()
        self.device = device
        # Define all activation functions
        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "ELU": nnFn.elu,
            "RELU": nnFn.relu
        }
        # Get the activation function based on the input string
        self.act = activations.get(activ, nnFn.elu)

        self.arma = ARMAConv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            num_stacks=num_stacks,   # number of parallel stacks
            num_layers=num_layers,   # depth per stack
            act=self.act,               # nonlinearity inside ARMA
            shared_weights=True,     # weight sharing across layers
            dropout=0.25             # ARMA-internal dropout
        )
        self.batchnorm = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.mlp = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.arma(x, edge_index)
        x = self.dropout(x)
        x = self.batchnorm(x)
        logits = self.mlp(x)
        return logits

In [30]:
class EMA(): # Moving Average update

    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        # old: old model parameter
        # new: new model parameter
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

In [31]:
class ARMA(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_clusters, device, activ, moving_average_decay=0.5, cut=True):
        super(ARMA, self).__init__()
        self.device = device
        self.num_clusters = num_clusters
        self.cut = cut
        self.beta = 0.6

        # Use ARMA encoder instead of GCN encoder
        self.online_encoder = ARMAEncoder(input_dim, hidden_dim, device, activ)
        self.target_encoder = copy.deepcopy(self.online_encoder)

        activations = {
            "SELU": nnFn.selu,
            "SiLU": nnFn.silu,
            "GELU": nnFn.gelu,
            "RELU": nnFn.relu
        }
        self.act = activations.get(activ, nnFn.elu)

        # Predictor head
        self.online_predictor = MLP(hidden_dim, num_clusters, hidden_dim)

        # Loss selection
        self.loss = self.cut_loss if cut else self.modularity_loss
        self.target_ema_updater = EMA(moving_average_decay)

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_ma(self):
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def forward(self, data1, data2):
        x1 = self.online_encoder(data1)
        logits1 = self.online_predictor(x1)
        S1 = nnFn.softmax(logits1, dim=1)

        x2 = self.online_encoder(data2)
        logits2 = self.online_predictor(x2)
        S2 = nnFn.softmax(logits2, dim=1)

        with torch.no_grad():
            target_proj_one = self.target_encoder(data1).detach()
            target_proj_two = self.target_encoder(data2).detach()

        l1 = self.beta * contrastive_loss_wo_cross_network(x1, x2, target_proj_two) + \
             (1.0 - self.beta) * contrastive_loss_wo_cross_view(x1, x2, target_proj_two)

        l2 = self.beta * contrastive_loss_wo_cross_network(x2, x1, target_proj_one) + \
             (1.0 - self.beta) * contrastive_loss_wo_cross_view(x2, x1, target_proj_one)

        return S1, S2, logits1, logits2, l1, l2

    def modularity_loss(self, A, S):
        C = nnFn.softmax(S, dim=1)
        d = torch.sum(A, dim=1)
        m = torch.sum(A)
        B = A - torch.ger(d, d) / (2 * m)

        I_S = torch.eye(self.num_clusters, device=self.device)
        # k = torch.norm(I_S)
        k = torch.tensor(self.num_clusters, device=self.device, dtype=torch.float32)
        n = S.shape[0]

        modularity_term = (-1 / (2 * m)) * torch.trace(torch.mm(torch.mm(C.t(), B), C))
        collapse_reg_term = (torch.sqrt(k) / n) * torch.norm(torch.sum(C, dim=0), p='fro') - 1

        return modularity_term + collapse_reg_term

    def cut_loss(self, A, S):
        S = nnFn.softmax(S, dim=1)
        A_pool = torch.matmul(torch.matmul(A, S).t(), S)
        num = torch.trace(A_pool)

        D = torch.diag(torch.sum(A, dim=-1))
        D_pooled = torch.matmul(torch.matmul(D, S).t(), S)
        den = torch.trace(D_pooled)
        mincut_loss = -(num / den)

        St_S = torch.matmul(S.t(), S)
        I_S = torch.eye(self.num_clusters, device=self.device)
        ortho_loss = torch.norm(St_S / torch.norm(St_S) - I_S / torch.norm(I_S))

        return mincut_loss + ortho_loss

In [32]:
def create_adj(features, cut, alpha=1.0):
    """Return a dense W0 matrix (only once), as you originally used for A1 / unsup loss.
       We still create the dense matrix once, but all augmentations below work with edge_index.
    """
    F_norm = features / np.linalg.norm(features, axis=1, keepdims=True)
    W = np.dot(F_norm, F_norm.T)

    if cut == 0:
        W = np.where(W >= alpha, 1, 0).astype(np.float32)
        W = (W / W.max()).astype(np.float32)
    else:
        W = (W * (W >= alpha)).astype(np.float32)
    return W

In [33]:
def edge_index_from_dense(W):
    """Return edge_index as numpy array shape (2, E) and edge_weight vector."""
    rows, cols = np.nonzero(W > 0)
    edge_index = np.vstack([rows, cols]).astype(np.int64)
    edge_weight = W[rows, cols].astype(np.float32)
    return edge_index, edge_weight

In [34]:
def build_adj_list(edge_index_np, num_nodes):
    """Build adjacency list: list of neighbor arrays for each node (numpy)."""
    adj = [[] for _ in range(num_nodes)]
    src = edge_index_np[0]
    dst = edge_index_np[1]
    for s, d in zip(src, dst):
        adj[s].append(d)
    # convert to numpy arrays for speed
    adj = [np.array(neis, dtype=np.int64) if len(neis) > 0 else np.array([], dtype=np.int64) for neis in adj]
    return adj

In [51]:
def aug_random_edge_edge_index(edge_index_np, drop_percent=0.2, seed=None):
    """Randomly drop edges from edge_index. Returns new edge_index (2 x E') and edge_weight placeholder."""
    rng = np.random.default_rng(seed)
    E = edge_index_np.shape[1]
    keep_mask = rng.random(E) >= drop_percent
    new_edge_index = edge_index_np[:, keep_mask]
    return new_edge_index

In [52]:
def aug_subgraph_edge_index(features_np, edge_index_np, adj_list, drop_percent=0.2, seed=None):
    """
    Sample a subgraph by selecting s_node_num nodes via neighbor expansion (BFS-like),
    then return (sub_features, sub_edge_index) with node ids remapped to [0..s-1].
    """
    rng = np.random.default_rng(seed)
    num_nodes = features_np.shape[0]
    s_node_num = int(num_nodes * (1 - drop_percent))
    if s_node_num < 1:
        s_node_num = 1

    # choose a random center node
    center_node = int(rng.integers(0, num_nodes))
    sub_nodes = [center_node]
    front_idx = 0

    # BFS-like expansion using adjacency list until we reach s_node_num
    while len(sub_nodes) < s_node_num and front_idx < len(sub_nodes):
        cur = sub_nodes[front_idx]
        neighbors = adj_list[cur]
        if neighbors.size > 0:
            # shuffle neighbors and try to add new ones
            nbrs_shuffled = neighbors.copy()
            rng.shuffle(nbrs_shuffled)
            for nb in nbrs_shuffled:
                if nb not in sub_nodes:
                    sub_nodes.append(int(nb))
                    if len(sub_nodes) >= s_node_num:
                        break
        front_idx += 1
        # if BFS stalls (no new neighbors), add random nodes
        if front_idx >= len(sub_nodes) and len(sub_nodes) < s_node_num:
            remaining = [n for n in range(num_nodes) if n not in sub_nodes]
            if not remaining:
                break
            add = int(rng.choice(remaining))
            sub_nodes.append(add)

    sub_nodes = sorted(set(sub_nodes))
    node_map = {old: new for new, old in enumerate(sub_nodes)}

    # induce edges that have both ends in sub_nodes
    src = edge_index_np[0]
    dst = edge_index_np[1]
    mask_src_in = np.isin(src, sub_nodes)
    mask_dst_in = np.isin(dst, sub_nodes)
    mask = mask_src_in & mask_dst_in
    sel_src = src[mask]
    sel_dst = dst[mask]
    # remap
    remapped_src = np.array([node_map[int(s)] for s in sel_src], dtype=np.int64)
    remapped_dst = np.array([node_map[int(d)] for d in sel_dst], dtype=np.int64)
    new_edge_index = np.vstack([remapped_src, remapped_dst])
    # sub features
    sub_features = features_np[sub_nodes, :].astype(np.float32)
    return sub_features, new_edge_index

In [53]:
def load_data_from_edge_index(node_feats_np, edge_index_np, device):
    """Return PyG Data with torch tensors. edge_index_np is (2, E) numpy."""
    node_feats = torch.from_numpy(node_feats_np).float()
    edge_index = torch.from_numpy(edge_index_np.astype(np.int64)).long()
    return node_feats, edge_index

# Data Loading and preprocessing

In [54]:
print("CN Shape:", Histogram_feature_CN_FA_array.shape)
print("AD Shape:",Histogram_feature_AD_FA_array.shape)

CN Shape: (133, 180)
AD Shape: (90, 180)


In [55]:
features = X
#features = np.concatenate((Histogram_feature_CN_FA_array, Histogram_feature_MCI_FA_array), axis=0)
features = features.astype(np.float32)
print(features.shape, features.dtype)

(223, 180) float32


In [56]:
# Required Parameters
cut = 1  # Consider n-cut loss OR Modularity loss
alpha = 0.8 #0.7 # Edge creation Threshold
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feats_dim = 180  # 20-bin
K = 2  # Number of clusters
epoch = [2500, 60, 100]  # Training epochs for different phases

# Define all activation functions to test
define_activations = ["SELU", "SiLU", "GELU", "ELU", "RELU"]
activ = "ELU"

In [41]:
print(features.shape)

(223, 180)


In [42]:
F_norm = features / np.linalg.norm(features, axis=1, keepdims=True)
W = np.dot(F_norm, F_norm.T)
print(np.array(np.nonzero(W>0.6)).shape)

(2, 49259)


In [43]:
F_norm = features / np.linalg.norm(features, axis=1, keepdims=True)
W = np.dot(F_norm, F_norm.T)
print(W)

[[1.0000001  0.8858711  0.8792206  ... 0.8945457  0.8821603  0.902783  ]
 [0.8858711  1.0000001  0.9067889  ... 0.8917234  0.88389105 0.9146811 ]
 [0.8792206  0.9067889  0.9999999  ... 0.87847316 0.90214026 0.92830694]
 ...
 [0.8945457  0.8917234  0.87847316 ... 1.0000002  0.8659749  0.91381127]
 [0.8821603  0.88389105 0.90214026 ... 0.8659749  1.         0.8924861 ]
 [0.902783   0.9146811  0.92830694 ... 0.91381127 0.8924861  1.        ]]


# Fully connected directed graph
# ONLY single directed edge between any two nodes

In [None]:
# num_nodes = F.shape[0]
# edge_index = torch.combinations(torch.arange(num_nodes), r=2).T
# def create_adjacency_matrix(edge_index, num_nodes):
#     adj = torch.zeros((num_nodes, num_nodes))
#     adj[edge_index[0], edge_index[1]] = 1
#     return adj

# W = create_adjacency_matrix(edge_index, num_nodes).numpy()
# print("W Shape= " , W.shape)

In [44]:
W0 = create_adj(features, cut, alpha)  # shape (N, N) dense
A1 = torch.from_numpy(W0).float().to(device)

edge_index_np, edge_weight_np = edge_index_from_dense(W0)  # numpy edge_index (2, E)
num_nodes = features.shape[0]
adj_list = build_adj_list(edge_index_np, num_nodes)  # adjacency list for fast subgraph sampling

# convert features to numpy (we'll slice them in augmentations)
features_np = features.copy()

# Build initial Data object (full graph)
node_feats_full, edge_index_full = load_data_from_edge_index(features_np, edge_index_np, device)
data0 = Data(x=node_feats_full.to(device), edge_index=edge_index_full.to(device))
print("Data0:", data0)

Data0: Data(x=[223, 180], edge_index=[2, 41689])


# Model initialization

## Contrastive Loss

In [45]:
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW
import torch.nn as nn

model = ARMA(feats_dim, 256, K, device, activ, cut=cut).to(device)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=200, gamma=0.5)
criterion = nn.CrossEntropyLoss()

num_epochs = 5000
lambda_contrastive = 0.3
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
for epoch in range(num_epochs):
    # --- Augmentations using edge_index or adjacency list (fast, sparse) ---
    # 1) Random edge drop on edge_index
    W_aug1_edge_index = aug_random_edge_edge_index(edge_index_np, drop_percent=0.2, seed=epoch)

    # 2) Subgraph via adjacency list (returns sub_features and sub_edge_index)
    W_aug2_edge_index = aug_random_edge_edge_index(edge_index_np, drop_percent=0.2, seed=epoch + 999)
    features_aug2 = features_np.copy()

    # 3) Feature augmentations (keep these as numpy operations)
    # Feature dropout (column-wise)
    rng = np.random.default_rng(epoch)
    mask = rng.random(features_np.shape) >= 0.2
    features_aug1 = (features_np * mask.astype(np.float32))

    # Feature cell dropout (random cell zeroing)
    aug_feat2 = features_np.copy()
    num_nodes_local, feat_dim = aug_feat2.shape
    drop_feat_num = int(num_nodes_local * feat_dim * 0.2)
    # random positions to zero
    flat_idx = rng.choice(num_nodes_local * feat_dim, size=drop_feat_num, replace=False)
    rows = (flat_idx // feat_dim)
    cols = (flat_idx % feat_dim)
    aug_feat2[rows, cols] = 0.0
    features_aug2_feat = aug_feat2.astype(np.float32)

    # --- Build PyG Data objects for the two views ---
    # view1: features_aug1 with W_aug1_edge_index
    node_feats1, edge_index1 = load_data_from_edge_index(features_aug1, W_aug1_edge_index, device)
    data1 = Data(x=node_feats1.to(device), edge_index=edge_index1.to(device))

    # view2: features_aug2 (from subgraph) and its edge_index
    node_feats2, edge_index2 = load_data_from_edge_index(features_aug2, W_aug2_edge_index, device)
    data2 = Data(x=node_feats2.to(device), edge_index=edge_index2.to(device))

    # --- Training step ---
    model.train()
    optimizer.zero_grad()

    S1, S2, logits1, logits2, l1, l2 = model(data1, data2)

    unsup_loss = model.loss(A1, logits1)
    cont_loss = ((l1 + l2) / 2).mean()
    total_loss = unsup_loss + lambda_contrastive * cont_loss

    total_loss.backward()
    optimizer.step()
    scheduler.step()
    model.update_ma()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Total: {total_loss.item():.4f} | Unsup: {unsup_loss.item():.4f} | Cont: {cont_loss.item():.4f}")

Epoch 0 | Total: 1.3923 | Unsup: -0.2323 | Cont: 5.4152
Epoch 100 | Total: 0.9523 | Unsup: -0.4906 | Cont: 4.8096
Epoch 200 | Total: 0.8381 | Unsup: -0.5151 | Cont: 4.5108
Epoch 300 | Total: 0.8166 | Unsup: -0.5157 | Cont: 4.4410
Epoch 400 | Total: 0.7884 | Unsup: -0.5209 | Cont: 4.3643
Epoch 500 | Total: 0.7681 | Unsup: -0.5293 | Cont: 4.3247
Epoch 600 | Total: 0.7800 | Unsup: -0.5042 | Cont: 4.2805
Epoch 700 | Total: 0.7455 | Unsup: -0.5309 | Cont: 4.2547
Epoch 800 | Total: 0.7438 | Unsup: -0.5309 | Cont: 4.2491
Epoch 900 | Total: 0.7446 | Unsup: -0.5296 | Cont: 4.2473
Epoch 1000 | Total: 0.7546 | Unsup: -0.5204 | Cont: 4.2501
Epoch 1100 | Total: 0.7499 | Unsup: -0.5276 | Cont: 4.2581
Epoch 1200 | Total: 0.7589 | Unsup: -0.5117 | Cont: 4.2352
Epoch 1300 | Total: 0.7448 | Unsup: -0.5198 | Cont: 4.2153
Epoch 1400 | Total: 0.7363 | Unsup: -0.5196 | Cont: 4.1863
Epoch 1500 | Total: 0.7525 | Unsup: -0.5216 | Cont: 4.2469
Epoch 1600 | Total: 0.7446 | Unsup: -0.5275 | Cont: 4.2401
Epoch 170

In [48]:
model.eval()
with torch.no_grad():
        S1, _, logits1,_,_,_ = model(data0, data0)
        y_pred = torch.argmax(logits1, dim=1).cpu().numpy()
        y_pred_proba = nnFn.softmax(logits1, dim=1).cpu().numpy()
        print(y_pred)
        print(y_pred_proba.max(axis=-1))

[0 0 0 1 0 1 0 1 1 1 0 0 1 0 0 1 0 0 1 0 1 0 0 1 1 1 1 1 0 0 0 0 0 1 1 0 1
 0 0 1 1 0 1 0 1 0 0 0 1 1 1 1 1 1 0 0 1 0 1 0 1 0 0 0 0 1 1 0 1 1 0 0 0 1
 1 0 0 0 0 1 1 1 1 0 1 1 1 0 1 0 1 0 1 1 1 1 0 0 0 0 0 0 0 0 1 1 0 0 1 1 0
 1 0 1 1 1 1 1 0 0 0 0 0 1 1 0 1 0 0 0 1 1 1 0 1 0 0 1 0 0 0 1 1 0 0 1 0 0
 0 1 1 1 0 0 0 1 1 1 1 0 0 0 1 0 0 1 1 1 0 1 0 0 1 0 0 1 0 0 1 1 1 1 0 0 0
 0 1 1 0 1 0 0 1 1 1 1 1 1 0 1 1 0 0 1 0 1 1 0 0 1 1 1 1 0 0 0 1 0 0 0 0 1
 1]
[0.99950635 0.99900204 0.97252727 0.9999993  0.98879904 0.9999896
 0.7966758  0.8000775  0.9946951  0.9999839  0.99999964 0.9963307
 0.9996388  0.99898666 0.9921062  0.99958104 0.9974285  0.9985008
 0.94155043 0.9999994  0.99385214 0.99967504 0.9997987  1.
 0.99936026 0.9999999  1.         0.69033384 0.9994986  0.9999982
 0.9979888  0.9973404  1.         0.99999714 1.         1.
 0.9845279  0.9403521  0.9997644  1.         0.99999976 0.8632848
 1.         0.6533867  0.99997663 0.99994695 0.9985109  0.99998605
 0.99999726 0.9999999  0.943098

In [49]:
acc_score = accuracy_score(y, y_pred)
acc_score_inverted = accuracy_score(y, 1 - y_pred)

print("Accuracy Score:", acc_score)
print("Accuracy Score Inverted:", acc_score_inverted)

if acc_score_inverted > acc_score:
    acc_score = acc_score_inverted
    y_pred = 1 - y_pred

prec_score = precision_score(y, y_pred)
rec_score = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)
log_loss_value = log_loss(y, y_pred_proba)

print("Precision Score:", prec_score)
print("Recall Score:", rec_score)
print("F1 Score:", f1)
print("Log Loss:", log_loss_value)

Accuracy Score: 0.7533632286995515
Accuracy Score Inverted: 0.24663677130044842
Precision Score: 0.6605504587155964
Recall Score: 0.8
F1 Score: 0.7236180904522613
Log Loss: 1.7112029405443685


In [57]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnFn
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss
from torch_geometric.data import Data
import random
import scipy.sparse as sp

# -------------------- Hyperparameters --------------------
num_runs = 10
num_epochs = 5000
lr = 1e-4
weight_decay = 1e-4
lambda_list = [0.001, 0.009, 0.1, 0.5, 1, 8]
base_seed = 42

# -------------------- Initialize results storage --------------------
all_results = []

# -------------------- Loop over different lambda values --------------------
for lam in lambda_list:
    print(f"\n================ LAMBDA = {lam} ================\n")

    acc_scores = []
    prec_scores = []
    rec_scores = []
    f1_scores = []
    log_losses = []

    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")
        torch.manual_seed(run)
        np.random.seed(run)
        random.seed(run)

        # -------------------- Initialize model and optimizer --------------------
        model = ARMA(feats_dim, 256, K, device, activ, cut).to(device)
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, step_size=200, gamma=0.5)

        # -------------------- Training loop --------------------
        for epoch in range(num_epochs):

            # --- Augmentation for adjacency ---
            W_aug1 = aug_random_edge(W0, drop_percent=0.2, seed=epoch)  # drop edges
            W_aug2 = aug_subgraph(features, sp.csr_matrix(W0), drop_percent=0.2)[1]  # subgraph

            # --- Augmentation for features ---
            features_aug1 = aug_feature_dropout(features, drop_percent=0.2, seed=epoch)
            features_aug2 = aug_feature_dropout_cell(torch.from_numpy(features), drop_percent=0.2).numpy()

            # --- Load data for PyG ---
            node_feats1, edge_index1, _ = load_data(W_aug1, features_aug1)
            node_feats2, edge_index2, _ = load_data(W_aug2, features_aug2)
            data1 = Data(x=node_feats1, edge_index=edge_index1).to(device)
            data2 = Data(x=node_feats2, edge_index=edge_index2).to(device)

            # --- Training step ---
            model.train()
            optimizer.zero_grad()

            S1, S2, logits1, logits2, l1, l2 = model(data1, data2)

            unsup_loss = model.loss(A1, logits1)
            cont_loss = ((l1 + l2) / 2).mean()
            total_loss = unsup_loss + lam * cont_loss

            total_loss.backward()
            optimizer.step()
            scheduler.step()
            model.update_ma()

            if epoch % 500 == 0:
                print(f"Epoch {epoch} | Total: {total_loss:.4f} | Unsup: {unsup_loss:.4f} | Cont: {cont_loss:.4f}")

        # -------------------- Evaluation --------------------
        model.eval()
        with torch.no_grad():
            S1, _, logits1, _, _, _ = model(data0, data0)
            y_pred = torch.argmax(logits1, dim=1).cpu().numpy()
            y_pred_proba = nnFn.softmax(logits1, dim=1).cpu().numpy()

            acc_score = accuracy_score(y, y_pred)
            acc_score_inverted = accuracy_score(y, 1 - y_pred)
            if acc_score_inverted > acc_score:
                acc_score = acc_score_inverted
                y_pred = 1 - y_pred

            print(f"Run {run+1} Accuracy: {acc_score:.4f}")

            prec_score = precision_score(y, y_pred)
            rec_score = recall_score(y, y_pred)
            f1 = f1_score(y, y_pred)
            log_loss_value = log_loss(y, y_pred_proba)

            acc_scores.append(acc_score)
            prec_scores.append(prec_score)
            rec_scores.append(rec_score)
            f1_scores.append(f1)
            log_losses.append(log_loss_value)

    # -------------------- Store results for this lambda --------------------
    lambda_results = {
        "lambda": lam,
        "accuracy": (np.mean(acc_scores), np.std(acc_scores)),
        "precision": (np.mean(prec_scores), np.std(prec_scores)),
        "recall": (np.mean(rec_scores), np.std(rec_scores)),
        "f1": (np.mean(f1_scores), np.std(f1_scores)),
        "log_loss": (np.mean(log_losses), np.std(log_losses))
    }
    all_results.append(lambda_results)

    print(f"\n--- RESULTS FOR LAMBDA = {lam} ---")
    print(f"Accuracy: {lambda_results['accuracy'][0]:.4f} ± {lambda_results['accuracy'][1]:.4f}")
    print(f"Precision: {lambda_results['precision'][0]:.4f} ± {lambda_results['precision'][1]:.4f}")
    print(f"Recall: {lambda_results['recall'][0]:.4f} ± {lambda_results['recall'][1]:.4f}")
    print(f"F1 Score: {lambda_results['f1'][0]:.4f} ± {lambda_results['f1'][1]:.4f}")
    print(f"Log Loss: {lambda_results['log_loss'][0]:.4f} ± {lambda_results['log_loss'][1]:.4f}")

# -------------------- Final Summary --------------------
print("\n================ FINAL SUMMARY FOR ALL LAMBDAS ================\n")
print(f"{'Lambda':>8} | {'Accuracy':>18} | {'Precision':>18} | {'Recall':>18} | {'F1 Score':>18} | {'Log Loss':>18}")
print("-" * 108)
for res in all_results:
    print(f"{res['lambda']:>8} | "
          f"{res['accuracy'][0]:.4f} ± {res['accuracy'][1]:.4f} | "
          f"{res['precision'][0]:.4f} ± {res['precision'][1]:.4f} | "
          f"{res['recall'][0]:.4f} ± {res['recall'][1]:.4f} | "
          f"{res['f1'][0]:.4f} ± {res['f1'][1]:.4f} | "
          f"{res['log_loss'][0]:.4f} ± {res['log_loss'][1]:.4f}")





--- Run 1/10 ---


NameError: name 'aug_random_edge' is not defined

In [None]:
import numpy as np
import torch

def aug_feature_dropout(features_np, drop_percent=0.2, seed=None):
    rng = np.random.default_rng(seed)
    mask = rng.random(features_np.shape) >= drop_percent
    return (features_np * mask).astype(np.float32)

def aug_feature_dropout_cell(features_np, drop_percent=0.2, seed=None):
    rng = np.random.default_rng(seed)
    aug_feat = features_np.copy()
    num_nodes_local, feat_dim = aug_feat.shape
    drop_feat_num = int(num_nodes_local * feat_dim * drop_percent)
    flat_idx = rng.choice(num_nodes_local * feat_dim, size=drop_feat_num, replace=False)
    rows = (flat_idx // feat_dim)
    cols = (flat_idx % feat_dim)
    aug_feat[rows, cols] = 0.0
    return aug_feat.astype(np.float32)