#### Training Evaluation

In [1]:
import os, random, math
from collections import defaultdict
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# ----------------- Config -----------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NODE_CSV = "GNNDatasets/node.csv"
NODE_EDGE_CSV = "GNNDatasets/node_edges.csv"
GRAPH_CSV = "GNNDatasets/graph.csv"
GRAPH_EDGE_CSV = "GNNDatasets/graph_edges.csv"

PRETRAIN_CHECKPOINT = "node_gcn_pretrained.pth"

# Hyperparams (tweakable)
HID_DIM = 64
PRETRAIN_LR = 1e-3
PRETRAIN_EPOCHS = 100
PRETRAIN_BATCH = None   # full-graph pretraining uses adjacency; we do whole-graph training (no batch)

FT_HEAD_LR = 1e-3
FT_FINETUNE_LR = 5e-4
FT_EPOCHS_HEAD = 60
FT_EPOCHS_FINETUNE = 120
BATCH_SIZE = 16
EARLY_STOPPING = 30
DROPOUT = 0.35

# ----------------- Load graph-level labels and split circuits -----------------
graph_df = pd.read_csv(GRAPH_CSV)
# find graph label column
graph_label_col = None
for cand in ["label_graph", "label", "is_trojan", "trojan"]:
    if cand in graph_df.columns:
        graph_label_col = cand; break
if graph_label_col is None:
    graph_df["label_graph"] = graph_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    graph_label_col = "label_graph"

circuits = graph_df["circuit_name"].tolist()
graph_labels = [int(x) for x in graph_df[graph_label_col].tolist()]

# stratified split of circuits for final graph-level evaluation
train_circuits, temp_circuits, y_train_c, y_temp_c = train_test_split(
    circuits, graph_labels, test_size=0.30, random_state=SEED, stratify=graph_labels
)
val_circuits, test_circuits, y_val_c, y_test_c = train_test_split(
    temp_circuits, y_temp_c, test_size=0.50, random_state=SEED, stratify=y_temp_c
)

print(f"Circuits split -> train:{len(train_circuits)} val:{len(val_circuits)} test:{len(test_circuits)}")

# ----------------- Prepare node data for pretraining (exclude test circuits!) -----------------
nodes_df = pd.read_csv(NODE_CSV)
edges_df = pd.read_csv(NODE_EDGE_CSV)

# identify node label column
node_label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        node_label_col = cand; break
if node_label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    node_label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# Pretrain set circuits: combine train + val circuits (we exclude graph test circuits)
pretrain_circuits = set(train_circuits + val_circuits)
print(f"Pretraining will use nodes from {len(pretrain_circuits)} circuits (train+val) and exclude {len(test_circuits)} test circuits.")

# Filter nodes/edges for pretraining graph
nodes_pretrain_df = nodes_df[nodes_df["circuit_name"].isin(pretrain_circuits)].reset_index(drop=True)
edges_pretrain_df = edges_df[edges_df["circuit_name"].isin(pretrain_circuits)].reset_index(drop=True)

# feature columns (numeric)  exclude metadata
feat_df = nodes_pretrain_df.copy()
if "gate_type" in feat_df.columns:
    # one-hot encode gate_type
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name", node_label_col}
feature_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
if len(feature_cols) == 0:
    raise SystemExit("No numeric feature columns found in node CSV. Please include features.")

# Build X_all for pretraining nodes and mapping
X_pre = feat_df[feature_cols].fillna(0.0).to_numpy(dtype=np.float32)
y_pre = nodes_pretrain_df[node_label_col].to_numpy(dtype=np.int64)
uids_pre = nodes_pretrain_df["uid"].tolist()
uid_to_idx_pre = {u: i for i,u in enumerate(uids_pre)}

# Some edges may include nodes not present in nodes_pretrain_df (rare)  filter edges
def map_uid(signal, circuit):
    return f"{circuit}::{signal}"

edge_src_uids = edges_pretrain_df.apply(lambda r: map_uid(r["src"], r["circuit_name"]), axis=1)
edge_dst_uids = edges_pretrain_df.apply(lambda r: map_uid(r["dst"], r["circuit_name"]), axis=1)

edge_src_idx = edge_src_uids.map(uid_to_idx_pre).dropna().astype(int).values
edge_dst_idx = edge_dst_uids.map(uid_to_idx_pre).dropna().astype(int).values

if len(edge_src_idx) == 0:
    raise SystemExit("No edges left after filtering to pretrain circuits; check your GNNDatasets files.")

edge_index_pre = np.stack([np.concatenate([edge_src_idx, edge_dst_idx]),
                           np.concatenate([edge_dst_idx, edge_src_idx])], axis=0)  # undirected

# Scale features using labeled nodes only (within pretrain set)
# ensure labels are available
labeled_mask_pre = (y_pre >= 0)
scaler = StandardScaler()
X_pre_scaled = X_pre.copy()
if labeled_mask_pre.sum() == 0:
    raise SystemExit("No labeled nodes in pretraining set.")
X_pre_scaled[labeled_mask_pre] = scaler.fit_transform(X_pre_scaled[labeled_mask_pre])
X_pre_scaled[~labeled_mask_pre] = (X_pre_scaled[~labeled_mask_pre] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# Convert to torch
X_pre_t = torch.from_numpy(X_pre_scaled).to(DEVICE)
y_pre_t = torch.from_numpy(y_pre).to(DEVICE)
edge_index_pre_t = torch.from_numpy(edge_index_pre).long().to(DEVICE)

# Create train/val split (node-level) for early stopping on pretraining
idx_nodes = np.where(labeled_mask_pre)[0]
y_nodes = y_pre[labeled_mask_pre]
n_train_nodes, n_tmp_nodes = train_test_split(idx_nodes, test_size=0.30, random_state=SEED, stratify=y_nodes)
n_val_nodes, n_test_nodes = train_test_split(n_tmp_nodes, test_size=0.50, random_state=SEED,
                                             stratify=y_pre[n_tmp_nodes])

train_mask_nodes = torch.zeros(len(y_pre), dtype=torch.bool, device=DEVICE); train_mask_nodes[n_train_nodes] = True
val_mask_nodes   = torch.zeros(len(y_pre), dtype=torch.bool, device=DEVICE);   val_mask_nodes[n_val_nodes] = True
# note: we won't use node_test_nodes later

# ----------------- Node GCN pretraining (PyG-style conv) -----------------
class NodeGCN(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, dropout=DROPOUT):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim)
        self.head = nn.Linear(hid_dim, 2)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return self.head(x)  # per-node logits

node_model = NodeGCN(in_dim=X_pre_t.shape[1], hid_dim=HID_DIM).to(DEVICE)

# class weights on node-level training labels
train_labels_nodes = y_pre_t[train_mask_nodes]
classes, counts = torch.unique(train_labels_nodes, return_counts=True)
num_pos = int((train_labels_nodes==1).sum().item()) if (train_labels_nodes==1).any() else 1
num_neg = int((train_labels_nodes==0).sum().item()) if (train_labels_nodes==0).any() else 1
w_pos = (num_neg + num_pos) / (2.0 * num_pos)
w_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights_nodes = torch.tensor([w_neg, w_pos], dtype=torch.float32, device=DEVICE)
crit_node = nn.CrossEntropyLoss(weight=class_weights_nodes)
opt_node = torch.optim.Adam(node_model.parameters(), lr=PRETRAIN_LR, weight_decay=5e-4)

# Pretraining loop
best_val = -1.0; best_state = None; patience_cnt = 0
print("Node pretraining sizes (labeled): train=%d val=%d" % (train_mask_nodes.sum().item(), val_mask_nodes.sum().item()))
for epoch in range(1, PRETRAIN_EPOCHS+1):
    node_model.train()
    opt_node.zero_grad()
    logits_nodes = node_model(X_pre_t, edge_index_pre_t)
    loss = crit_node(logits_nodes[train_mask_nodes], y_pre_t[train_mask_nodes])
    loss.backward()
    opt_node.step()
    if epoch % 5 == 0 or epoch == 1:
        node_model.eval()
        with torch.no_grad():
            logits_val = node_model(X_pre_t, edge_index_pre_t)
            preds_val = logits_val.argmax(dim=1)
            if val_mask_nodes.sum() > 0:
                val_acc = (preds_val[val_mask_nodes] == y_pre_t[val_mask_nodes]).float().mean().item()
            else:
                val_acc = 0.0
        print(f"Pretrain Epoch {epoch:03d} | Loss {loss.item():.4f} | Val {val_acc:.4f}")
        if val_acc > best_val + 1e-5:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in node_model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= EARLY_STOPPING:
                print("Early stopping node pretrain.")
                break

if best_state is not None:
    node_model.load_state_dict(best_state)
