#### Training Evaluation with GraphSAINT taken from https://github.com/DfX-NYUAD/TrojanSAINT

In [2]:
# train_trojansaint_node_fixed.py
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

NODE_CSV  = "GNNDatasets/node.csv"
EDGE_CSV  = "GNNDatasets/node_edges.csv"
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)

# ----------------------------- Load nodes -----------------------------
nodes_df = pd.read_csv(NODE_CSV)

label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        label_col = cand; break
if label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

feat_df = nodes_df.copy()
if "gate_type" in feat_df.columns:
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name",label_col}
num_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
X = feat_df[num_cols].fillna(0.0).values.astype(np.float32)
y = nodes_df[label_col].values.astype(np.int64)

# ----------------------------- Load edges; add missing nodes -----------------------------
edges_df = pd.read_csv(EDGE_CSV)
edges_df["src_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["src"].astype(str)
edges_df["dst_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["dst"].astype(str)

known_uids = set(nodes_df["uid"])
edge_uids = set(edges_df["src_uid"]).union(set(edges_df["dst_uid"]))
missing = list(edge_uids - known_uids)

if missing:
    zero_row = np.zeros((1, X.shape[1]), dtype=np.float32)
    addX = np.repeat(zero_row, len(missing), axis=0)
    addY = -1*np.ones(len(missing), dtype=np.int64)
    add_df = pd.DataFrame({
        "uid": missing,
        "circuit_name": [u.split("::",1)[0] for u in missing],
        "node": [u.split("::",1)[1] for u in missing],
        label_col: addY
    })
    X = np.vstack([X, addX])
    y = np.concatenate([y, addY])
    nodes_df = pd.concat([nodes_df, add_df], ignore_index=True)

uid_to_idx = {u:i for i,u in enumerate(nodes_df["uid"].tolist())}
src_idx = edges_df["src_uid"].map(uid_to_idx).dropna().astype(int).values
dst_idx = edges_df["dst_uid"].map(uid_to_idx).dropna().astype(int).values
edge_index = np.stack([np.concatenate([src_idx, dst_idx]),
                       np.concatenate([dst_idx, src_idx])], axis=0)

num_nodes = X.shape(0) if callable(getattr(X, "shape", None)) else X.shape[0]

# ----------------------------- Scale features -----------------------------
labeled_mask_np = (y >= 0)
scaler = StandardScaler()
X_scaled = X.copy()
X_scaled[labeled_mask_np] = scaler.fit_transform(X_scaled[labeled_mask_np])
if (~labeled_mask_np).any():
    X_scaled[~labeled_mask_np] = (X_scaled[~labeled_mask_np] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# ----------------------------- Splits -----------------------------
idx_all = np.where(labeled_mask_np)[0]
y_all = y[labeled_mask_np]

idx_train, idx_tmp, y_train, y_tmp = train_test_split(
    idx_all, y_all, test_size=0.30, random_state=SEED, stratify=y_all
)
idx_val, idx_test, y_val, y_test = train_test_split(
    idx_tmp, y_tmp, test_size=0.50, random_state=SEED, stratify=y_tmp
)

# ----------------------------- Torch tensors -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_t = torch.from_numpy(X_scaled).to(device)
y_t = torch.from_numpy(y).to(device)
edge_index_t = torch.from_numpy(edge_index).long().to(device)

train_mask_t = torch.zeros(len(y), dtype=torch.bool, device=device); train_mask_t[idx_train] = True
val_mask_t   = torch.zeros(len(y), dtype=torch.bool, device=device); val_mask_t[idx_val]   = True
test_mask_t  = torch.zeros(len(y), dtype=torch.bool, device=device); test_mask_t[idx_test] = True
labeled_mask_t = torch.from_numpy(labeled_mask_np).to(device)

# ----------------------------- Utility: degree, neighbors -----------------------------
def degrees(num_nodes, ei):
    deg = torch.bincount(ei[0], minlength=num_nodes).float()
    return deg

def coalesce_edge_index(ei):
    # simple coalesce: unique columns of ei (2 x E)
    u = ei[0]* (ei[0].max()+1) + ei[1]
    uniq, idx = torch.unique(u, return_inverse=True)
    ei0 = uniq // (ei[0].max()+1)
    ei1 = uniq %  (ei[0].max()+1)
    return torch.stack([ei0, ei1], dim=0)

# ----------------------------- GraphSAINT-style sampler -----------------------------
class GraphSAINTSampler:
    def __init__(self, num_nodes, edge_index, train_mask, batch_size=2000, num_steps=1,
                 bias_by_degree=True, device='cpu'):
        self.num_nodes = num_nodes
        self.edge_index = edge_index
        self.train_mask = train_mask
        self.batch_size = batch_size
        self.num_steps = num_steps
        self.device = device

        deg = torch.bincount(edge_index[0], minlength=num_nodes).float()
        deg = deg + 1e-6
        base_prob = deg / deg.sum() if bias_by_degree else torch.full(
            (num_nodes,), 1.0/num_nodes, device=edge_index.device
        )

        self.sample_space = torch.where(train_mask)[0]
        p_space = base_prob[self.sample_space]
        self.p_space = p_space / p_space.sum()
        self.base_prob = base_prob

    def __iter__(self):
        return self

    def next_seeds(self):
        replace = self.sample_space.numel() < self.batch_size
        idx = torch.multinomial(self.p_space, self.batch_size, replacement=replace)
        return self.sample_space[idx]

    def _induce_subgraph(self, seeds):
        ei = self.edge_index
        mask = (ei[0].unsqueeze(1) == seeds.unsqueeze(0)).any(dim=1) | \
               (ei[1].unsqueeze(1) == seeds.unsqueeze(0)).any(dim=1)
        sub_ei = ei[:, mask]
        nodes = torch.unique(torch.cat([sub_ei[0], sub_ei[1], seeds], dim=0))

        nid_map = -torch.ones(self.num_nodes, dtype=torch.long, device=ei.device)
        nid_map[nodes] = torch.arange(nodes.numel(), device=ei.device)
        sub_ei = nid_map[sub_ei]

        if sub_ei.numel() == 0:
            sub_ei = torch.stack([torch.tensor([0], device=ei.device), torch.tensor([0], device=ei.device)], dim=0)
            nodes = nodes[:1]

        u = sub_ei[0] * nodes.numel() + sub_ei[1]
        uniq = torch.unique(u)
        sub_ei0 = uniq // nodes.numel()
        sub_ei1 = uniq %  nodes.numel()
        sub_ei = torch.stack([sub_ei0, sub_ei1], dim=0)

        self_loops = torch.arange(nodes.numel(), device=ei.device)
        sub_ei = torch.cat([sub_ei, torch.stack([self_loops, self_loops])], dim=1)

        deg = torch.bincount(sub_ei[0], minlength=nodes.numel()).float()
        deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
        w = deg_inv_sqrt[sub_ei[0]] * deg_inv_sqrt[sub_ei[1]]
        A = torch.sparse_coo_tensor(sub_ei, w, (nodes.numel(), nodes.numel())).coalesce()
        return nodes, A

    def __next__(self):
        seeds = self.next_seeds()
        nodes, A = self._induce_subgraph(seeds)
        p = torch.clamp(self.base_prob[nodes], min=1e-8)
        norm_loss = (1.0 / p) / (1.0 / p).mean()
        return nodes, A, norm_loss


# ----------------------------- TrojanSAINT (GraphSAINT-style) model -----------------------------
class SAINTGCNLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.0, bias=True):
        super().__init__()
        self.lin = nn.Linear(in_dim, out_dim, bias=bias)
        self.bn = nn.BatchNorm1d(out_dim)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.lin.weight)

    def forward(self, x, adj):
        # GCN propagation on sampled subgraph
        x = torch.sparse.mm(adj, x)
        x = self.lin(x)
        x = self.bn(x)
        x = F.relu(x, inplace=True)
        x = self.dropout(x)
        return x

class TrojanSAINT(nn.Module):
    def __init__(self, in_dim, hid_dim=96, out_dim=2, dropout=0.35):
        super().__init__()
        self.g1 = SAINTGCNLayer(in_dim, hid_dim, dropout=dropout)
        self.g2 = nn.Linear(hid_dim, out_dim, bias=True)
        nn.init.xavier_uniform_(self.g2.weight)

    def forward(self, x, adj):
        x = self.g1(x, adj)
        x = torch.sparse.mm(adj, x)  # last propagation (linear head after propagation)
        x = self.g2(x)
        return x

model = TrojanSAINT(in_dim=X_t.size(1), hid_dim=96, out_dim=2, dropout=0.35).to(device)

# ----------------------------- Class weights, loss, optimizer -----------------------------
train_labels = y_t[train_mask_t]
classes, counts = torch.unique(train_labels, return_counts=True)
num_pos = counts[classes==1].item() if (classes==1).any() else 1
num_neg = counts[classes==0].item() if (classes==0).any() else 1
weight_pos = (num_neg + num_pos) / (2.0 * num_pos)
weight_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights = torch.tensor([weight_neg, weight_pos], dtype=torch.float32, device=device)

criterion = nn.CrossEntropyLoss(reduction='none', weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3, weight_decay=5e-4)

# ----------------------------- Sampler -----------------------------
num_nodes = X_t.size(0)
sampler = GraphSAINTSampler(
    num_nodes=num_nodes,
    edge_index=edge_index_t,
    train_mask=train_mask_t & (y_t >= 0),
    batch_size=min(4000, (train_mask_t & (y_t >= 0)).sum().item()),
    device=device
)

# ----------------------------- Evaluation (full-graph) -----------------------------
def build_full_adj(num_nodes, edge_index):
    self_loops = torch.arange(num_nodes, device=edge_index.device)
    ei = torch.cat([edge_index, torch.stack([self_loops, self_loops])], dim=1)
    deg = torch.bincount(ei[0], minlength=num_nodes).float()
    deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
    w = deg_inv_sqrt[ei[0]] * deg_inv_sqrt[ei[1]]
    A = torch.sparse_coo_tensor(ei, w, (num_nodes, num_nodes))
    return A.coalesce()

A_full = build_full_adj(X_t.size(0), edge_index_t)

@torch.no_grad()
def evaluate(mask_t):
    model.eval()
    logits = model(X_t, A_full)
    pred = logits.argmax(dim=1)
    msk = mask_t & (y_t >= 0)
    if msk.sum() == 0: return 0.0
    return (pred[msk] == y_t[msk]).float().mean().item()

# ----------------------------- Training (GraphSAINT-style) -----------------------------
best_val, best_state = -1.0, None
patience, patience_cnt = 20, 0
EPOCHS = 300
steps_per_epoch = max(1, int(np.ceil((train_mask_t & (y_t >= 0)).sum().item() / sampler.batch_size)))

for epoch in range(1, EPOCHS+1):
    model.train()
    epoch_loss = 0.0
    for _ in range(steps_per_epoch):
        nodes, A_b, norm_loss = next(sampler)
        x_b = X_t[nodes]
        y_b = y_t[nodes]
        train_b = train_mask_t[nodes] & (y_b >= 0)

        if train_b.sum() == 0:
            continue

        logits_b = model(x_b, A_b)
        loss_raw = criterion(logits_b[train_b], y_b[train_b])
        # GraphSAINT normalization proxy: weight per-node loss by norm_loss
        loss = (loss_raw * norm_loss[train_b]).mean()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 0 or epoch == 1:
        val_acc = evaluate(val_mask_t)
        test_acc = evaluate(test_mask_t)
        print(f"Epoch {epoch:03d} | Loss {epoch_loss/steps_per_epoch:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print("Early stopping."); break

if best_state is not None:
    model.load_state_dict(best_state)

# ----------------------------- Final eval -----------------------------
model.eval()
with torch.no_grad():
    logits = model(X_t, A_full)
    preds = logits.argmax(dim=1)

msk = (test_mask_t & (y_t >= 0)).cpu().numpy()
y_true = y_t.cpu().numpy()[msk]
y_pred = preds.cpu().numpy()[msk]

acc = (y_true == y_pred).mean()
print("\nFinal Evaluation (Node-Level, TrojanSAINT/GraphSAINT-style)")
print("============================================================")
print(f"Test Accuracy: {acc:.4f}\n")

print("Classification Report:")
print(classification_report(y_true, y_pred, labels=[0,1], target_names=["clean","trojan"], digits=4))

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred, labels=[0,1]))


