#### Training Evaluation using TrojanSAINT taken from https://github.com/DfX-NYUAD/TrojanSAINT

In [1]:
# train_graphlevel_trojansaint.py
import os, random, math
from collections import defaultdict
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool  # TrojanSAINT-style GCN backbone

# ----------------- Config -----------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NODE_CSV = "GNNDatasets/node.csv"
NODE_EDGE_CSV = "GNNDatasets/node_edges.csv"
GRAPH_CSV = "GNNDatasets/graph.csv"
GRAPH_EDGE_CSV = "GNNDatasets/graph_edges.csv"

PRETRAIN_CHECKPOINT = "node_gcn_pretrained.pth"  # keep name unchanged

# Hyperparams (tweakable)
HID_DIM = 64
PRETRAIN_LR = 1e-3
PRETRAIN_EPOCHS = 100
PRETRAIN_BATCH = None   # full-graph pretraining uses adjacency; we do whole-graph training (no batch)

FT_HEAD_LR = 1e-3
FT_FINETUNE_LR = 5e-4
FT_EPOCHS_HEAD = 60
FT_EPOCHS_FINETUNE = 120
BATCH_SIZE = 16
EARLY_STOPPING = 30
DROPOUT = 0.35

# ----------------- Load graph-level labels and split circuits -----------------
graph_df = pd.read_csv(GRAPH_CSV)
# find graph label column
graph_label_col = None
for cand in ["label_graph", "label", "is_trojan", "trojan"]:
    if cand in graph_df.columns:
        graph_label_col = cand; break
if graph_label_col is None:
    graph_df["label_graph"] = graph_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    graph_label_col = "label_graph"

circuits = graph_df["circuit_name"].tolist()
graph_labels = [int(x) for x in graph_df[graph_label_col].tolist()]

# stratified split of circuits for final graph-level evaluation
train_circuits, temp_circuits, y_train_c, y_temp_c = train_test_split(
    circuits, graph_labels, test_size=0.30, random_state=SEED, stratify=graph_labels
)
val_circuits, test_circuits, y_val_c, y_test_c = train_test_split(
    temp_circuits, y_temp_c, test_size=0.50, random_state=SEED, stratify=y_temp_c
)

print(f"Circuits split -> train:{len(train_circuits)} val:{len(val_circuits)} test:{len(test_circuits)}")

# ----------------- Prepare node data for pretraining (exclude test circuits!) -----------------
nodes_df = pd.read_csv(NODE_CSV)
edges_df = pd.read_csv(NODE_EDGE_CSV)

# identify node label column
node_label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        node_label_col = cand; break
if node_label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    node_label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# Pretrain set circuits: combine train + val circuits (we exclude graph test circuits)
pretrain_circuits = set(train_circuits + val_circuits)
print(f"Pretraining will use nodes from {len(pretrain_circuits)} circuits (train+val) and exclude {len(test_circuits)} test circuits.")

# Filter nodes/edges for pretraining graph
nodes_pretrain_df = nodes_df[nodes_df["circuit_name"].isin(pretrain_circuits)].reset_index(drop=True)
edges_pretrain_df = edges_df[edges_df["circuit_name"].isin(pretrain_circuits)].reset_index(drop=True)

# feature columns (numeric)  exclude metadata
feat_df = nodes_pretrain_df.copy()
if "gate_type" in feat_df.columns:
    # one-hot encode gate_type
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name", node_label_col}
feature_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
if len(feature_cols) == 0:
    raise SystemExit("No numeric feature columns found in node CSV. Please include features.")

# Build X_all for pretraining nodes and mapping
X_pre = feat_df[feature_cols].fillna(0.0).to_numpy(dtype=np.float32)
y_pre = nodes_pretrain_df[node_label_col].to_numpy(dtype=np.int64)
uids_pre = nodes_pretrain_df["uid"].tolist()
uid_to_idx_pre = {u: i for i,u in enumerate(uids_pre)}

# Some edges may include nodes not present in nodes_pretrain_df (rare)  filter edges
def map_uid(signal, circuit):
    return f"{circuit}::{signal}"