torch.save(node_model.state_dict(), PRETRAIN_CHECKPOINT)
print("Saved node-level model checkpoint:", PRETRAIN_CHECKPOINT)

# ----------------- Build graph-level dataset (per-circuit Data objects) -----------------
print("\nBuilding graph-level Data objects for fine-tuning (train/val/test circuits) ...")
# load full nodes/edges for graphs (use nodes_df, edges_df from earlier)
nodes_full_df = pd.read_csv(NODE_CSV)
edges_full_df = pd.read_csv(GRAPH_EDGE_CSV)

# Prepare feature scaler: use the scaler fitted during pretraining (we already have StandardScaler scaler)
# For nodes not seen in pretrain set, we'll apply same scaler transform using scaler.mean_/var_
def node_uid(circuit, node):
    return f"{circuit}::{node}"

# build uid->feature map for all nodes (apply scaler to full dataset)
feat_full_df = nodes_full_df.copy()
if "gate_type" in feat_full_df.columns:
    gate_oh = pd.get_dummies(feat_full_df["gate_type"], prefix="gt")
    feat_full_df = pd.concat([feat_full_df.drop(columns=["gate_type"]), gate_oh], axis=1)
# ensure feature columns compatible: if full has extra gate dummies or missing ones compared to pretrain, align
for col in feature_cols:
    if col not in feat_full_df.columns:
        feat_full_df[col] = 0.0
feat_full_df = feat_full_df[["circuit_name","node"] + feature_cols]

# scale using pretrain scaler (note: scaler was fitted on pretrain labeled nodes' features)
full_X = feat_full_df[feature_cols].fillna(0.0).to_numpy(dtype=np.float32)
full_X_scaled = (full_X - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)  # consistent transform
feat_full_df["scaled_feat"] = list(full_X_scaled.tolist())

# build edges by circuit
edges_full_df["src_uid"] = edges_full_df["circuit_name"].astype(str) + "::" + edges_full_df["src"].astype(str)
edges_full_df["dst_uid"] = edges_full_df["circuit_name"].astype(str) + "::" + edges_full_df["dst"].astype(str)
edges_by_circuit = defaultdict(list)
for _, r in edges_full_df.iterrows():
    edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# build per-circuit Data (only circuits present in graph_labels)
graph_data_list = []
graph_names = []
graph_target = []
for _, row in graph_df.iterrows():
    ckt = row["circuit_name"]
    lbl = int(row[graph_label_col])
    # nodes of this circuit
    sub_nodes = feat_full_df[feat_full_df["circuit_name"]==ckt]
    if sub_nodes.shape[0] == 0: 
        continue
    node_names = sub_nodes["node"].tolist()
    uid_map = {n:i for i,n in enumerate(node_names)}
    X_nodes = np.vstack(sub_nodes["scaled_feat"].values).astype(np.float32)
    # build edge_index
    srcs, dsts = [], []
    if ckt in edges_by_circuit:
        for u,v in edges_by_circuit[ckt]:
            if u in uid_map and v in uid_map:
                srcs.extend([uid_map[u], uid_map[v]])
                dsts.extend([uid_map[v], uid_map[u]])
    if len(srcs) == 0:
        # skip graphs without edges (unlikely)
        continue
    edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
    data = Data(x=torch.tensor(X_nodes, dtype=torch.float), edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    graph_data_list.append(data)
    graph_names.append(ckt)
    graph_target.append(lbl)

print(f"Built {len(graph_data_list)} graphs.")

# Create train/val/test lists by circuit split we made earlier
def filter_by_circuit(list_data, circuits_set):
    idxs = [i for i,d in enumerate(list_data) if d.circuit_name in circuits_set]
    return [list_data[i] for i in idxs]

train_graphs = filter_by_circuit(graph_data_list, set(train_circuits))
val_graphs   = filter_by_circuit(graph_data_list, set(val_circuits))
test_graphs  = filter_by_circuit(graph_data_list, set(test_circuits))

print(f"Graph counts -> train: {len(train_graphs)}, val: {len(val_graphs)}, test: {len(test_graphs)}")

# ----------------- Graph classifier reusing conv layers from node_model -----------------
class GraphClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, dropout=DROPOUT):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hid_dim, 2)
    def forward(self, x, edge_index, batch):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.conv2(h, edge_index)
        g = global_mean_pool(h, batch)
        return self.classifier(g)

graph_model = GraphClassifier(in_dim=X_pre_t.shape[1], hid_dim=HID_DIM).to(DEVICE)

# Copy conv weights from node_model -> graph_model (shapes align because both use GCNConv)
node_state = torch.load(PRETRAIN_CHECKPOINT, map_location="cpu")
# node_state contains keys: conv1.lin.weight, conv1.lin.bias? Check keys
for k_src, v_src in node_state.items():
    if "conv1" in k_src and "weight" in k_src:
        # copy to graph_model conv1 weight if exists
        if k_src in graph_model.state_dict() and graph_model.state_dict()[k_src].shape == v_src.shape:
            graph_model.state_dict()[k_src].copy_(v_src)
    if "conv2" in k_src and "weight" in k_src:
        if k_src in graph_model.state_dict() and graph_model.state_dict()[k_src].shape == v_src.shape:
            graph_model.state_dict()[k_src].copy_(v_src)

from torch_geometric.loader import DataLoader