Epoch 001 | Loss 0.2149 | Val 0.2494 | Test 0.2495
Epoch 010 | Loss 0.0014 | Val 0.2497 | Test 0.2500
Epoch 020 | Loss 0.0007 | Val 0.3833 | Test 0.3831
Epoch 030 | Loss 0.0006 | Val 0.3490 | Test 0.3488
Epoch 040 | Loss 0.0006 | Val 0.2675 | Test 0.2687
Epoch 050 | Loss 0.0025 | Val 0.8438 | Test 0.8493
Epoch 060 | Loss 0.0005 | Val 0.9770 | Test 0.9776
Epoch 070 | Loss 0.0005 | Val 0.9625 | Test 0.9623
Epoch 080 | Loss 0.0015 | Val 0.7825 | Test 0.7851
Epoch 090 | Loss 0.0004 | Val 0.3912 | Test 0.3925
Epoch 100 | Loss 0.0004 | Val 0.2504 | Test 0.2509
Epoch 110 | Loss 0.0004 | Val 0.2494 | Test 0.2495
Epoch 120 | Loss 0.0004 | Val 0.2783 | Test 0.2783
Epoch 130 | Loss 0.0004 | Val 0.2757 | Test 0.2754
Epoch 140 | Loss 0.0019 | Val 0.2494 | Test 0.2495
Epoch 150 | Loss 0.0020 | Val 0.2494 | Test 0.2495
Epoch 160 | Loss 0.0019 | Val 0.2494 | Test 0.2495
Epoch 170 | Loss 0.0004 | Val 0.2494 | Test 0.2495
Epoch 180 | Loss 0.0004 | Val 0.2494 | Test 0.2495
Epoch 190 | Loss 0.0004 | Val 0