edge_src_uids = edges_pretrain_df.apply(lambda r: map_uid(r["src"], r["circuit_name"]), axis=1)
edge_dst_uids = edges_pretrain_df.apply(lambda r: map_uid(r["dst"], r["circuit_name"]), axis=1)

edge_src_idx = edge_src_uids.map(uid_to_idx_pre).dropna().astype(int).values
edge_dst_idx = edge_dst_uids.map(uid_to_idx_pre).dropna().astype(int).values

if len(edge_src_idx) == 0:
    raise SystemExit("No edges left after filtering to pretrain circuits; check your GNNDatasets files.")

edge_index_pre = np.stack([np.concatenate([edge_src_idx, edge_dst_idx]),
                           np.concatenate([edge_dst_idx, edge_src_idx])], axis=0)  # undirected

# Scale features using labeled nodes only (within pretrain set)
# ensure labels are available
labeled_mask_pre = (y_pre >= 0)
scaler = StandardScaler()
X_pre_scaled = X_pre.copy()
if labeled_mask_pre.sum() == 0:
    raise SystemExit("No labeled nodes in pretraining set.")
X_pre_scaled[labeled_mask_pre] = scaler.fit_transform(X_pre_scaled[labeled_mask_pre])
X_pre_scaled[~labeled_mask_pre] = (X_pre_scaled[~labeled_mask_pre] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# Convert to torch
X_pre_t = torch.from_numpy(X_pre_scaled).to(DEVICE)
y_pre_t = torch.from_numpy(y_pre).to(DEVICE)
edge_index_pre_t = torch.from_numpy(edge_index_pre).long().to(DEVICE)

# Create train/val split (node-level) for early stopping on pretraining
idx_nodes = np.where(labeled_mask_pre)[0]
y_nodes = y_pre[labeled_mask_pre]
n_train_nodes, n_tmp_nodes = train_test_split(idx_nodes, test_size=0.30, random_state=SEED, stratify=y_nodes)
n_val_nodes, n_test_nodes = train_test_split(n_tmp_nodes, test_size=0.50, random_state=SEED,
                                             stratify=y_pre[n_tmp_nodes])

train_mask_nodes = torch.zeros(len(y_pre), dtype=torch.bool, device=DEVICE); train_mask_nodes[n_train_nodes] = True
val_mask_nodes   = torch.zeros(len(y_pre), dtype=torch.bool, device=DEVICE);   val_mask_nodes[n_val_nodes] = True
# note: we won't use node_test_nodes later

# ----------------- Node TrojanSAINT-style pretraining (GCN + BN + Dropout) -----------------
class NodeTrojanSAINT(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, dropout=DROPOUT):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim, cached=False, add_self_loops=True, normalize=True)
        self.bn1 = nn.BatchNorm1d(hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim, cached=False, add_self_loops=True, normalize=True)
        self.bn2 = nn.BatchNorm1d(hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.head = nn.Linear(hid_dim, 2)
        nn.init.xavier_uniform_(self.head.weight)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = self.bn1(h)
        h = F.relu(h, inplace=True)
        h = self.dropout(h)
        h = self.conv2(h, edge_index)
        h = self.bn2(h)
        h = F.relu(h, inplace=True)
        h = self.dropout(h)
        return self.head(h)  # per-node logits

node_model = NodeTrojanSAINT(in_dim=X_pre_t.shape[1], hid_dim=HID_DIM, dropout=DROPOUT).to(DEVICE)

# class weights on node-level training labels
train_labels_nodes = y_pre_t[train_mask_nodes]
num_pos = int((train_labels_nodes==1).sum().item()) if (train_labels_nodes==1).any() else 1
num_neg = int((train_labels_nodes==0).sum().item()) if (train_labels_nodes==0).any() else 1
w_pos = (num_neg + num_pos) / (2.0 * num_pos)
w_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights_nodes = torch.tensor([w_neg, w_pos], dtype=torch.float32, device=DEVICE)
crit_node = nn.CrossEntropyLoss(weight=class_weights_nodes)
opt_node = torch.optim.Adam(node_model.parameters(), lr=PRETRAIN_LR, weight_decay=5e-4)

# Pretraining loop
best_val = -1.0; best_state = None; patience_cnt = 0
print("Node pretraining sizes (labeled): train=%d val=%d" % (train_mask_nodes.sum().item(), val_mask_nodes.sum().item()))
for epoch in range(1, PRETRAIN_EPOCHS+1):
    node_model.train()
    opt_node.zero_grad()
    logits_nodes = node_model(X_pre_t, edge_index_pre_t)
    loss = crit_node(logits_nodes[train_mask_nodes], y_pre_t[train_mask_nodes])
    loss.backward()
    nn.utils.clip_grad_norm_(node_model.parameters(), 2.0)  # stable like TrojanSAINT configs
    opt_node.step()
    if epoch % 5 == 0 or epoch == 1:
        node_model.eval()
        with torch.no_grad():
            logits_val = node_model(X_pre_t, edge_index_pre_t)
            preds_val = logits_val.argmax(dim=1)
            if val_mask_nodes.sum() > 0:
                val_acc = (preds_val[val_mask_nodes] == y_pre_t[val_mask_nodes]).float().mean().item()
            else:
                val_acc = 0.0
        print(f"Pretrain Epoch {epoch:03d} | Loss {loss.item():.4f} | Val {val_acc:.4f}")
        if val_acc > best_val + 1e-5:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in node_model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= EARLY_STOPPING:
                print("Early stopping node pretrain.")
                break

if best_state is not None:
    node_model.load_state_dict(best_state)
torch.save(node_model.state_dict(), PRETRAIN_CHECKPOINT)
print("Saved node-level model checkpoint:", PRETRAIN_CHECKPOINT)

# ----------------- Build graph-level dataset (per-circuit Data objects) -----------------
print("\nBuilding graph-level Data objects for fine-tuning (train/val/test circuits) ...")
# load full nodes/edges for graphs (use nodes_df, edges_df from earlier)
nodes_full_df = pd.read_csv(NODE_CSV)
edges_full_df = pd.read_csv(GRAPH_EDGE_CSV)

# Prepare feature scaler: use the scaler fitted during pretraining (we already have StandardScaler scaler)
# For nodes not seen in pretrain set, we'll apply same scaler transform using scaler.mean_/var_
def node_uid(circuit, node):
    return f"{circuit}::{node}"

# build uid->feature map for all nodes (apply scaler to full dataset)
feat_full_df = nodes_full_df.copy()
if "gate_type" in feat_full_df.columns:
    gate_oh = pd.get_dummies(feat_full_df["gate_type"], prefix="gt")
    feat_full_df = pd.concat([feat_full_df.drop(columns=["gate_type"]), gate_oh], axis=1)
# ensure feature columns compatible: if full has extra gate dummies or missing ones compared to pretrain, align
for col in feature_cols:
    if col not in feat_full_df.columns:
        feat_full_df[col] = 0.0
feat_full_df = feat_full_df[["circuit_name","node"] + feature_cols]

# scale using pretrain scaler (note: scaler was fitted on pretrain labeled nodes' features)
full_X = feat_full_df[feature_cols].fillna(0.0).to_numpy(dtype=np.float32)
full_X_scaled = (full_X - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)  # consistent transform
feat_full_df["scaled_feat"] = list(full_X_scaled.tolist())

# build edges by circuit
edges_full_df["src_uid"] = edges_full_df["circuit_name"].astype(str) + "::" + edges_full_df["src"].astype(str)
edges_full_df["dst_uid"] = edges_full_df["circuit_name"].astype(str) + "::" + edges_full_df["dst"].astype(str)
edges_by_circuit = defaultdict(list)
for _, r in edges_full_df.iterrows():
    edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# build per-circuit Data (only circuits present in graph_labels)
graph_data_list = []
graph_names = []
graph_target = []
for _, row in graph_df.iterrows():
    ckt = row["circuit_name"]
    lbl = int(row[graph_label_col])
    # nodes of this circuit
    sub_nodes = feat_full_df[feat_full_df["circuit_name"]==ckt]
    if sub_nodes.shape[0] == 0: 
        continue
    node_names = sub_nodes["node"].tolist()
    uid_map = {n:i for i,n in enumerate(node_names)}
    X_nodes = np.vstack(sub_nodes["scaled_feat"].values).astype(np.float32)
    # build edge_index
    srcs, dsts = [], []
    if ckt in edges_by_circuit:
        for u,v in edges_by_circuit[ckt]:
            if u in uid_map and v in uid_map:
                srcs.extend([uid_map[u], uid_map[v]])
                dsts.extend([uid_map[v], uid_map[u]])
    if len(srcs) == 0:
        # skip graphs without edges (unlikely)
        continue
    edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
    data = Data(x=torch.tensor(X_nodes, dtype=torch.float), edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    graph_data_list.append(data)
    graph_names.append(ckt)
    graph_target.append(lbl)

print(f"Built {len(graph_data_list)} graphs.")

# Create train/val/test lists by circuit split we made earlier
def filter_by_circuit(list_data, circuits_set):
    idxs = [i for i,d in enumerate(list_data) if d.circuit_name in circuits_set]
    return [list_data[i] for i in idxs]

train_graphs = filter_by_circuit(graph_data_list, set(train_circuits))
val_graphs   = filter_by_circuit(graph_data_list, set(val_circuits))
test_graphs  = filter_by_circuit(graph_data_list, set(test_circuits))

print(f"Graph counts -> train: {len(train_graphs)}, val: {len(val_graphs)}, test: {len(test_graphs)}")

# ----------------- Graph classifier reusing encoder from node_model (TrojanSAINT-style GCN + BN) -----------------
class GraphTrojanSAINT(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, dropout=DROPOUT):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim, cached=False, add_self_loops=True, normalize=True)
        self.bn1 = nn.BatchNorm1d(hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim, cached=False, add_self_loops=True, normalize=True)
        self.bn2 = nn.BatchNorm1d(hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hid_dim, 2)
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)

    def forward(self, x, edge_index, batch):
        h = self.conv1(x, edge_index); h = self.bn1(h); h = F.relu(h, inplace=True); h = self.dropout(h)
        h = self.conv2(h, edge_index); h = self.bn2(h); h = F.relu(h, inplace=True); h = self.dropout(h)
        g = global_mean_pool(h, batch)
        return self.classifier(g)