# --- replace the function definition (rename it) ---
def train_graphs_fn(model, train_set, val_set, test_set, freeze_encoder=True):
    # dataloaders
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

    if freeze_encoder:
        for p in model.conv1.parameters(): p.requires_grad = False
        for p in model.conv2.parameters(): p.requires_grad = False
    else:
        for p in model.parameters(): p.requires_grad = True

    # class weights
    ytrain = np.array([int(d.y.item()) for d in train_set])
    if len(np.unique(ytrain)) == 2:
        counts = np.bincount(ytrain); w = torch.tensor([ (counts.sum()/counts[0]), (counts.sum()/counts[1]) ], dtype=torch.float32).to(DEVICE)
        criterion = nn.CrossEntropyLoss(weight=w)
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=FT_HEAD_LR if freeze_encoder else FT_FINETUNE_LR, weight_decay=5e-4)
    best_val = -1.0; best_state = None; pcount = 0

    for epoch in range(1, FT_EPOCHS_HEAD + 1 if freeze_encoder else FT_EPOCHS_FINETUNE + 1):
        model.train()
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(DEVICE)
            optimizer.zero_grad()
            logits = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(logits, batch.y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs

        # validation
        model.eval()
        ys, ps = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(DEVICE)
                out = model(batch.x, batch.edge_index, batch.batch)
                preds = out.argmax(dim=1)
                ys.extend(batch.y.cpu().numpy()); ps.extend(preds.cpu().numpy())
        val_acc = accuracy_score(ys, ps) if len(ys)>0 else 0.0
        if epoch % 5 == 0 or epoch == 1:
            print(f"Fine-tune ({'frozen' if freeze_encoder else 'all'}) Epoch {epoch:03d} | Val Acc {val_acc:.4f} | AvgLoss {total_loss / max(1,len(train_set)):.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            pcount = 0
        else:
            pcount += 1
            if pcount >= EARLY_STOPPING:
                print("Early stopping fine-tune stage.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    # test eval
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(DEVICE)
            out = model(batch.x, batch.edge_index, batch.batch)
            preds = out.argmax(dim=1)
            ys.extend(batch.y.cpu().numpy()); ps.extend(preds.cpu().numpy())
    test_acc = accuracy_score(ys, ps) if len(ys)>0 else 0.0
    return test_acc, ys, ps

# Stage 1: freeze encoder, train head
print("\nStage 1: freeze encoder and train classifier head")
for p in graph_model.conv1.parameters(): p.requires_grad = False
for p in graph_model.conv2.parameters(): p.requires_grad = False
# Stage 1 call:
test_acc_1, ys1, ps1 = train_graphs_fn(graph_model, train_graphs, val_graphs, test_graphs, freeze_encoder=True)
print("After head training -> Test Acc: {:.4f}".format(test_acc_1))

# Stage 2: unfreeze and fine-tune all (lower lr)
print("\nStage 2: unfreeze encoder and fine-tune entire model")
for p in graph_model.parameters(): p.requires_grad = True
# update global lr for finetune
FT_HEAD_LR = FT_FINETUNE_LR
# Stage 2 call:
test_acc_2, ys2, ps2 = train_graphs_fn(graph_model, train_graphs, val_graphs, test_graphs, freeze_encoder=False)
print("After full fine-tune -> Test Acc: {:.4f}".format(test_acc_2))

# Final report
final_ys, final_ps = (ys2, ps2) if len(ys2)>0 else (ys1, ps1)
print("\nFinal Evaluation (Graph-level after transfer)")
print("=============================================")
print(f"Test Accuracy: {accuracy_score(final_ys, final_ps):.4f}\n")
print("Classification Report:")
print(classification_report(final_ys, final_ps, digits=4))
print("Confusion Matrix:")
print(confusion_matrix(final_ys, final_ps))



Circuits split -> train:64 val:14 test:14
Pretraining will use nodes from 78 circuits (train+val) and exclude 14 test circuits.
Node pretraining sizes (labeled): train=130706 val=28008
Pretrain Epoch 001 | Loss 0.7781 | Val 0.6488
Pretrain Epoch 005 | Loss 0.6883 | Val 0.6977
Pretrain Epoch 010 | Loss 0.5939 | Val 0.9208
Pretrain Epoch 015 | Loss 0.5102 | Val 0.9453
Pretrain Epoch 020 | Loss 0.4307 | Val 0.9594
Pretrain Epoch 025 | Loss 0.3524 | Val 0.9863
Pretrain Epoch 030 | Loss 0.2770 | Val 0.9936
Pretrain Epoch 035 | Loss 0.2106 | Val 0.9944
Pretrain Epoch 040 | Loss 0.1548 | Val 0.9961
Pretrain Epoch 045 | Loss 0.1138 | Val 0.9991
Pretrain Epoch 050 | Loss 0.0838 | Val 0.9996
Pretrain Epoch 055 | Loss 0.0618 | Val 0.9997
Pretrain Epoch 060 | Loss 0.0478 | Val 0.9998
Pretrain Epoch 065 | Loss 0.0380 | Val 0.9998
Pretrain Epoch 070 | Loss 0.0316 | Val 0.9998
Pretrain Epoch 075 | Loss 0.0256 | Val 0.9998
Pretrain Epoch 080 | Loss 0.0221 | Val 0.9998
Pretrain Epoch 085 | Loss 0.0190 

  node_state = torch.load(PRETRAIN_CHECKPOINT, map_location="cpu")


Fine-tune (frozen) Epoch 001 | Val Acc 1.0000 | AvgLoss 0.5945
Fine-tune (frozen) Epoch 005 | Val Acc 1.0000 | AvgLoss 0.4577
Fine-tune (frozen) Epoch 010 | Val Acc 1.0000 | AvgLoss 0.2530
Fine-tune (frozen) Epoch 015 | Val Acc 1.0000 | AvgLoss 0.1850
Fine-tune (frozen) Epoch 020 | Val Acc 1.0000 | AvgLoss 0.1366
Fine-tune (frozen) Epoch 025 | Val Acc 1.0000 | AvgLoss 0.0988
Fine-tune (frozen) Epoch 030 | Val Acc 1.0000 | AvgLoss 0.1019
Early stopping fine-tune stage.
After head training -> Test Acc: 0.9286

Stage 2: unfreeze encoder and fine-tune entire model
Fine-tune (all) Epoch 001 | Val Acc 1.0000 | AvgLoss 0.5179
Fine-tune (all) Epoch 005 | Val Acc 1.0000 | AvgLoss 0.3262
Fine-tune (all) Epoch 010 | Val Acc 1.0000 | AvgLoss 0.2033
Fine-tune (all) Epoch 015 | Val Acc 1.0000 | AvgLoss 0.1076
Fine-tune (all) Epoch 020 | Val Acc 1.0000 | AvgLoss 0.0639
Fine-tune (all) Epoch 025 | Val Acc 1.0000 | AvgLoss 0.0353
Fine-tune (all) Epoch 030 | Val Acc 1.0000 | AvgLoss 0.0246
Early stoppin

#### Jacobain 

In [10]:
# ============================================================
# Graph-level: PGD perturbation -> evaluation -> Jacobian+FD
# ============================================================
import torch, numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, accuracy_score
from torch_geometric.loader import DataLoader

# ---------------------- PARAMETERS (tuned to induce large shifts) ----------------------
PER_CLASS = 250        # graphs per class to perturb (adjust based on your class counts)
EPSILON   = 20.0       # L2 radius on flattened node-feature tensor (strong)
ALPHA     = 10.2       # PGD step size (normalized by grad norm)
NUM_ITERS = 100        # PGD iterations
FD_EPS    = 1e-1      # finite-difference epsilon for Jacobian check
SEED      = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# ---------------------- Sanity checks & setup ----------------------
required = ["graph_model", "test_graphs", "DEVICE"]
missing = [v for v in required if v not in globals()]
if missing:
    raise RuntimeError(f"Missing required variables in notebook: {missing}. "
                       "Run your training cell first to define graph_model/test_graphs/DEVICE.")

model = graph_model.to(DEVICE)
model.eval()

# We'll reuse BATCH_SIZE if present; otherwise default small batch for eval loaders
BATCH_EVAL = globals().get("BATCH_SIZE", 16)

# Freeze randomness for selection
rng = np.random.default_rng(SEED)

# ---------------------- Build a stable list (order) of test graphs ----------------------
test_list = list(test_graphs)   # preserve order
test_labels = np.array([int(d.y.item()) for d in test_list])
print(f"Test graphs: {len(test_list)} | class counts = {np.bincount(test_labels) if len(test_labels)>0 else '[]'}")

# ---------------------- Select graphs to perturb (PER_CLASS per class) ----------------------
selected_idxs = []
for cls in [0, 1]:
    idxs = np.where(test_labels == cls)[0]
    if len(idxs) == 0:
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False)
    selected_idxs.extend(chosen.tolist())
selected_idxs = np.array(sorted(selected_idxs), dtype=np.int64)
print("Selected to perturb:", {0:int((test_labels[selected_idxs]==0).sum()),
                               1:int((test_labels[selected_idxs]==1).sum())})

# ---------------------- PGD perturbation for selected graphs ----------------------
perturbed_map = {}  # idx -> perturbed node-feature matrix (Tensor on DEVICE)
print("\nRunning PGD on selected graphs (strong settings)...")

for idx in selected_idxs:
    data = test_list[int(idx)]
    x_orig = data.x.detach().to(DEVICE)                 # [N, F]
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_orig.shape
    batch_zero = torch.zeros(n_nodes, dtype=torch.long, device=DEVICE)
    y_true = data.y.view(-1).to(DEVICE)                 # shape [1]

    # init random direction with L2 = EPSILON on the flattened x
    delta = torch.randn_like(x_orig, device=DEVICE)
    delta = delta * (EPSILON / (delta.view(-1).norm() + 1e-12))
    x_adv = (x_orig + delta).detach().clone().requires_grad_(True)

    for _ in range(NUM_ITERS):
        out = model(x_adv, edge_index, batch_zero)      # [1, C]
        loss = F.cross_entropy(out, y_true)             # targeted to true label; pushes away from correct
        grad_x = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]

        gnorm = grad_x.view(-1).norm().item()
        if gnorm == 0:
            break
        step = (ALPHA * grad_x) / (gnorm + 1e-12)
        x_adv = (x_adv + step).detach()

        # Project back to L2 ball around x_orig
        delta = x_adv - x_orig
        dnorm = delta.view(-1).norm().item()
        if dnorm > EPSILON:
            delta = delta * (EPSILON / (dnorm + 1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)

    perturbed_map[int(idx)] = x_adv.detach().clone()

print("? PGD finished.")

# ---------------------- Full test evaluation (perturbed selected + originals) ----------------------
y_true_list, y_pred_list = [], []
with torch.no_grad():
    for i, data in enumerate(test_list):
        x_in = perturbed_map[i] if i in perturbed_map else data.x.to(DEVICE)
        edge_index = data.edge_index.to(DEVICE)
        batch_zero = torch.zeros(x_in.size(0), dtype=torch.long, device=DEVICE)
        out = model(x_in, edge_index, batch_zero)          # [1, C]
        pred = int(out.argmax(dim=1).item())
        y_pred_list.append(pred)
        y_true_list.append(int(data.y.item()))

y_true_arr = np.array(y_true_list)
y_pred_arr = np.array(y_pred_list)

print("\n============= Robustness Evaluation (Full Test Set: perturbed selected + rest original) =============")
acc = float((y_true_arr == y_pred_arr).mean())
prec, rec, f1, _ = precision_recall_fscore_support(y_true_arr, y_pred_arr, average='weighted', zero_division=0)
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification report:")
print(classification_report(y_true_arr, y_pred_arr, labels=[0,1], target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(y_true_arr, y_pred_arr, labels=[0,1]))

# Flip statistics for the selected set
orig_sel_preds, adv_sel_preds = [], []
with torch.no_grad():
    for idx in selected_idxs:
        d = test_list[int(idx)]
        ez = torch.zeros(d.x.size(0), dtype=torch.long, device=DEVICE)
        po = model(d.x.to(DEVICE), d.edge_index.to(DEVICE), ez).argmax(dim=1).item()
        pa = model(perturbed_map[int(idx)], d.edge_index.to(DEVICE), ez).argmax(dim=1).item()
        orig_sel_preds.append(int(po)); adv_sel_preds.append(int(pa))
orig_sel_preds = np.array(orig_sel_preds); adv_sel_preds = np.array(adv_sel_preds)
num_flips = int((orig_sel_preds != adv_sel_preds).sum())
print(f"\nSelected graphs: {len(selected_idxs)}. Flipped after attack: {num_flips} ({100.0*num_flips/len(selected_idxs):.2f}%).")

# ---------------------- Jacobian & FD relative error at PERTURBED graphs ----------------------
print("\nComputing Jacobian norms & FD relative error at the PERTURBED graphs...")
per_sample_info = []   # (idx, label, jacobian_fro_norm, fd_rel_err)

for idx in selected_idxs:
    d = test_list[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE)  # [N,F]
    x_adv = x_adv.requires_grad_(True)
    edge_index = d.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_adv.shape
    batch_zero = torch.zeros(n_nodes, dtype=torch.long, device=DEVICE)

    # flatten node features -> vector
    x_flat = x_adv.view(-1).detach().clone().requires_grad_(True)

    def f_flat(z):
        x_mat = z.view_as(x_adv)
        out = model(x_mat, edge_index, batch_zero)   # [1, C]
        return out.squeeze(0)                        # (C,)

    # Compute Jacobian J \in R^{C x D} where D = N*F
    try:
        J = torch.autograd.functional.jacobian(f_flat, x_flat)  # (C, D)
    except RuntimeError:
        # Row-by-row fallback (rarely needed)
        C = int(model(x_adv, edge_index, batch_zero).shape[1])
        rows = []
        for out_i in range(C):
            def scalar_f(z, i=out_i): return f_flat(z)[i]
            r = torch.autograd.functional.jacobian(scalar_f, x_flat)  # (D,)
            rows.append(r.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    J = J.detach()
    jac_frob = float(torch.norm(J, p='fro').item())

    # FD relative error at the perturbed point
    delta_fd = FD_EPS * torch.randn_like(x_flat).to(DEVICE)
    pred_change = J @ delta_fd                          # (C,)
    f0  = f_flat(x_flat).detach()
    f0p = f_flat(x_flat + delta_fd).detach()
    actual_change = f0p - f0
    rel_err = float((torch.norm(pred_change - actual_change) /
                    (torch.norm(actual_change) + 1e-8)).item())

    per_sample_info.append((int(idx), int(d.y.item()), jac_frob, rel_err))

# ---------------------- Aggregate & print ----------------------
if len(per_sample_info) > 0:
    arr = np.array(per_sample_info, dtype=object)  # columns: idx, label, J_frob, FD_rel
    clean_mask = (arr[:,1] == 0)
    troj_mask  = (arr[:,1] == 1)

    def mstd(a): 
        return (float(np.mean(a)), float(np.std(a))) if len(a) else (0.0, 0.0)

    cJ, cJsd = mstd(arr[clean_mask, 2].astype(float))
    tJ, tJsd = mstd(arr[troj_mask,  2].astype(float))
    cE, cEsd = mstd(arr[clean_mask, 3].astype(float))
    tE, tEsd = mstd(arr[troj_mask,  3].astype(float))

    print("\n--- Jacobian Frobenius Norms & FD Relative Errors (Graph-level) ---")
    print(f" Clean graphs :  avg_norm={cJ:.4f} ± {cJsd:.4f}, avg_FDrel={cE:.4e} ± {cEsd:.4e}")
    print(f" Trojan graphs:  avg_norm={tJ:.4f} ± {tJsd:.4f}, avg_FDrel={tE:.4e} ± {tEsd:.4e}")

    print("\nSample preview (first 6): (idx,label,jacobian_frob,FD_rel_err)")
    for t in per_sample_info[:6]:
        print(t)
else:
    print("No Jacobian samples computed (empty selected set).")

print("\nDone. (Order: PGD -> full-test evaluation -> Jacobian @ perturbed inputs.)")


Test graphs: 14 | class counts = [ 4 10]
Selected to perturb: {0: 4, 1: 10}

Running PGD on selected graphs (strong settings)...
? PGD finished.

Accuracy: 71.43%
Precision: 0.6786, Recall: 0.7143, F1: 0.6797

Classification report:
              precision    recall  f1-score   support

       clean     0.5000    0.2500    0.3333         4
      trojan     0.7500    0.9000    0.8182        10

    accuracy                         0.7143        14
   macro avg     0.6250    0.5750    0.5758        14
weighted avg     0.6786    0.7143    0.6797        14

Confusion Matrix:
[[1 3]
 [1 9]]

Selected graphs: 14. Flipped after attack: 3 (21.43%).

Computing Jacobian norms & FD relative error at the PERTURBED graphs...

--- Jacobian Frobenius Norms & FD Relative Errors (Graph-level) ---
 Clean graphs :  avg_norm=0.0178 ± 0.0071, avg_FDrel=6.4374e-01 ± 4.7658e-01
 Trojan graphs:  avg_norm=0.0210 ± 0.0291, avg_FDrel=1.2029e+00 ± 7.0398e-01

Sample preview (first 6): (idx,label,jacobian_frob,FD_

#### Local Lipschitz Constants

In [5]:
# ============================================================
# Graph-level: Local Lipschitz Constants (with PGD perturbations)
# ============================================================
import torch, numpy as np, torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ---------------------- PARAMETERS ----------------------
PER_CLASS = 25      # graphs per class to perturb
EPSILON   = 50.0     # L2 budget
ALPHA     = 10.0     # PGD step size
NUM_ITERS = 100     # PGD iterations
FD_EPS    = 1e-1   # finite-diff epsilon
SEED      = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# ---------------------- Sanity check ----------------------
if 'graph_model' not in globals() or 'test_graphs' not in globals():
    raise RuntimeError("Need `graph_model` and `test_graphs` defined from training cell.")

model = graph_model.to(DEVICE)
model.eval()

# ---------------------- Build dataset view ----------------------
dataset = list(test_graphs)
labels_np = np.array([int(d.y.item()) for d in dataset])

# ---------------------- Select graphs (PER_CLASS/class) ----------------------
rng = np.random.default_rng(SEED)
selected = []
for cls in [0, 1]:
    idxs = np.where(labels_np == cls)[0]
    if len(idxs) == 0: 
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False)
    selected.extend(chosen.tolist())
selected = np.array(sorted(selected), dtype=np.int64)
print("Selected perturbation pool:", {0:int((labels_np[selected]==0).sum()), 1:int((labels_np[selected]==1).sum())})

# ---------------------- PGD: Lipschitz-directed ----------------------
perturbed_map = {}
orig_preds_list, adv_preds_list = [], []

print("\nRunning Lipschitz-directed PGD for selected graphs...")
for idx in selected:
    data = dataset[int(idx)]
    x_orig = data.x.detach().clone().to(DEVICE)       # [N,F]
    edge_index = data.edge_index.to(DEVICE)
    y_true = torch.tensor([int(data.y.item())], device=DEVICE)

    batch_zero = torch.zeros(x_orig.size(0), dtype=torch.long, device=DEVICE)

    # define f_local
    def f_local(x):
        return model(x, edge_index, batch_zero).squeeze(0)

    # Jacobian at x_orig
    try:
        J = torch.autograd.functional.jacobian(f_local, x_orig)   # (C,N,F)
    except RuntimeError:
        logits0 = f_local(x_orig).detach()
        C = logits0.shape[0]
        rows = []
        for c in range(C):
            def scalar_f(x, cidx=c): return f_local(x)[cidx]
            row = torch.autograd.functional.jacobian(scalar_f, x_orig)
            rows.append(row.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    C, N, Fdim = J.shape
    J_flat = J.reshape(C, N*Fdim)

    # leading right-singular vector
    try:
        _, S, Vh = torch.linalg.svd(J_flat, full_matrices=False)
        v = Vh[0,:].detach()
    except RuntimeError:
        v = torch.randn(J_flat.shape[1], device=DEVICE)
    v = v / (v.norm() + 1e-12)

    v_mat = v.view_as(x_orig)
    x_adv = (x_orig + 0.5*EPSILON*v_mat).detach().clone().requires_grad_(True)

    # PGD loop
    for it in range(NUM_ITERS):
        logits = model(x_adv, edge_index, batch_zero)
        loss = F.cross_entropy(logits, y_true)
        grad = torch.autograd.grad(loss, x_adv)[0]
        gnorm = grad.view(-1).norm().item()
        if gnorm == 0: break
        step = ALPHA * grad / (gnorm + 1e-12)
        x_adv = (x_adv + step).detach()

        delta = (x_adv - x_orig).view(-1)
        dnorm = delta.norm().item()
        if dnorm > EPSILON:
            delta = delta * (EPSILON / (dnorm+1e-12))
            x_adv = (x_orig + delta.view_as(x_orig)).detach()
        x_adv = x_adv.requires_grad_(True)

    perturbed_map[int(idx)] = x_adv.detach()

    # record flips
    with torch.no_grad():
        p_orig = f_local(x_orig).argmax().item()
        p_adv  = f_local(x_adv).argmax().item()
    orig_preds_list.append(p_orig); adv_preds_list.append(p_adv)

print("? Finished perturbations.")

# ---------------------- Evaluate full test set ----------------------
print("\n============= Robustness Evaluation (Full Test Set) =============")
y_true_all, y_pred_all = [], []
with torch.no_grad():
    for i, d in enumerate(dataset):
        x_eval = perturbed_map[i] if i in perturbed_map else d.x.to(DEVICE)
        edge_index = d.edge_index.to(DEVICE)
        batch_zero = torch.zeros(x_eval.size(0), dtype=torch.long, device=DEVICE)
        logits = model(x_eval, edge_index, batch_zero)
        y_true_all.append(int(d.y.item()))
        y_pred_all.append(int(logits.argmax(dim=1).item()))

y_true_all = np.array(y_true_all)
y_pred_all = np.array(y_pred_all)

acc = (y_true_all == y_pred_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(y_true_all, y_pred_all, average='weighted', zero_division=0)
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification Report:")
print(classification_report(y_true_all, y_pred_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(y_true_all, y_pred_all, labels=[0,1]))

num_flips = int((np.array(orig_preds_list) != np.array(adv_preds_list)).sum())
print(f"\nSelected graphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ---------------------- Compute Local Lipschitz on perturbed graphs ----------------------
print("\nComputing Local Lipschitz constants + FD relative errors...")
per_sample_info = []
for idx in selected:
    d = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = d.edge_index.to(DEVICE)
    label = int(d.y.item())
    batch_zero = torch.zeros(x_adv.size(0), dtype=torch.long, device=DEVICE)

    def f_local_adv(x): return model(x, edge_index, batch_zero).squeeze(0)

    # Jacobian
    try:
        J = torch.autograd.functional.jacobian(f_local_adv, x_adv).detach()
    except RuntimeError:
        logits0 = f_local_adv(x_adv).detach()
        C = logits0.shape[0]; rows = []
        for c in range(C):
            def scalar_f(x, cidx=c): return f_local_adv(x)[cidx]
            row = torch.autograd.functional.jacobian(scalar_f, x_adv)
            rows.append(row.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    C, N, Fdim = J.shape
    J_flat = J.reshape(C, N*Fdim)

    # spectral norm
    try:
        _, S, _ = torch.linalg.svd(J_flat, full_matrices=False)
        sigma_max = float(S[0].item())
    except RuntimeError:
        JJT = (J_flat @ J_flat.T).cpu().numpy()
        eigvals = np.linalg.eigvalsh(JJT)
        sigma_max = float(np.sqrt(max(eigvals.max(), 0.0)))

    # FD relative error
    delta_fd = FD_EPS * torch.randn(N*Fdim, device=DEVICE)
    pred_change = J_flat @ delta_fd
    f0  = f_local_adv(x_adv).detach()
    f0p = f_local_adv((x_adv + delta_fd.view_as(x_adv))).detach()
    actual_change = f0p - f0
    fd_rel_err = (torch.norm(pred_change - actual_change) / (torch.norm(actual_change)+1e-8)).item()

    per_sample_info.append((int(idx), label, sigma_max, fd_rel_err))

# ---------------------- Aggregate ----------------------
clean_stats = [p for p in per_sample_info if p[1]==0]
troj_stats  = [p for p in per_sample_info if p[1]==1]

def aggs(stats):
    if not stats: return (0,0,0,0)
    Ls = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean,cL_std,cE_mean,cE_std = aggs(clean_stats)
tL_mean,tL_std,tE_mean,tE_std = aggs(troj_stats)

print("\n--- Local Lipschitz Constants (graph-level) ---")
print(f" Clean:  avg_L={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_L={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,L,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)


Selected perturbation pool: {0: 4, 1: 10}

Running Lipschitz-directed PGD for selected graphs...
? Finished perturbations.

Accuracy: 35.71%
Precision: 0.4490, Recall: 0.3571, F1: 0.3881

Classification Report:
              precision    recall  f1-score   support

       clean     0.1429    0.2500    0.1818         4
      trojan     0.5714    0.4000    0.4706        10

    accuracy                         0.3571        14
   macro avg     0.3571    0.3250    0.3262        14
weighted avg     0.4490    0.3571    0.3881        14

Confusion Matrix:
[[1 3]
 [6 4]]

Selected graphs: 14. Flipped after perturbation: 8 (57.14%).

Computing Local Lipschitz constants + FD relative errors...

--- Local Lipschitz Constants (graph-level) ---
 Clean:  avg_L=0.0200 ± 0.0084, avg_FDrel=4.0474e-01 ± 2.2832e-01
 Trojan: avg_L=0.0167 ± 0.0198, avg_FDrel=9.6798e-01 ± 2.9770e-01

Sample preview (first 6): (idx,label,L,FD_rel_err)
(0, 0, 0.023152219131588936, 0.19119632244110107)
(1, 0, 0.02562996372580

#### Hessian-based Curvature Analysis

In [13]:
# ----------------------- Hessian-Based Curvature (Graph-Level, notebook-friendly) -----------------------
import torch, numpy as np, torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, accuracy_score

# ---------- PARAMETERS (tweak if needed) ----------
PER_CLASS = 250        # graphs per class to include in curvature pool (will be clipped by available graphs)
FD_EPS = 5e-1          # finite-difference epsilon for FD error
TRIALS_PER_GRAPH = 5   # FD trials per graph
PERT_P = 50.0          # L2 magnitude (applied per-node-feature matrix) for Hessian-aligned perturbation
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# ---------- locate test graphs in notebook (support test_graphs or test_loader.dataset) ----------
required = ["graph_model", "DEVICE"]
missing = [v for v in required if v not in globals()]
if missing:
    raise RuntimeError(f"Missing required variables in notebook: {missing}. Run your training/fine-tune cell first.")

if "test_graphs" in globals():
    test_list = list(test_graphs)
elif "test_loader" in globals():
    # try to materialize dataset from test_loader.dataset
    try:
        ds = test_loader.dataset
        test_list = [ds[i] for i in range(len(ds))]
    except Exception:
        # fallback: try to iterate loader (may yield batched Data objects)
        test_list = []
        for b in test_loader:
            # if loader yields batched Data, break them out to individual graphs is complex;
            # usually `test_graphs` is present in your notebook so this branch is rarely used.
            raise RuntimeError("Cannot materialize individual graphs from test_loader. Prefer having `test_graphs` list in notebook.")
else:
    raise RuntimeError("Need `test_graphs` or `test_loader` defined in the notebook environment.")

print(f"Found {len(test_list)} test graphs.")

# ---------- quick helpers ----------
device = DEVICE
model = graph_model.to(device)
model.eval()
rng = np.random.default_rng(SEED)

def batch_for(x):
    # single-graph batching: all nodes belong to batch index 0
    return torch.zeros(x.size(0), dtype=torch.long, device=device)

# ---------- base predictions (clean) ----------
base_preds = []
with torch.no_grad():
    for data in test_list:
        x = data.x.to(device)
        edge_index = data.edge_index.to(device)
        out = model(x, edge_index, batch_for(x))    # [1, C]
        base_preds.append(int(out.argmax(dim=1).item()))
base_preds = np.array(base_preds)
labels_all = np.array([int(d.y.item()) for d in test_list])
print("Class counts in test set:", np.bincount(labels_all) if len(labels_all)>0 else "[]")

# ---------- pick PER_CLASS graphs / class ----------
selected = []
for cls in [0, 1]:
    idxs = np.where(labels_all == cls)[0]
    if len(idxs) == 0:
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False).tolist()
    selected.extend(chosen)
selected = np.array(sorted(selected), dtype=np.int64)
print(f"Selected pool size -> clean: {(labels_all[selected]==0).sum()}, trojan: {(labels_all[selected]==1).sum()}")

# ---------- compute Hessian-curvature proxies on selected graphs ----------
per_sample_info = []   # tuples: (idx, label, lambda_proxy, avg_fd_rel_err)
print("\nComputing Hessian-curvature proxies on selected graphs... (this will compute per-graph grads)")

for idx in selected:
    data = test_list[int(idx)]
    x0 = data.x.detach().clone().to(device).requires_grad_(True)   # [N, F]
    edge_index = data.edge_index.to(device)

    # predicted class at x0 (use that class to estimate curvature proxy)
    with torch.no_grad():
        logits = model(x0, edge_index, batch_for(x0))
    pred_class = int(logits.argmax(dim=1).item())

    # scalar function h(x) = log P(pred_class | x)
    def h_pred(x):
        logits = model(x, edge_index, batch_for(x))    # [1, C]
        lp = F.log_softmax(logits.squeeze(0), dim=0)
        return lp[pred_class]

    h0 = h_pred(x0)
    # gradient of log-prob of predicted class wrt node features (shape [N, F])
    g = torch.autograd.grad(h0, x0, retain_graph=False, create_graph=False)[0].detach()

    lambda_proxy = float((g.norm(p=2).item() ** 2))

    # finite-difference second-order relative error (to check local quadratic approx)
    errs = []
    for _ in range(TRIALS_PER_GRAPH):
        delta = FD_EPS * torch.randn_like(x0).to(device)
        gt_delta = float((g * delta).sum().item())
        pred_second = 0.5 * (gt_delta ** 2)
        # actual second-order term: h(x0+delta) - h0 - g^T delta
        actual_second = float((h_pred((x0 + delta).detach()) - h0 - (g * delta).sum()).item())
        rel_err = abs(pred_second - actual_second) / (abs(actual_second) + 1e-8)
        errs.append(rel_err)
    avg_rel_err = float(np.mean(errs))
    per_sample_info.append((int(idx), int(data.y.item()), lambda_proxy, avg_rel_err))

# ---------- construct Hessian-aligned perturbations ----------
print("\nConstructing Hessian-aligned perturbations (using negative gradient of true label log-prob)...")
perturbed_map = {}
for (idx, label, lam, fd_err) in per_sample_info:
    data = test_list[int(idx)]
    x0 = data.x.detach().clone().to(device).requires_grad_(True)
    edge_index = data.edge_index.to(device)

    # define h_true = log P(true_label | x) so negative grad reduces true-class score
    true_label = int(data.y.item())
    def h_true(x):
        logits = model(x, edge_index, batch_for(x))
        lp = F.log_softmax(logits.squeeze(0), dim=0)
        return lp[true_label]

    g_true = torch.autograd.grad(h_true(x0), x0, retain_graph=False, create_graph=False)[0].detach()
    gnorm = g_true.norm().item()
    if gnorm < 1e-12:
        dir_vec = torch.randn_like(x0).to(device)
    else:
        dir_vec = - g_true / (gnorm + 1e-12)   # direction to reduce true class score

    delta = (PERT_P * dir_vec).detach()
    perturbed_map[int(idx)] = (x0 + delta).detach()

print(f"Perturbations built for {len(perturbed_map)} selected graphs.")

# ---------- evaluate model on entire test set: clean vs perturbed ----------
print("\n================ Robustness Evaluation (Full Test Set: clean vs perturbed) ===============")
preds_clean = []
preds_pert  = []
with torch.no_grad():
    for i, data in enumerate(test_list):
        x_clean = data.x.to(device)
        eidx = data.edge_index.to(device)
        out_clean = model(x_clean, eidx, batch_for(x_clean))
        preds_clean.append(int(out_clean.argmax(dim=1).item()))

        if int(i) in perturbed_map:
            x_eval = perturbed_map[int(i)].to(device)
        else:
            x_eval = x_clean
        out_pert = model(x_eval, eidx, batch_for(x_eval))
        preds_pert.append(int(out_pert.argmax(dim=1).item()))

preds_clean = np.array(preds_clean)
preds_pert  = np.array(preds_pert)

# ----- metrics printing helper -----
def print_metrics(name, preds, labels):
    acc = (preds == labels).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    print(f"\n--- {name} ---")
    print(f"Accuracy: {acc*100:.2f}%")
    print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")
    print("Classification Report:")
    print(classification_report(labels, preds, target_names=['clean','trojan'], digits=4))
    print("Confusion Matrix:")
    print(confusion_matrix(labels, preds, labels=[0,1]))

print_metrics("CLEAN (original features)", preds_clean, labels_all)
print_metrics("PERTURBED (Hessian-aligned on selected graphs)", preds_pert, labels_all)

# ---------- flip statistics for selected set ----------
orig_sel_preds = base_preds[selected]
adv_sel_preds  = np.array([preds_pert[i] for i in selected])
num_flips = int((orig_sel_preds != adv_sel_preds).sum())
print(f"\nSelected graphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ---------- aggregate curvature stats by class ----------
clean_stats = [p for p in per_sample_info if p[1] == 0]
troj_stats  = [p for p in per_sample_info if p[1] == 1]

def summarize(stats):
    if not stats: return (0.0, 0.0, 0.0, 0.0)
    Ls = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = summarize(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = summarize(troj_stats)

print("\n--- Hessian Curvature Stats (grad outer-product proxy) ---")
print(f" Clean:  avg_lambda={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_lambda={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,lambda_proxy,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)

print("\nDone. (Hessian-proxy computation -> perturbations -> clean/perturbed evaluation.)")


Found 14 test graphs.
Class counts in test set: [ 4 10]
Selected pool size -> clean: 4, trojan: 10

Computing Hessian-curvature proxies on selected graphs... (this will compute per-graph grads)

Constructing Hessian-aligned perturbations (using negative gradient of true label log-prob)...
Perturbations built for 14 selected graphs.


--- CLEAN (original features) ---
Accuracy: 92.86%
Precision: 0.9351, Recall: 0.9286, F1: 0.9252
Classification Report:
              precision    recall  f1-score   support

       clean     1.0000    0.7500    0.8571         4
      trojan     0.9091    1.0000    0.9524        10

    accuracy                         0.9286        14
   macro avg     0.9545    0.8750    0.9048        14
weighted avg     0.9351    0.9286    0.9252        14

Confusion Matrix:
[[ 3  1]
 [ 0 10]]

--- PERTURBED (Hessian-aligned on selected graphs) ---
Accuracy: 50.00%
Precision: 0.5333, Recall: 0.5000, F1: 0.5146
Classification Report:
              precision    recall  f1-

#### Prediction Margin

In [16]:
# ----------------------- Prediction Margin (Graph-Level, notebook-friendly) -----------------------
import torch, numpy as np, torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, accuracy_score

# ----- parameters -----
PER_CLASS = 250       # graphs per class
EPSILON = 50.0         # PGD perturbation budget (L2 norm)
ALPHA = 0.9           # PGD step size
NUM_ITERS = 150        # PGD iterations
FD_EPS = 1e-1         # finite-difference perturbation
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# ----- locate dataset -----
if "test_graphs" in globals():
    dataset = list(test_graphs)
elif "test_loader" in globals():
    dataset = [test_loader.dataset[i] for i in range(len(test_loader.dataset))]
else:
    raise RuntimeError("Need `test_graphs` or `test_loader` defined in notebook.")

labels_np = np.array([int(d.y.item()) for d in dataset])
n_test = len(dataset)
print(f"Found {n_test} test graphs.")

# ----- select PER_CLASS graphs per class -----
rng = np.random.default_rng(SEED)
selected = []
for cls in [0, 1]:
    idxs = np.where(labels_np == cls)[0].tolist()
    if len(idxs) == 0: continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False).tolist()
    selected.extend(chosen)
selected = np.array(sorted(selected), dtype=np.int64)
print(f"Selected pool: clean={int((labels_np[selected]==0).sum())}, trojan={int((labels_np[selected]==1).sum())}")

# helper for batching
def batch_for(x): 
    return torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)

# ----- base predictions -----
model = graph_model.to(DEVICE)
model.eval()
base_preds = []
with torch.no_grad():
    for data in dataset:
        logits = model(data.x.to(DEVICE), data.edge_index.to(DEVICE), batch_for(data.x.to(DEVICE)))
        base_preds.append(int(logits.argmax()))
base_preds = np.array(base_preds)

# ----- adversarial perturbations (PGD) -----
perturbed_map = {}
print("\nRunning PGD perturbations on selected graphs...")
for idx in selected:
    data = dataset[int(idx)]
    x_orig = data.x.detach().clone().to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)
    y_scalar = int(data.y.item())
    target = torch.tensor([y_scalar], dtype=torch.long, device=DEVICE)

    # random init inside epsilon-ball
    delta = torch.randn_like(x_orig).to(DEVICE)
    delta = EPSILON * delta / (delta.norm() + 1e-12)
    x_adv = (x_orig + delta).detach().clone().requires_grad_(True)

    for it in range(NUM_ITERS):
        logits = model(x_adv, edge_index, batch_for(x_adv))
        loss = F.cross_entropy(logits, target)
        grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        step = ALPHA * grad / (grad.norm() + 1e-12)
        x_adv = (x_adv + step).detach()
        delta = x_adv - x_orig
        if delta.norm() > EPSILON:
            delta = delta * (EPSILON / (delta.norm() + 1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)

    perturbed_map[int(idx)] = x_adv.detach()
print("? Finished perturbations.")

# ----- evaluate on full test set (clean vs perturbed) -----
preds_clean, preds_pert = [], []
with torch.no_grad():
    for i, data in enumerate(dataset):
        # clean
        logits = model(data.x.to(DEVICE), data.edge_index.to(DEVICE), batch_for(data.x.to(DEVICE)))
        preds_clean.append(int(logits.argmax()))

        # perturbed if selected, else clean
        if int(i) in perturbed_map:
            x_eval = perturbed_map[int(i)]
        else:
            x_eval = data.x.to(DEVICE)
        logits = model(x_eval, data.edge_index.to(DEVICE), batch_for(x_eval))
        preds_pert.append(int(logits.argmax()))

preds_clean = np.array(preds_clean)
preds_pert  = np.array(preds_pert)

labels_all = labels_np

def print_metrics(name, preds, labels):
    acc = (preds == labels).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    print(f"\n--- {name} ---")
    print(f"Accuracy: {acc*100:.2f}%")
    print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")
    print("Classification Report:")
    print(classification_report(labels, preds, target_names=['clean','trojan'], digits=4))
    print("Confusion Matrix:")
    print(confusion_matrix(labels, preds, labels=[0,1]))

print_metrics("CLEAN (original features)", preds_clean, labels_all)
print_metrics("PERTURBED (PGD, selected graphs)", preds_pert, labels_all)

# ----- flip stats -----
orig_sel_preds = base_preds[selected]
adv_sel_preds  = np.array([preds_pert[i] for i in selected])
num_flips = int((orig_sel_preds != adv_sel_preds).sum())
print(f"\nSelected graphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ----- compute prediction margin + FD relative error -----
per_sample_info = []  # (idx, label, margin, FD_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_eval = perturbed_map[int(idx)]
    edge_index = data.edge_index.to(DEVICE)

    with torch.no_grad():
        logits = model(x_eval, edge_index, batch_for(x_eval)).squeeze(0)
    pred_class = int(logits.argmax().item())
    pred_logit = logits[pred_class].item()
    other_logits = logits.clone(); other_logits[pred_class] = -float("inf")
    second_max = other_logits.max().item()
    margin = pred_logit - second_max

    # finite-difference check
    delta = FD_EPS * torch.randn_like(x_eval).to(DEVICE)
    with torch.no_grad():
        logits_p = model(x_eval + delta, edge_index, batch_for(x_eval)).squeeze(0)
    pred_logit_p = logits_p[pred_class].item()
    other_logits_p = logits_p.clone(); other_logits_p[pred_class] = -float("inf")
    second_max_p = other_logits_p.max().item()
    margin_p = pred_logit_p - second_max_p
    rel_err = abs(margin - margin_p) / (abs(margin_p) + 1e-12)

    per_sample_info.append((int(idx), int(data.y.item()), float(margin), float(rel_err)))

# ----- aggregate stats -----
def summarize(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ms = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ms.mean(), Ms.std(), Es.mean(), Es.std())

clean_stats = [p for p in per_sample_info if p[1]==0]
troj_stats  = [p for p in per_sample_info if p[1]==1]

cM_mean, cM_std, cE_mean, cE_std = summarize(clean_stats)
tM_mean, tM_std, tE_mean, tE_std = summarize(troj_stats)

print("\n--- Prediction Margin Stats (on perturbed graphs) ---")
print(f" Clean:  avg_margin={cM_mean:.4f} ± {cM_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_margin={tM_mean:.4f} ± {tM_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,margin,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)

print("\n? Done. (Prediction Margin: PGD perturbations -> clean/perturbed eval -> margin stats.)")


Found 14 test graphs.
Selected pool: clean=4, trojan=10

Running PGD perturbations on selected graphs...
? Finished perturbations.

--- CLEAN (original features) ---
Accuracy: 92.86%
Precision: 0.9351, Recall: 0.9286, F1: 0.9252
Classification Report:
              precision    recall  f1-score   support

       clean     1.0000    0.7500    0.8571         4
      trojan     0.9091    1.0000    0.9524        10

    accuracy                         0.9286        14
   macro avg     0.9545    0.8750    0.9048        14
weighted avg     0.9351    0.9286    0.9252        14

Confusion Matrix:
[[ 3  1]
 [ 0 10]]

--- PERTURBED (PGD, selected graphs) ---
Accuracy: 42.86%
Precision: 0.4940, Recall: 0.4286, F1: 0.4540
Classification Report:
              precision    recall  f1-score   support

       clean     0.1667    0.2500    0.2000         4
      trojan     0.6250    0.5000    0.5556        10

    accuracy                         0.4286        14
   macro avg     0.3958    0.3750    0

#### Adversarial Robustness Radius

In [18]:
# ----------------------- Adversarial Robustness Radius (Graph-Level) -----------------------
import time
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ----- helpers -----
def f_for_graph(x_tensor, data):
    """Return logits for given graph with node features replaced by x_tensor."""
    x_tensor = x_tensor.to(DEVICE)
    with torch.no_grad():
        out = graph_model(x_tensor, data.edge_index.to(DEVICE), batch_for(x_tensor))
    return out.squeeze(0)

def adversarial_radius_for_graph(data, x0, initial_epsilon=1e-3, growth_factor=1.25,
                                 max_epsilon=20.0, bs_iters=10, num_trials=6):
    """Estimate minimal perturbation radius that flips prediction (around perturbed point)."""
    with torch.no_grad():
        base_out = graph_model(x0, data.edge_index.to(DEVICE), batch_for(x0))
        y0 = int(base_out.argmax().item())

    def is_same(x):
        out = f_for_graph(x, data)
        return int(out.argmax().item()) == y0

    radii = []
    for _ in range(num_trials):
        d = torch.randn_like(x0)
        d = d / (d.norm() + 1e-12)

        eps = initial_epsilon
        while eps < max_epsilon and is_same(x0 + eps * d):
            eps *= growth_factor

        if eps >= max_epsilon:
            candidate = max_epsilon
        else:
            low, high = eps / growth_factor, eps
            for _ in range(bs_iters):
                mid = 0.5 * (low + high)
                if is_same(x0 + mid * d):
                    low = mid
                else:
                    high = mid
            candidate = float(high)
        radii.append(candidate)
    return float(min(radii))

def adversarial_radius_relerr(data, x0):
    r1 = adversarial_radius_for_graph(data, x0, growth_factor=2.05, bs_iters=10, num_trials=6)
    r2 = adversarial_radius_for_graph(data, x0, growth_factor=1.8, bs_iters=12, num_trials=6)
    rel_err = abs(r1 - r2) / (abs(r2) + 1e-12)
    return r1, rel_err

# ----- ARR computation on selected perturbed graphs -----
class_names = ['clean', 'trojan']
class_adv_radius = {cn: [] for cn in class_names}
class_rel_errors = {cn: [] for cn in class_names}
all_radii, all_rel_errs = [], []

t0 = time.time()
print("\nComputing Adversarial Robustness Radius (ARR) for selected perturbed graphs...")
for i, idx in enumerate(selected):
    idx = int(idx)
    data = dataset[idx]
    label = int(data.y.item())
    cn = class_names[label]

    x0 = perturbed_map[idx].clone().detach().to(DEVICE)  # start from perturbed features
    r, rel = adversarial_radius_relerr(data, x0)

    class_adv_radius[cn].append(r)
    class_rel_errors[cn].append(rel)
    all_radii.append(r)
    all_rel_errs.append(rel)

    if (i+1) % 20 == 0:
        print(f"  processed {i+1}/{len(selected)} graphs...")

t1 = time.time()
print(f"? Done ARR computation. Time elapsed: {t1-t0:.1f}s")

# ----- Reporting ARR aggregates -----
print("\n--- Adversarial Robustness Radius (ARR) Stats ---")
print("{:<10s} {:>14s} {:>22s}".format("Class", "Avg Radius ± Std", "Avg Rel. Error ± Std"))
print("-"*52)
for cn in class_names:
    if class_adv_radius[cn]:
        print("{:<10s} {:>7.4f} ± {:<7.4f} {:>14.4e} ± {:<10.4e}".format(
            cn, np.mean(class_adv_radius[cn]), np.std(class_adv_radius[cn]),
            np.mean(class_rel_errors[cn]), np.std(class_rel_errors[cn])
        ))
    else:
        print("{:<10s} {:>10s}".format(cn, "-"))

print("\nOverall ARR: Avg Radius: {:.4f} ± {:.4f}".format(np.mean(all_radii), np.std(all_radii)))
print("Overall ARR: Avg Relative Error: {:.4e} ± {:.4e}".format(np.mean(all_rel_errs), np.std(all_rel_errs)))

# ------------------ Evaluate model on full test set (selected perturbed + others clean) ------------------
graph_model.eval()
with torch.no_grad():
    all_logits = []
    all_labels = []
    for i, data in enumerate(dataset):
        x_in = perturbed_map[i].to(DEVICE) if i in perturbed_map else data.x.to(DEVICE)
        logits = graph_model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        all_logits.append(logits.cpu())
        all_labels.append(int(data.y.item()))

    all_logits = torch.cat(all_logits, dim=0)
    preds_all = all_logits.argmax(dim=1).numpy()
    labels_all = np.array(all_labels)

acc = (preds_all == labels_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_all, preds_all, average='weighted', zero_division=0)

print("\n============= Robustness Evaluation (Full Test Set: perturbed selected graphs + others clean) =============")
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification report:")
print(classification_report(labels_all, preds_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_all, preds_all, labels=[0,1]))

# ------------------ Flip statistics for selected graphs ------------------
with torch.no_grad():
    orig_preds = []
    for data in dataset:
        x_in = data.x.to(DEVICE)
        logits = graph_model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        orig_preds.append(logits.cpu())
    orig_preds = torch.cat(orig_preds, dim=0).argmax(dim=1).numpy()

adv_sel_preds = preds_all[selected]
num_flips = int((orig_preds[selected] != adv_sel_preds).sum())
print(f"\nSelected graphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0 * num_flips/len(selected):.2f}%).")



Computing Adversarial Robustness Radius (ARR) for selected perturbed graphs...
? Done ARR computation. Time elapsed: 14.3s

--- Adversarial Robustness Radius (ARR) Stats ---
Class      Avg Radius ± Std   Avg Rel. Error ± Std
----------------------------------------------------
clean      20.0000 ± 0.0000      0.0000e+00 ± 0.0000e+00
trojan     20.0000 ± 0.0000      0.0000e+00 ± 0.0000e+00

Overall ARR: Avg Radius: 20.0000 ± 0.0000
Overall ARR: Avg Relative Error: 0.0000e+00 ± 0.0000e+00

Accuracy: 42.86%
Precision: 0.4940, Recall: 0.4286, F1: 0.4540

Classification report:
              precision    recall  f1-score   support

       clean     0.1667    0.2500    0.2000         4
      trojan     0.6250    0.5000    0.5556        10

    accuracy                         0.4286        14
   macro avg     0.3958    0.3750    0.3778        14
weighted avg     0.4940    0.4286    0.4540        14

Confusion Matrix:
[[1 3]
 [5 5]]

Selected graphs: 14. Flipped after perturbation: 7 (50.00%

#### Stability Under Input Noise

In [22]:
# ============================================================
# Stability Under Input Noise (SUIN) - Graph classification
# ============================================================
import torch
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import time

# --- Parameters ---
NOISE_SIGMA        = 0.8   # Gaussian noise stddev for stability metric
NUM_NOISE_SAMPLES  = 50    # Monte Carlo samples per graph
RELERR_RESAMPLES   = 10     # repeats to estimate relative error

print("\n--- Stability Under Input Noise (graph-level) ---")
print(f"Selected graphs: clean={(labels_np[selected]==0).sum()}, trojan={(labels_np[selected]==1).sum()}")

# ============================================================
# Step 1: Evaluate model on full dataset (perturbed + original)
# ============================================================
graph_model.eval()
with torch.no_grad():
    all_logits, all_labels = [], []
    for i, data in enumerate(dataset):
        # use perturbed features if available
        x_in = perturbed_map[i].to(DEVICE) if i in perturbed_map else data.x.to(DEVICE)
        logits = graph_model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        all_logits.append(logits.cpu())
        all_labels.append(int(data.y.item()))

    all_logits = torch.cat(all_logits, dim=0)
    preds_all  = all_logits.argmax(dim=1).numpy()
    labels_all = np.array(all_labels)

acc = (preds_all == labels_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_all, preds_all, average='weighted', zero_division=0)

print("\n============= Robustness Evaluation (Full Test Set: perturbed selected + others unperturbed) =============")
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification report:")
print(classification_report(labels_all, preds_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_all, preds_all, labels=[0,1]))

# Flip statistics for selected graphs
with torch.no_grad():
    orig_preds = []
    for data in dataset:
        x_in = data.x.to(DEVICE)
        logits = graph_model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        orig_preds.append(logits.cpu())
    orig_preds = torch.cat(orig_preds, dim=0).argmax(dim=1).numpy()

adv_sel_preds = preds_all[selected]
num_flips = int((orig_preds[selected] != adv_sel_preds).sum())
print(f"\nSelected graphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ============================================================
# Step 2: Stability Under Input Noise (on perturbed graphs)
# ============================================================
def stability_for_graph(idx, sigma, num_samples):
    """Compute avg L2 change in logits for noisy perturbations around perturbed graph."""
    data = dataset[idx]
    base_x = perturbed_map[idx].to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)
    batch = batch_for(base_x)

    with torch.no_grad():
        f_orig = graph_model(base_x, edge_index, batch).squeeze()

    diffs = []
    for _ in range(num_samples):
        noise = sigma * torch.randn_like(base_x).to(DEVICE)
        f_noisy = graph_model(base_x + noise, edge_index, batch).squeeze()
        diffs.append(torch.norm(f_noisy - f_orig).item())
    return float(np.mean(diffs))

print("\nComputing Stability Under Input Noise (this may take a while)...")
t0 = time.time()
per_sample_info = []  # (idx, label, stability, rel_err)
for i, idx in enumerate(selected):
    s_val = stability_for_graph(idx, NOISE_SIGMA, NUM_NOISE_SAMPLES)
    re_vals = [stability_for_graph(idx, NOISE_SIGMA, NUM_NOISE_SAMPLES) for _ in range(RELERR_RESAMPLES)]
    s_ref = float(np.mean(re_vals))
    rel_err = abs(s_val - s_ref) / (abs(s_ref) + 1e-12)
    per_sample_info.append((int(idx), int(labels_np[idx]), s_val, rel_err))
    if (i+1) % 10 == 0:
        print(f"  processed {i+1}/{len(selected)} graphs...")
t1 = time.time()
print(f"? Done SUIN computation. Time elapsed: {t1-t0:.1f}s")

# ============================================================
# Step 3: Aggregate and report
# ============================================================
clean_stats = [(i, s, e) for (i, l, s, e) in per_sample_info if l == 0]
troj_stats  = [(i, s, e) for (i, l, s, e) in per_sample_info if l == 1]

def aggs(stats):
    if not stats:
        return (0.0, 0.0, 0.0, 0.0)
    Ss = np.array([s for (_, s, _) in stats])
    Es = np.array([e for (_, _, e) in stats])
    return (Ss.mean(), Ss.std(), Es.mean(), Es.std())

cS_mean, cS_std, cE_mean, cE_std = aggs(clean_stats)
tS_mean, tS_std, tE_mean, tE_std = aggs(troj_stats)

print("\n--- Stability Under Input Noise (on perturbed graphs) ---")
print(f" Clean:  avg_stability={cS_mean:.4f} ± {cS_std:.4f}, avg_relerr={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_stability={tS_mean:.4f} ± {tS_std:.4f}, avg_relerr={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx, label, stability, rel_err)")
for p in per_sample_info[:6]:
    print(p)



--- Stability Under Input Noise (graph-level) ---
Selected graphs: clean=4, trojan=10

Accuracy: 42.86%
Precision: 0.4940, Recall: 0.4286, F1: 0.4540

Classification report:
              precision    recall  f1-score   support

       clean     0.1667    0.2500    0.2000         4
      trojan     0.6250    0.5000    0.5556        10

    accuracy                         0.4286        14
   macro avg     0.3958    0.3750    0.3778        14
weighted avg     0.4940    0.4286    0.4540        14

Confusion Matrix:
[[1 3]
 [5 5]]

Selected graphs: 14. Flipped after perturbation: 7 (50.00%).

Computing Stability Under Input Noise (this may take a while)...
  processed 10/14 graphs...
? Done SUIN computation. Time elapsed: 60.2s

--- Stability Under Input Noise (on perturbed graphs) ---
 Clean:  avg_stability=0.0304 ± 0.0131, avg_relerr=3.1787e-02 ± 3.9258e-02
 Trojan: avg_stability=0.0833 ± 0.0260, avg_relerr=1.3912e-02 ± 1.5433e-02

Sample preview (first 6): (idx, label, stability, rel_