#### All in One, same perturbation across all metric.

In [13]:
# =========================
# Robustness PGD + Jacobian (Node-level, TrojanSAINT)
# =========================
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

PER_CLASS     = 50
EPSILON_PGD   = 1.5
ALPHA_PGD     = 0.3
NUM_ITERS_PGD = 20
FD_EPS        = 1e-3
SEED          = 42
torch.manual_seed(SEED); np.random.seed(SEED)
device = X_t.device

# -------------------------
# Select test nodes
# -------------------------
test_mask_np = test_mask_t.cpu().numpy() & (y_t.cpu().numpy() >= 0)
y_np = y_t.cpu().numpy()
idx_test = np.where(test_mask_np)[0]

sel_clean  = np.random.choice(idx_test[y_np[idx_test]==0],
                              size=min(PER_CLASS, (y_np[idx_test]==0).sum()),
                              replace=False)
sel_trojan = np.random.choice(idx_test[y_np[idx_test]==1],
                              size=min(PER_CLASS, (y_np[idx_test]==1).sum()),
                              replace=False)
selected = np.concatenate([sel_clean, sel_trojan])
print(f"Selected {len(sel_clean)} clean and {len(sel_trojan)} trojan nodes from test set.")

# -------------------------
# PGD perturbation per node
# -------------------------
perturbed_feats = X_t.clone().detach()
orig_preds, adv_preds = [], []