graph_model = GraphTrojanSAINT(in_dim=X_pre_t.shape[1], hid_dim=HID_DIM, dropout=DROPOUT).to(DEVICE)

# Transfer encoder weights from node_model -> graph_model where shapes match
node_state = torch.load(PRETRAIN_CHECKPOINT, map_location="cpu")
gstate = graph_model.state_dict()
copied = 0
for k in gstate.keys():
    if k in node_state and gstate[k].shape == node_state[k].shape:
        gstate[k] = node_state[k]
        copied += 1
graph_model.load_state_dict(gstate)
print(f"Transferred {copied} parameter tensors from node encoder to graph encoder.")

from torch_geometric.loader import DataLoader

# --- training/eval (unchanged flow) ---
def train_graphs_fn(model, train_set, val_set, test_set, freeze_encoder=True):
    # dataloaders
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

    if freeze_encoder:
        for p in model.conv1.parameters(): p.requires_grad = False
        for p in model.bn1.parameters(): p.requires_grad = False
        for p in model.conv2.parameters(): p.requires_grad = False
        for p in model.bn2.parameters(): p.requires_grad = False
    else:
        for p in model.parameters(): p.requires_grad = True

    # class weights
    ytrain = np.array([int(d.y.item()) for d in train_set])
    if len(np.unique(ytrain)) == 2:
        counts = np.bincount(ytrain); w = torch.tensor([ (counts.sum()/counts[0]), (counts.sum()/counts[1]) ], dtype=torch.float32).to(DEVICE)
        criterion = nn.CrossEntropyLoss(weight=w)
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                 lr=FT_HEAD_LR if freeze_encoder else FT_FINETUNE_LR, weight_decay=5e-4)
    best_val = -1.0; best_state = None; pcount = 0

    for epoch in range(1, FT_EPOCHS_HEAD + 1 if freeze_encoder else FT_EPOCHS_FINETUNE + 1):
        model.train()
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(DEVICE)
            optimizer.zero_grad()
            logits = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(logits, batch.y.view(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs

        # validation
        model.eval()
        ys, ps = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(DEVICE)
                out = model(batch.x, batch.edge_index, batch.batch)
                preds = out.argmax(dim=1)
                ys.extend(batch.y.cpu().numpy()); ps.extend(preds.cpu().numpy())
        val_acc = accuracy_score(ys, ps) if len(ys)>0 else 0.0
        if epoch % 5 == 0 or epoch == 1:
            print(f"Fine-tune ({'frozen' if freeze_encoder else 'all'}) Epoch {epoch:03d} | Val Acc {val_acc:.4f} | AvgLoss {total_loss / max(1,len(train_set)):.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            pcount = 0
        else:
            pcount += 1
            if pcount >= EARLY_STOPPING:
                print("Early stopping fine-tune stage.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    # test eval
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(DEVICE)
            out = model(batch.x, batch.edge_index, batch.batch)
            preds = out.argmax(dim=1)
            ys.extend(batch.y.cpu().numpy()); ps.extend(preds.cpu().numpy())
    test_acc = accuracy_score(ys, ps) if len(ys)>0 else 0.0
    return test_acc, ys, ps

# Stage 1: freeze encoder, train head
print("\nStage 1: freeze encoder and train classifier head")
for p in graph_model.conv1.parameters(): p.requires_grad = False
for p in graph_model.bn1.parameters(): p.requires_grad = False
for p in graph_model.conv2.parameters(): p.requires_grad = False
for p in graph_model.bn2.parameters(): p.requires_grad = False
# Stage 1 call:
test_acc_1, ys1, ps1 = train_graphs_fn(graph_model, train_graphs, val_graphs, test_graphs, freeze_encoder=True)
print("After head training -> Test Acc: {:.4f}".format(test_acc_1))

# Stage 2: unfreeze and fine-tune all (lower lr)
print("\nStage 2: unfreeze encoder and fine-tune entire model")
for p in graph_model.parameters(): p.requires_grad = True
# update global lr for finetune
FT_HEAD_LR = FT_FINETUNE_LR
# Stage 2 call:
test_acc_2, ys2, ps2 = train_graphs_fn(graph_model, train_graphs, val_graphs, test_graphs, freeze_encoder=False)
print("After full fine-tune -> Test Acc: {:.4f}".format(test_acc_2))

# Final report
final_ys, final_ps = (ys2, ps2) if len(ys2)>0 else (ys1, ps1)
print("\nFinal Evaluation (Graph-level after transfer)")
print("=============================================")
print(f"Test Accuracy: {accuracy_score(final_ys, final_ps):.4f}\n")
print("Classification Report:")
print(classification_report(final_ys, final_ps, digits=4))
print("Confusion Matrix:")
print(confusion_matrix(final_ys, final_ps))




Circuits split -> train:64 val:14 test:14
Pretraining will use nodes from 78 circuits (train+val) and exclude 14 test circuits.
Node pretraining sizes (labeled): train=130706 val=28008
Pretrain Epoch 001 | Loss 0.9511 | Val 0.2940
Pretrain Epoch 005 | Loss 0.7267 | Val 0.6488
Pretrain Epoch 010 | Loss 0.5281 | Val 0.9220
Pretrain Epoch 015 | Loss 0.3861 | Val 0.9845
Pretrain Epoch 020 | Loss 0.2874 | Val 0.9946
Pretrain Epoch 025 | Loss 0.2167 | Val 0.9971
Pretrain Epoch 030 | Loss 0.1668 | Val 0.9984
Pretrain Epoch 035 | Loss 0.1302 | Val 0.9992
Pretrain Epoch 040 | Loss 0.1067 | Val 0.9993
Pretrain Epoch 045 | Loss 0.0867 | Val 0.9993
Pretrain Epoch 050 | Loss 0.0707 | Val 0.9986
Pretrain Epoch 055 | Loss 0.0628 | Val 0.9969
Pretrain Epoch 060 | Loss 0.0535 | Val 0.9969
Pretrain Epoch 065 | Loss 0.0459 | Val 0.9970
Pretrain Epoch 070 | Loss 0.0410 | Val 0.9998
Pretrain Epoch 075 | Loss 0.0364 | Val 0.9998
Pretrain Epoch 080 | Loss 0.0320 | Val 0.9999
Pretrain Epoch 085 | Loss 0.0288 

  node_state = torch.load(PRETRAIN_CHECKPOINT, map_location="cpu")


Fine-tune (frozen) Epoch 001 | Val Acc 0.5000 | AvgLoss 0.9971
Fine-tune (frozen) Epoch 005 | Val Acc 0.6429 | AvgLoss 0.9000
Fine-tune (frozen) Epoch 010 | Val Acc 0.8571 | AvgLoss 0.5870
Fine-tune (frozen) Epoch 015 | Val Acc 0.9286 | AvgLoss 0.7016
Fine-tune (frozen) Epoch 020 | Val Acc 1.0000 | AvgLoss 0.5106
Fine-tune (frozen) Epoch 025 | Val Acc 1.0000 | AvgLoss 0.2519
Fine-tune (frozen) Epoch 030 | Val Acc 1.0000 | AvgLoss 0.1758
Fine-tune (frozen) Epoch 035 | Val Acc 1.0000 | AvgLoss 0.1928
Fine-tune (frozen) Epoch 040 | Val Acc 1.0000 | AvgLoss 0.1375
Fine-tune (frozen) Epoch 045 | Val Acc 1.0000 | AvgLoss 0.1183
Early stopping fine-tune stage.
After head training -> Test Acc: 1.0000

Stage 2: unfreeze encoder and fine-tune entire model
Fine-tune (all) Epoch 001 | Val Acc 0.9286 | AvgLoss 0.4813
Fine-tune (all) Epoch 005 | Val Acc 1.0000 | AvgLoss 0.1793
Fine-tune (all) Epoch 010 | Val Acc 1.0000 | AvgLoss 0.2301
Fine-tune (all) Epoch 015 | Val Acc 1.0000 | AvgLoss 0.1879
Fine

#### All in One, same perturbation across all metric.

In [2]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from torch_geometric.data import Data

# ----------------- Robustness Metric Functions -----------------
def jacobian_sensitivity(model, batch):
    x = batch.x.clone().detach().to(DEVICE).requires_grad_(True)
    out = model(x, batch.edge_index, batch.batch)
    pred_class = out.argmax(dim=1)
    loss = F.nll_loss(F.log_softmax(out, dim=1), pred_class)
    grads = torch.autograd.grad(loss, x, retain_graph=False, create_graph=False)[0]
    return grads.norm(p="fro").item() / x.size(0)

def local_lipschitz(model, batch):
    x = batch.x.clone().detach().to(DEVICE).requires_grad_(True)
    out = model(x, batch.edge_index, batch.batch)
    pred_class = out.argmax(dim=1)
    loss = F.nll_loss(F.log_softmax(out, dim=1), pred_class)
    grads = torch.autograd.grad(loss, x, retain_graph=False, create_graph=False)[0]
    return grads.norm(2).item() / x.size(0)

def hessian_curvature(model, batch):
    x = batch.x.clone().detach().to(DEVICE).requires_grad_(True)
    out = model(x, batch.edge_index, batch.batch)
    pred_class = out.argmax(dim=1)
    loss = F.nll_loss(F.log_softmax(out, dim=1), pred_class)
    grad = torch.autograd.grad(loss, x, create_graph=True)[0]
    hvp = torch.autograd.grad(grad.sum(), x, retain_graph=False)[0]
    return hvp.norm(2).item() / x.size(0)

def prediction_margin(model, batch):
    out = model(batch.x.to(DEVICE), batch.edge_index, batch.batch)
    probs = F.softmax(out, dim=1)
    top2 = torch.topk(probs, 2, dim=1).values
    return (top2[:,0] - top2[:,1]).mean().item()

def adv_robustness_radius(model, batch, eps=1e-3):
    x = batch.x.clone().detach().to(DEVICE)
    noise = torch.randn_like(x) * eps
    out1 = model(x, batch.edge_index, batch.batch).argmax(dim=1)
    out2 = model(x + noise, batch.edge_index, batch.batch).argmax(dim=1)
    return (out1 != out2).float().mean().item()

def stability_under_noise(model, batch, sigma=0.05, trials=5):
    diffs = []
    x = batch.x.clone().detach().to(DEVICE)
    out_orig = model(x, batch.edge_index, batch.batch)
    for _ in range(trials):
        noise = torch.randn_like(x) * sigma
        out_noisy = model(x+noise, batch.edge_index, batch.batch)
        diffs.append((out_noisy - out_orig).norm().item())
    return np.mean(diffs)

# ----------------- Perturbation Function -----------------
EPS = 0.3
ALPHA = 0.05
STEPS = 15
NOISE_STD = 0.1

def perturb_graph(batch, pgd=True, gaussian=True):
    x = batch.x.clone().detach().to(DEVICE)
    x.requires_grad_(True)

    delta = torch.zeros_like(x).to(DEVICE)
    if pgd:
        for _ in range(STEPS):
            out = graph_model(x + delta, batch.edge_index, batch.batch)
            loss = F.cross_entropy(out, batch.y.view(-1))
            loss.backward()
            grad = x.grad.detach()
            delta = (delta + ALPHA * grad.sign()).clamp(-EPS, EPS)
            x.grad.zero_()
    x_pert = (x + delta).detach()

    if gaussian:
        noise = torch.randn_like(x_pert) * NOISE_STD
        x_pert = x_pert + noise

    return x_pert.detach()

# ----------------- Metric Collection -----------------
def collect_metrics(batch, perturbed=False):
    if perturbed:
        x_used = perturb_graph(batch, pgd=True, gaussian=True)
    else:
        x_used = batch.x.clone().detach().to(DEVICE)
    b = Data(x=x_used, edge_index=batch.edge_index, y=batch.y, batch=batch.batch).to(DEVICE)
    return {
        "jac": jacobian_sensitivity(graph_model, b),
        "lip": local_lipschitz(graph_model, b),
        "hess": hessian_curvature(graph_model, b),
        "marg": prediction_margin(graph_model, b),
        "rad": adv_robustness_radius(graph_model, b),
        "stab": stability_under_noise(graph_model, b)
    }

orig_metrics, pert_metrics = [], []
ys_true, ys_pred, ys_pred_pert = [], [], []

for batch in test_graphs:
    batch = batch.to(DEVICE)

    # metrics before/after perturbation
    orig_m = collect_metrics(batch, perturbed=False)
    pert_m = collect_metrics(batch, perturbed=True)
    orig_metrics.append(orig_m); pert_metrics.append(pert_m)

    # predictions
    with torch.no_grad():
        out = graph_model(batch.x, batch.edge_index, batch.batch)
        ys_true.append(batch.y.item())
        ys_pred.append(out.argmax(dim=1).item())
    x_pert = perturb_graph(batch, pgd=True, gaussian=True)
    with torch.no_grad():
        out_p = graph_model(x_pert, batch.edge_index, batch.batch)
        ys_pred_pert.append(out_p.argmax(dim=1).item())

# ----------------- Summarize Results -----------------
def summarize_with_relerr(name, key):
    orig_vals = [m[key] for m in orig_metrics]
    pert_vals = [m[key] for m in pert_metrics]
    rel_errs = [abs(p-o)/(abs(o)+1e-8) for o,p in zip(orig_vals, pert_vals)]
    print(f"{name:<25} orig={np.mean(orig_vals):.4f} pert={np.mean(pert_vals):.4f} "
          f"rel.err={np.mean(rel_errs):.4f}")

print("\nRobustness Metrics (Graph-Level, Before vs After Perturbation)")
print("================================================================")
summarize_with_relerr("Jacobian Sensitivity", "jac")
summarize_with_relerr("Local Lipschitz", "lip")
summarize_with_relerr("Hessian Curvature", "hess")
summarize_with_relerr("Prediction Margin", "marg")
summarize_with_relerr("Adv Robustness Radius", "rad")
summarize_with_relerr("Stability under Noise", "stab")

# Perturbation success
ys_true = np.array(ys_true); ys_pred = np.array(ys_pred); ys_pred_pert = np.array(ys_pred_pert)
success_rate = np.mean(ys_pred != ys_pred_pert) * 100
print("\nPerturbation Success Evaluation")
print("================================")
print(f"Perturbation Success Rate: {success_rate:.2f}%")

print("\nPerformance on Perturbed Samples Only")
print("-------------------------------------")
print("Accuracy:", accuracy_score(ys_true, ys_pred_pert))
print("Classification Report:")
print(classification_report(ys_true, ys_pred_pert, digits=4))
print("Confusion Matrix:")
print(confusion_matrix(ys_true, ys_pred_pert))



Robustness Metrics (Graph-Level, Before vs After Perturbation)
Jacobian Sensitivity      orig=0.0000 pert=0.0002 rel.err=2.2547
Local Lipschitz           orig=0.0000 pert=0.0002 rel.err=2.2547
Hessian Curvature         orig=0.0000 pert=0.0003 rel.err=2.6463
Prediction Margin         orig=0.6324 pert=0.4551 rel.err=0.6368
Adv Robustness Radius     orig=0.0000 pert=0.0000 rel.err=0.0000
Stability under Noise     orig=0.0037 pert=0.0047 rel.err=1.1367

Perturbation Success Evaluation
Perturbation Success Rate: 92.86%

Performance on Perturbed Samples Only
-------------------------------------
Accuracy: 0.07142857142857142
Classification Report:
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         4
           1     0.2000    0.1000    0.1333        10

    accuracy                         0.0714        14
   macro avg     0.1000    0.0500    0.0667        14
weighted avg     0.1429    0.0714    0.0952        14

Confusion Matrix:
[[0 

In [3]:
# ----------------- Summarize Results (with mean ± std) -----------------
def summarize_with_relerr(name, key):
    orig_vals = np.array([m[key] for m in orig_metrics])
    pert_vals = np.array([m[key] for m in pert_metrics])
    rel_errs = np.abs(pert_vals - orig_vals) / (np.abs(orig_vals) + 1e-8)

    mean_orig, std_orig = orig_vals.mean(), orig_vals.std()
    mean_pert, std_pert = pert_vals.mean(), pert_vals.std()
    mean_rel, std_rel = rel_errs.mean(), rel_errs.std()

    print(f"{name:<25}")
    print(f"  Clean : avg={mean_orig:.4f} ± {std_orig:.4f}, "
          f"avg_relerr={mean_rel:.4e} ± {std_rel:.4e}")
    print(f"  Pert. : avg={mean_pert:.4f} ± {std_pert:.4f}, "
          f"avg_relerr={mean_rel:.4e} ± {std_rel:.4e}\n")

print("\nRobustness Metrics (Graph-Level, Before vs After Perturbation)")
print("================================================================")
summarize_with_relerr("Jacobian Sensitivity", "jac")
summarize_with_relerr("Local Lipschitz", "lip")
summarize_with_relerr("Hessian Curvature", "hess")
summarize_with_relerr("Prediction Margin", "marg")
summarize_with_relerr("Adv Robustness Radius", "rad")
summarize_with_relerr("Stability under Noise", "stab")



Robustness Metrics (Graph-Level, Before vs After Perturbation)
Jacobian Sensitivity     
  Clean : avg=0.0000 ± 0.0001, avg_relerr=2.2547e+00 ± 3.2201e+00
  Pert. : avg=0.0002 ± 0.0005, avg_relerr=2.2547e+00 ± 3.2201e+00

Local Lipschitz          
  Clean : avg=0.0000 ± 0.0001, avg_relerr=2.2547e+00 ± 3.2201e+00
  Pert. : avg=0.0002 ± 0.0005, avg_relerr=2.2547e+00 ± 3.2201e+00

Hessian Curvature        
  Clean : avg=0.0000 ± 0.0001, avg_relerr=2.6463e+00 ± 3.8600e+00
  Pert. : avg=0.0003 ± 0.0008, avg_relerr=2.6463e+00 ± 3.8600e+00

Prediction Margin        
  Clean : avg=0.6324 ± 0.2019, avg_relerr=6.3677e-01 ± 1.0521e+00
  Pert. : avg=0.4551 ± 0.2025, avg_relerr=6.3677e-01 ± 1.0521e+00

Adv Robustness Radius    
  Clean : avg=0.0000 ± 0.0000, avg_relerr=0.0000e+00 ± 0.0000e+00
  Pert. : avg=0.0000 ± 0.0000, avg_relerr=0.0000e+00 ± 0.0000e+00

Stability under Noise    
  Clean : avg=0.0037 ± 0.0045, avg_relerr=1.1367e+00 ± 1.3012e+00
  Pert. : avg=0.0047 ± 0.0042, avg_relerr=1.1367e