for idx in selected:
    x_adv = X_t.clone().detach()
    x_adv.requires_grad_(True)

    for it in range(NUM_ITERS_PGD):
        logits = model(x_adv, A_full)
        loss = F.cross_entropy(logits[idx:idx+1], y_t[idx:idx+1])
        grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        g_node = grad[idx]
        step = ALPHA_PGD * g_node / (g_node.norm() + 1e-12)
        x_adv = x_adv.detach()
        x_adv[idx] = (x_adv[idx] + step).detach()
        # project back into L2 ball
        diff = x_adv[idx] - X_t[idx]
        if diff.norm() > EPSILON_PGD:
            diff = diff * (EPSILON_PGD / (diff.norm() + 1e-12))
            x_adv[idx] = X_t[idx] + diff
        x_adv.requires_grad_(True)

    perturbed_feats[idx] = x_adv[idx].detach()

    with torch.no_grad():
        orig_preds.append(int(model(X_t, A_full)[idx].argmax().item()))
        adv_preds.append(int(model(perturbed_feats, A_full)[idx].argmax().item()))

orig_preds, adv_preds = np.array(orig_preds), np.array(adv_preds)
num_flips = (orig_preds != adv_preds).sum()
print(f"Perturbation success: {num_flips}/{len(selected)} "
      f"({100.0*num_flips/len(selected):.2f}%)")

# -------------------------
# Evaluate perturbed selected nodes
# -------------------------
labels_sel = y_np[selected]
with torch.no_grad():
    logits_adv = model(perturbed_feats, A_full)[selected]
    preds_adv = logits_adv.argmax(dim=1).cpu().numpy()

acc = (preds_adv == labels_sel).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_sel, preds_adv, average='weighted', zero_division=0)

print("\nPerformance on Perturbed Selected Nodes")
print("---------------------------------------")
print(f"Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")
print("Classification Report:")
print(classification_report(labels_sel, preds_adv, labels=[0,1], target_names=["clean","trojan"], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_sel, preds_adv, labels=[0,1]))

# -------------------------
# Jacobian Frobenius norm example
# -------------------------
jacobian_vals = []
for idx in selected[:5]:  # sample a few nodes (Jacobian is expensive)
    x_in = perturbed_feats.clone().detach().requires_grad_(True)
    def f_node(z):
        return model(z, A_full)[idx]
    J = torch.autograd.functional.jacobian(f_node, x_in)
    # shape: (C, N, F) ? flatten node dimension
    if J.ndim == 3:
        J = J[:, idx, :]  # only keep Jacobian wrt that nodes features
    frob = torch.norm(J, p='fro').item()
    jacobian_vals.append((idx, frob))

print("\nJacobian Frobenius Norm (sampled nodes):")
for idx, val in jacobian_vals:
    print(f" Node {idx}: {val:.4f}")


Selected 50 clean and 50 trojan nodes from test set.
Perturbation success: 45/100 (45.00%)

Performance on Perturbed Selected Nodes
---------------------------------------
Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000, F1: 0.3333
Classification Report:
              precision    recall  f1-score   support

       clean     0.5000    1.0000    0.6667        50
      trojan     0.0000    0.0000    0.0000        50

    accuracy                         0.5000       100
   macro avg     0.2500    0.5000    0.3333       100
weighted avg     0.2500    0.5000    0.3333       100

Confusion Matrix:
[[50  0]
 [50  0]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Jacobian Frobenius Norm (sampled nodes):
 Node 14263: 1.8702
 Node 23954: 1.8375
 Node 35912: 1.3144
 Node 9652: 1.8866
 Node 31500: 1.0735


In [14]:
# ================================
# Robustness Metrics (Node-Level, TrojanSAINT)
# ================================
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ------------- Parameters -------------
PER_CLASS     = 50
EPSILON_PGD   = 1.5
ALPHA_PGD     = 0.3
NUM_ITERS_PGD = 20
FD_EPS        = 1e-3
ARR_INITIAL_EPS  = 1e-3
ARR_GROW1        = 1.25
ARR_GROW2        = 1.4
ARR_MAX_EPS      = 5.0
ARR_BS_ITERS     = 5
ARR_TRIALS       = 3
STAB_SIGMA       = 0.5
STAB_SAMPLES     = 5
SEED          = 42
torch.manual_seed(SEED); np.random.seed(SEED)
device = X_t.device


# ------------- Utility: stats -------------
def mean_std(a):
    return (np.mean(a).item() if len(a) else 0.0,
            np.std(a).item() if len(a) else 0.0)

def print_stats(name, clean_vals, troj_vals, clean_errs, troj_errs):
    c_mean, c_std = mean_std(clean_vals)
    t_mean, t_std = mean_std(troj_vals)
    ce_mean, ce_std = mean_std(clean_errs)
    te_mean, te_std = mean_std(troj_errs)
    print(f"\n{name} (on perturbed selected nodes):")
    print(f" Clean:  avg={c_mean:.4f} ± {c_std:.4f}, avg_relerr={ce_mean:.4e} ± {ce_std:.4e}")
    print(f" Trojan: avg={t_mean:.4f} ± {t_std:.4f}, avg_relerr={te_mean:.4e} ± {te_std:.4e}")

# ------------- Metric 1: Jacobian -------------
jac_vals, jac_errs, labels = [], [], []
for idx in selected:
    x_in = perturbed_feats.clone().detach().requires_grad_(True)
    def f_node(z): return model(z, A_full)[idx]
    J = torch.autograd.functional.jacobian(f_node, x_in)
    if J.ndim == 3:
        J = J[:, idx, :]
    jac_frob = torch.norm(J, p='fro').item()
    delta_fd = FD_EPS * torch.randn_like(x_in[idx])
    pred_change = J @ delta_fd
    f0 = f_node(x_in).detach()
    f1 = f_node(x_in + (delta_fd.unsqueeze(0) * (torch.arange(len(x_in))==idx).float().to(device)[:,None]))
    actual_change = (f1 - f0).detach()
    rel_err = (torch.norm(pred_change - actual_change)/ (torch.norm(actual_change)+1e-12)).item()
    jac_vals.append(jac_frob); jac_errs.append(rel_err); labels.append(y_np[idx])

print_stats("Jacobian Sensitivity", 
            [jac_vals[i] for i in range(len(labels)) if labels[i]==0],
            [jac_vals[i] for i in range(len(labels)) if labels[i]==1],
            [jac_errs[i] for i in range(len(labels)) if labels[i]==0],
            [jac_errs[i] for i in range(len(labels)) if labels[i]==1])

# ------------- Metric 2: Local Lipschitz -------------
lip_vals, lip_errs, labels = [], [], []
for idx in selected:
    x_in = perturbed_feats.clone().detach().requires_grad_(True)
    def f_node(z): return model(z, A_full)[idx]
    J = torch.autograd.functional.jacobian(f_node, x_in)
    if J.ndim == 3: J = J[:, idx, :]
    U, S, V = torch.linalg.svd(J, full_matrices=False)
    sigma_max = S[0].item()
    delta_fd = FD_EPS * torch.randn_like(x_in[idx])
    pred_change = J @ delta_fd
    f0 = f_node(x_in).detach()
    f1 = f_node(x_in + (delta_fd.unsqueeze(0) * (torch.arange(len(x_in))==idx).float().to(device)[:,None]))
    actual_change = (f1 - f0).detach()
    rel_err = (torch.norm(pred_change - actual_change)/ (torch.norm(actual_change)+1e-12)).item()
    lip_vals.append(sigma_max); lip_errs.append(rel_err); labels.append(y_np[idx])

print_stats("Local Lipschitz", 
            [lip_vals[i] for i in range(len(labels)) if labels[i]==0],
            [lip_vals[i] for i in range(len(labels)) if labels[i]==1],
            [lip_errs[i] for i in range(len(labels)) if labels[i]==0],
            [lip_errs[i] for i in range(len(labels)) if labels[i]==1])

# ------------- Metric 3: Hessian Proxy -------------
hess_vals, hess_errs, labels = [], [], []
for idx in selected:
    x_in = perturbed_feats.clone().detach().requires_grad_(True)
    logits = model(x_in, A_full)[idx]
    pred_class = logits.argmax().item()
    logp = F.log_softmax(logits, dim=0)[pred_class]
    g = torch.autograd.grad(logp, x_in, retain_graph=False, create_graph=False)[0][idx]
    lambda_proxy = g.norm().item()**2
    rels = []
    for _ in range(3):
        delta = FD_EPS * torch.randn_like(x_in[idx])
        gt_delta = (g * delta).sum().item()
        logits_p = model(x_in.clone().detach().index_add(0, torch.tensor([idx], device=device), delta.unsqueeze(0)), A_full)[idx]
        logp_p = F.log_softmax(logits_p, dim=0)[pred_class]
        actual_second = (logp_p - logp).item() - gt_delta
        pred_second = 0.5*(gt_delta**2)
        rels.append(abs(pred_second-actual_second)/(abs(actual_second)+1e-12))
    hess_vals.append(lambda_proxy); hess_errs.append(np.mean(rels)); labels.append(y_np[idx])

print_stats("Hessian Curvature", 
            [hess_vals[i] for i in range(len(labels)) if labels[i]==0],
            [hess_vals[i] for i in range(len(labels)) if labels[i]==1],
            [hess_errs[i] for i in range(len(labels)) if labels[i]==0],
            [hess_errs[i] for i in range(len(labels)) if labels[i]==1])

# ------------- Metric 4: Prediction Margin -------------
margin_vals, margin_errs, labels = [], [], []
for idx in selected:
    logits = model(perturbed_feats, A_full)[idx]
    pred_class = logits.argmax().item()
    margin = logits[pred_class].item() - logits[[j for j in range(len(logits)) if j!=pred_class]].max().item()
    delta = FD_EPS * torch.randn_like(perturbed_feats[idx])
    logits_p = model(perturbed_feats.clone().detach().index_add(0, torch.tensor([idx], device=device), delta.unsqueeze(0)), A_full)[idx]
    margin_p = logits_p[pred_class].item() - logits_p[[j for j in range(len(logits_p)) if j!=pred_class]].max().item()
    rel_err = abs(margin-margin_p)/(abs(margin_p)+1e-12)
    margin_vals.append(margin); margin_errs.append(rel_err); labels.append(y_np[idx])

print_stats("Prediction Margin", 
            [margin_vals[i] for i in range(len(labels)) if labels[i]==0],
            [margin_vals[i] for i in range(len(labels)) if labels[i]==1],
            [margin_errs[i] for i in range(len(labels)) if labels[i]==0],
            [margin_errs[i] for i in range(len(labels)) if labels[i]==1])

# ------------- Metric 5: ARR -------------
def adversarial_radius(idx, growth=ARR_GROW1):
    x0 = perturbed_feats.clone().detach()
    y0 = model(x0, A_full)[idx].argmax().item()
    d = torch.randn_like(x0[idx]); d /= d.norm()+1e-12
    eps = ARR_INITIAL_EPS
    while eps<ARR_MAX_EPS and model(x0.clone().detach().index_add(0, torch.tensor([idx], device=device), (eps*d).unsqueeze(0)), A_full)[idx].argmax().item()==y0:
        eps*=growth
    return eps

arr_vals, arr_errs, labels = [], [], []
for idx in selected:
    r1 = adversarial_radius(idx, growth=ARR_GROW1)
    r2 = adversarial_radius(idx, growth=ARR_GROW2)
    rel = abs(r1-r2)/(abs(r2)+1e-12)
    arr_vals.append(r1); arr_errs.append(rel); labels.append(y_np[idx])

print_stats("Adversarial Robustness Radius", 
            [arr_vals[i] for i in range(len(labels)) if labels[i]==0],
            [arr_vals[i] for i in range(len(labels)) if labels[i]==1],
            [arr_errs[i] for i in range(len(labels)) if labels[i]==0],
            [arr_errs[i] for i in range(len(labels)) if labels[i]==1])

# ------------- Metric 6: Stability under Noise -------------
stab_vals, stab_errs, labels = [], [], []
for idx in selected:
    base = model(perturbed_feats, A_full)[idx]
    diffs = []
    for _ in range(STAB_SAMPLES):
        noise = STAB_SIGMA*torch.randn_like(perturbed_feats[idx])
        x_noisy = perturbed_feats.clone().detach()
        x_noisy[idx]+=noise
        f_noisy = model(x_noisy, A_full)[idx]
        diffs.append(torch.norm(f_noisy-base).item())
    val = np.mean(diffs)
    re_vals = []
    for _ in range(3):
        diffs2 = []
        for _ in range(STAB_SAMPLES):
            noise = STAB_SIGMA*torch.randn_like(perturbed_feats[idx])
            x_noisy = perturbed_feats.clone().detach()
            x_noisy[idx]+=noise
            f_noisy = model(x_noisy, A_full)[idx]
            diffs2.append(torch.norm(f_noisy-base).item())
        re_vals.append(np.mean(diffs2))
    ref = np.mean(re_vals)
    rel_err = abs(val-ref)/(abs(ref)+1e-12)
    stab_vals.append(val); stab_errs.append(rel_err); labels.append(y_np[idx])

print_stats("Stability under Noise", 
            [stab_vals[i] for i in range(len(labels)) if labels[i]==0],
            [stab_vals[i] for i in range(len(labels)) if labels[i]==1],
            [stab_errs[i] for i in range(len(labels)) if labels[i]==0],
            [stab_errs[i] for i in range(len(labels)) if labels[i]==1])



Jacobian Sensitivity (on perturbed selected nodes):
 Clean:  avg=1.5936 ± 0.3236, avg_relerr=7.5815e-04 ± 2.4729e-03
 Trojan: avg=4.7731 ± 1.9465, avg_relerr=2.5976e-04 ± 7.8128e-04

Local Lipschitz (on perturbed selected nodes):
 Clean:  avg=1.5932 ± 0.3235, avg_relerr=1.5952e-03 ± 3.6699e-03
 Trojan: avg=4.7710 ± 1.9470, avg_relerr=6.2384e-04 ± 1.8397e-03

Hessian Curvature (on perturbed selected nodes):
 Clean:  avg=0.0010 ± 0.0023, avg_relerr=1.0013e+00 ± 7.1515e-03
 Trojan: avg=0.1581 ± 0.5224, avg_relerr=1.0614e+00 ± 1.6674e-01

Prediction Margin (on perturbed selected nodes):
 Clean:  avg=5.9187 ± 1.3137, avg_relerr=3.2050e-04 ± 2.4754e-04
 Trojan: avg=3.7113 ± 1.4298, avg_relerr=1.4169e-03 ± 1.1791e-03

Adversarial Robustness Radius (on perturbed selected nodes):
 Clean:  avg=6.0185 ± 0.0000, avg_relerr=4.4652e-02 ± 0.0000e+00
 Trojan: avg=5.5198 ± 1.2059, avg_relerr=1.4842e-01 ± 2.2524e-01

Stability under Noise (on perturbed selected nodes):
 Clean:  avg=0.5819 ± 0.2441, avg