#### Training Evaluation similar to https://github.com/AICPS/hw2vec

Model to a HW2VEC-style GIN architecture (MLP + BatchNorm + ReLU with learnable e and sum-aggregation), while keeping all preprocessing, splits, training loop, optimizer, class weights, and evaluation unchanged.

In [1]:
# train_gcn_node_fixed.py
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

NODE_CSV  = "GNNDatasets/node.csv"
EDGE_CSV  = "GNNDatasets/node_edges.csv"
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)

# ----------------------------- Load nodes -----------------------------
nodes_df = pd.read_csv(NODE_CSV)

label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        label_col = cand; break
if label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

feat_df = nodes_df.copy()
if "gate_type" in feat_df.columns:
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name",label_col}
num_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
X = feat_df[num_cols].fillna(0.0).values.astype(np.float32)
y = nodes_df[label_col].values.astype(np.int64)

# ----------------------------- Load edges; add missing nodes -----------------------------
edges_df = pd.read_csv(EDGE_CSV)
edges_df["src_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["src"].astype(str)
edges_df["dst_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["dst"].astype(str)

known_uids = set(nodes_df["uid"])
edge_uids = set(edges_df["src_uid"]).union(set(edges_df["dst_uid"]))
missing = list(edge_uids - known_uids)

if missing:
    zero_row = np.zeros((1, X.shape[1]), dtype=np.float32)
    addX = np.repeat(zero_row, len(missing), axis=0)
    addY = -1*np.ones(len(missing), dtype=np.int64)
    add_df = pd.DataFrame({
        "uid": missing,
        "circuit_name": [u.split("::",1)[0] for u in missing],
        "node": [u.split("::",1)[1] for u in missing],
        label_col: addY
    })
    X = np.vstack([X, addX])
    y = np.concatenate([y, addY])
    nodes_df = pd.concat([nodes_df, add_df], ignore_index=True)

uid_to_idx = {u:i for i,u in enumerate(nodes_df["uid"].tolist())}
src_idx = edges_df["src_uid"].map(uid_to_idx).dropna().astype(int).values
dst_idx = edges_df["dst_uid"].map(uid_to_idx).dropna().astype(int).values
edge_index = np.stack([np.concatenate([src_idx, dst_idx]),
                       np.concatenate([dst_idx, src_idx])], axis=0)

# ----------------------------- Scale features -----------------------------
labeled_mask_np = (y >= 0)
scaler = StandardScaler()
X_scaled = X.copy()
X_scaled[labeled_mask_np] = scaler.fit_transform(X_scaled[labeled_mask_np])
if (~labeled_mask_np).any():
    X_scaled[~labeled_mask_np] = (X_scaled[~labeled_mask_np] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# ----------------------------- Splits -----------------------------
idx_all = np.where(labeled_mask_np)[0]
y_all = y[labeled_mask_np]

idx_train, idx_tmp, y_train, y_tmp = train_test_split(
    idx_all, y_all, test_size=0.30, random_state=SEED, stratify=y_all
)
idx_val, idx_test, y_val, y_test = train_test_split(
    idx_tmp, y_tmp, test_size=0.50, random_state=SEED, stratify=y_tmp
)

# ----------------------------- Torch tensors (FIX: masks as torch.bool) -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_t = torch.from_numpy(X_scaled).to(device)
y_t = torch.from_numpy(y).to(device)

edge_index_t = torch.from_numpy(edge_index).long().to(device)

train_mask_t = torch.zeros(len(y), dtype=torch.bool, device=device); train_mask_t[idx_train] = True
val_mask_t   = torch.zeros(len(y), dtype=torch.bool, device=device); val_mask_t[idx_val]   = True
test_mask_t  = torch.zeros(len(y), dtype=torch.bool, device=device); test_mask_t[idx_test]  = True
labeled_mask_t = torch.from_numpy(labeled_mask_np).to(device)

# ----------------------------- Build GCN adjacency -----------------------------
def build_adj(num_nodes, edge_index):
    self_loops = torch.arange(num_nodes, device=edge_index.device)
    ei = torch.cat([edge_index, torch.stack([self_loops, self_loops])], dim=1)
    deg = torch.bincount(ei[0], minlength=num_nodes).float()
    deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
    w = deg_inv_sqrt[ei[0]] * deg_inv_sqrt[ei[1]]
    A = torch.sparse_coo_tensor(ei, w, (num_nodes, num_nodes))
    return A.coalesce()

A_t = build_adj(X_t.size(0), edge_index_t)

# ----------------------------- Model -----------------------------
# ----------------------------- Build unnormalized adjacency (GIN-style sum aggregation) -----------------------------
def build_adj_unnorm(num_nodes, edge_index):
    # Add self-loops; unit weights; coalesce
    self_loops = torch.arange(num_nodes, device=edge_index.device)
    ei = torch.cat([edge_index, torch.stack([self_loops, self_loops])], dim=1)
    vals = torch.ones(ei.size(1), device=edge_index.device)
    A = torch.sparse_coo_tensor(ei, vals, (num_nodes, num_nodes))
    return A.coalesce()

A_t = build_adj_unnorm(X_t.size(0), edge_index_t)

# ----------------------------- HW2VEC-style GIN Model for node classification -----------------------------
class GINLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.35, eps_init=0.0):
        super().__init__()
        # MLP as in GIN: Linear -> BN -> ReLU -> Linear
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim, bias=True),
            nn.BatchNorm1d(out_dim),
            nn.ReLU(inplace=True),
            nn.Linear(out_dim, out_dim, bias=True),
        )
        self.dropout = nn.Dropout(dropout)
        # Learnable epsilon for (1 + eps) * h + sum_neighbors(h)
        self.eps = nn.Parameter(torch.tensor(eps_init, dtype=torch.float32))
        # Init
        for m in self.mlp:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, adj_unnorm):
        # sum aggregation with self-loops already in adj
        agg = torch.sparse.mm(adj_unnorm, x)
        # Separate self contribution: since adj has self-loops of weight 1,
        # agg = h + sum_neighbors(h). Then (1+eps)*h + sum_neighbors(h) = agg + eps*h
        out = agg + self.eps * x
        out = self.mlp(out)
        out = self.dropout(out)
        return out

class HW2VEC_GIN_Node(nn.Module):
    def __init__(self, in_dim, hid_dim=96, out_dim=2, dropout=0.35, eps_init=0.0):
        super().__init__()
        self.g1 = GINLayer(in_dim, hid_dim, dropout=dropout, eps_init=eps_init)
        self.g2 = GINLayer(hid_dim, hid_dim, dropout=dropout, eps_init=eps_init)
        self.lin_out = nn.Linear(hid_dim, out_dim, bias=True)
        nn.init.xavier_uniform_(self.lin_out.weight)
        if self.lin_out.bias is not None:
            nn.init.zeros_(self.lin_out.bias)

    def forward(self, x, adj_unnorm):
        x = self.g1(x, adj_unnorm)
        x = self.g2(x, adj_unnorm)
        x = self.lin_out(x)
        return x

# Replace your previous model with HW2VEC-style GIN
model = HW2VEC_GIN_Node(in_dim=X_t.size(1), hid_dim=96, out_dim=2, dropout=0.35, eps_init=0.0).to(device)


# ----------------------------- Loss, optimizer -----------------------------
train_labels = y_t[train_mask_t]
classes, counts = torch.unique(train_labels, return_counts=True)
num_pos = counts[classes==1].item() if (classes==1).any() else 1
num_neg = counts[classes==0].item() if (classes==0).any() else 1
weight_pos = (num_neg + num_pos) / (2.0 * num_pos)
weight_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights = torch.tensor([weight_neg, weight_pos], dtype=torch.float32, device=device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3, weight_decay=5e-4)

# ----------------------------- Training -----------------------------
def evaluate(mask_t):
    model.eval()
    with torch.no_grad():
        logits = model(X_t, A_t)
        pred = logits.argmax(dim=1)
        msk = mask_t & (y_t >= 0)
        if msk.sum() == 0: return 0.0
        return (pred[msk] == y_t[msk]).float().mean().item()

best_val, best_state = -1.0, None
patience, patience_cnt = 20, 0
EPOCHS = 300

for epoch in range(1, EPOCHS+1):
    model.train()
    optimizer.zero_grad()
    logits = model(X_t, A_t)
    loss = criterion(logits[train_mask_t], y_t[train_mask_t])
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    optimizer.step()

    if epoch % 10 == 0 or epoch == 1:
        val_acc = evaluate(val_mask_t)
        test_acc = evaluate(test_mask_t)
        print(f"Epoch {epoch:03d} | Loss {loss.item():.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print("Early stopping."); break

if best_state is not None:
    model.load_state_dict(best_state)

# ----------------------------- Final eval -----------------------------
model.eval()
with torch.no_grad():
    logits = model(X_t, A_t)
    preds = logits.argmax(dim=1)

msk = (test_mask_t & (y_t >= 0)).cpu().numpy()
y_true = y_t.cpu().numpy()[msk]
y_pred = preds.cpu().numpy()[msk]

acc = (y_true == y_pred).mean()
print("\nFinal Evaluation (Node-Level)")
print("=============================")
print(f"Test Accuracy: {acc:.4f}\n")

print("Classification Report:")
print(classification_report(y_true, y_pred, labels=[0,1], target_names=["clean","trojan"], digits=4))

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred, labels=[0,1]))

Epoch 001 | Loss 1.1374 | Val 0.2816 | Test 0.2836
Epoch 010 | Loss 0.0668 | Val 0.9928 | Test 0.9918
Epoch 020 | Loss 0.0134 | Val 0.9832 | Test 0.9849
Epoch 030 | Loss 0.0082 | Val 0.9999 | Test 0.9999
Epoch 040 | Loss 0.0048 | Val 1.0000 | Test 1.0000
Epoch 050 | Loss 0.0031 | Val 1.0000 | Test 0.9999
Epoch 060 | Loss 0.0020 | Val 1.0000 | Test 0.9999
Epoch 070 | Loss 0.0013 | Val 1.0000 | Test 1.0000
Epoch 080 | Loss 0.0010 | Val 1.0000 | Test 1.0000
Epoch 090 | Loss 0.0008 | Val 1.0000 | Test 1.0000
Epoch 100 | Loss 0.0006 | Val 1.0000 | Test 1.0000
Epoch 110 | Loss 0.0007 | Val 1.0000 | Test 1.0000
Epoch 120 | Loss 0.0005 | Val 1.0000 | Test 1.0000
Epoch 130 | Loss 0.0005 | Val 1.0000 | Test 1.0000
Epoch 140 | Loss 0.0005 | Val 1.0000 | Test 1.0000
Epoch 150 | Loss 0.0005 | Val 1.0000 | Test 1.0000
Epoch 160 | Loss 0.0005 | Val 1.0000 | Test 1.0000
Epoch 170 | Loss 0.0004 | Val 1.0000 | Test 1.0000
Epoch 180 | Loss 0.0004 | Val 1.0000 | Test 1.0000
Epoch 190 | Loss 0.0004 | Val 1

In [5]:
# ================================
# Unified Robustness Evaluation
# ================================
import torch, numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ---------------- Parameters ----------------
PER_CLASS = 100
EPSILON   = 5.0     # L2 budget for PGD
ALPHA     = 1.0     # PGD step size
NUM_ITERS = 40      # PGD iterations
FD_EPS    = 1e-3    # finite-difference epsilon
SEED      = 42

torch.manual_seed(SEED); np.random.seed(SEED)

required_vars = ["model","X_t","A_t","y_t","test_mask_t","device"]
for v in required_vars:
    if v not in globals():
        raise RuntimeError(f"Required var '{v}' not found.")

model.to(device); model.eval()
labels_np = y_t.cpu().numpy()

# ---------------- Node Selection ----------------
test_indices = np.where(test_mask_t.cpu().numpy())[0]
rng = np.random.default_rng(SEED)
selected_nodes = []
for cls in [0,1]:
    idxs = [int(i) for i in test_indices if labels_np[i]==cls]
    chosen = rng.choice(idxs, size=min(PER_CLASS, len(idxs)), replace=False)
    selected_nodes.extend(chosen)
selected_nodes = np.array(selected_nodes, dtype=np.int64)

print(f"Selected: clean={int((labels_np[selected_nodes]==0).sum())}, "
      f"trojan={int((labels_np[selected_nodes]==1).sum())}")

# ---------------- Shared PGD Perturbations ----------------
perturbed_X = X_t.clone().detach().to(device)
for node_idx in selected_nodes:
    node_idx = int(node_idx)
    x_orig = X_t[node_idx].detach().clone().to(device)
    x_adv = (x_orig + 1e-3*torch.randn_like(x_orig)).detach().requires_grad_(True)

    for _ in range(NUM_ITERS):
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = x_adv
        logits = model(X_mod, A_t)
        loss = F.cross_entropy(logits[node_idx].unsqueeze(0), y_t[node_idx].unsqueeze(0))
        grad_x = torch.autograd.grad(loss, x_adv)[0]
        if grad_x.norm().item()==0: break
        step = ALPHA * grad_x / (grad_x.norm() + 1e-12)
        x_adv = (x_adv + step).detach()
        delta = x_adv - x_orig
        if delta.norm() > EPSILON:
            delta = delta * (EPSILON/(delta.norm()+1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)
    perturbed_X[node_idx] = x_adv.detach()
print("? Shared PGD perturbations done.")

# ---------------- Eval Helper ----------------
def evaluate_model(name, perturbed_X, selected_nodes):
    with torch.no_grad():
        # Predictions on original and perturbed inputs
        orig_logits = model(X_t, A_t)
        pert_logits = model(perturbed_X, A_t)

        orig_preds = orig_logits.argmax(dim=1).cpu().numpy()
        pert_preds = pert_logits.argmax(dim=1).cpu().numpy()

    # Restrict to selected perturbed samples only
    sel_idx = np.array(selected_nodes)
    sel_labels = labels_np[sel_idx]
    sel_orig_preds = orig_preds[sel_idx]
    sel_pert_preds = pert_preds[sel_idx]

    # Flip count first
    flips = (sel_orig_preds != sel_pert_preds).sum()
    print(f"\n=== Robustness Eval ({name}) ===")
    print(f"Flipped {flips}/{len(sel_idx)} ({100*flips/len(sel_idx):.2f}%)")

    # Accuracy, precision, recall, F1 on perturbed subset only
    acc = (sel_pert_preds == sel_labels).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(
        sel_labels, sel_pert_preds, average="weighted", zero_division=0)

    print(f"Accuracy={acc*100:.2f} | Precision={prec:.4f} | Recall={rec:.4f} | F1={f1:.4f}")
    print("Confusion Matrix:")
    print(confusion_matrix(sel_labels, sel_pert_preds, labels=[0, 1]))
    print("Classification Report:")
    print(classification_report(sel_labels, sel_pert_preds,
                                target_names=["clean", "trojan"], digits=4))

    return pert_logits


# ---------------- Metric 1: Jacobian Sensitivity ----------------
jac_info = []
for node_idx in selected_nodes:
    x0 = perturbed_X[node_idx].detach().clone().requires_grad_(True)
    def f_local(x):
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = x
        return model(X_mod, A_t)[node_idx]
    J = torch.autograd.functional.jacobian(f_local, x0)
    jac_norm = torch.norm(J, p='fro').item()
    delta_fd = FD_EPS * torch.randn_like(x0)
    pred_change = J.mv(delta_fd)
    f0, f0p = f_local(x0).detach(), f_local(x0+delta_fd).detach()
    actual_change = f0p - f0
    rel_err = (torch.norm(pred_change-actual_change)/(torch.norm(actual_change)+1e-8)).item()
    jac_info.append((int(labels_np[node_idx]), jac_norm, rel_err))
print("\nJacobian Sensitivity:")
for cls in [0,1]:
    vals = [j[1] for j in jac_info if j[0]==cls]
    errs = [j[2] for j in jac_info if j[0]==cls]
    print(f" Class {cls}: norm={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Jacobian", perturbed_X, selected_nodes)

# ---------------- Metric 2: Lipschitz (Spectral Norm) ----------------
lip_info = []
for node_idx in selected_nodes:
    x0 = perturbed_X[node_idx].detach().clone().requires_grad_(True)
    def f_node(x):
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = x
        return model(X_mod, A_t)[node_idx]
    J = torch.autograd.functional.jacobian(f_node, x0).detach()
    U, S, Vh = torch.linalg.svd(J, full_matrices=False)
    sigma_max = S[0].item()
    delta_fd = FD_EPS*torch.randn_like(x0)
    pred_change = J.mv(delta_fd)
    f0, f0p = f_node(x0).detach(), f_node(x0+delta_fd).detach()
    actual_change = f0p-f0
    rel_err = (torch.norm(pred_change-actual_change)/(torch.norm(actual_change)+1e-8)).item()
    lip_info.append((int(labels_np[node_idx]), sigma_max, rel_err))
print("\nLipschitz Constant:")
for cls in [0,1]:
    vals = [j[1] for j in lip_info if j[0]==cls]
    errs = [j[2] for j in lip_info if j[0]==cls]
    print(f" Class {cls}: L={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Lipschitz", perturbed_X, selected_nodes)

Selected: clean=100, trojan=100
? Shared PGD perturbations done.

Jacobian Sensitivity:
 Class 0: norm=1.5677±0.6970, relerr=3.7685e-03±7.6597e-03
 Class 1: norm=1.3663±0.1219, relerr=2.3876e-02±1.8866e-01

=== Robustness Eval (Jacobian) ===
Flipped 93/200 (46.50%)
Accuracy=53.50 | Precision=0.5350 | Recall=0.5350 | F1=0.5350
Confusion Matrix:
[[53 47]
 [46 54]]
Classification Report:
              precision    recall  f1-score   support

       clean     0.5354    0.5300    0.5327       100
      trojan     0.5347    0.5400    0.5373       100

    accuracy                         0.5350       200
   macro avg     0.5350    0.5350    0.5350       200
weighted avg     0.5350    0.5350    0.5350       200


Lipschitz Constant:
 Class 0: L=1.5646±0.6972, relerr=6.9201e-03±2.4854e-02
 Class 1: L=1.3654±0.1218, relerr=1.7527e-02±1.2781e-01

=== Robustness Eval (Lipschitz) ===
Flipped 93/200 (46.50%)
Accuracy=53.50 | Precision=0.5350 | Recall=0.5350 | F1=0.5350
Confusion Matrix:
[[53 47]
 [

tensor([[ 11.1726, -11.6568],
        [  9.0651,  -9.4224],
        [  9.9202, -10.1688],
        ...,
        [  5.3161,  -5.7546],
        [  5.3161,  -5.7546],
        [  3.2757,  -3.6181]])

In [6]:
# =========================
# Hessian-Based Curvature (grad outer-product) for node-level Trojan detection
# =========================
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# -------------------- Parameters --------------------
PER_CLASS = 100          # 100 nodes per class (clean/trojan)
FD_EPS = 5e-3            # finite-diff epsilon for relative-error check
TRIALS_PER_NODE = 10     # average trials per node for relative error
PERT_P = 6.0             # L2 magnitude for final Hessian-aligned perturbation (tuneable)
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

model.to(device)
model.eval()

# -------------------- Class names --------------------
class_names = ["clean", "trojan"]


# -------------------- Helper: compute g(x) --------------------
def compute_gradient(node_idx):
    """
    Returns gradient g = ?_x log p(y_hat|x) at node_idx.
    """
    x0 = X_t[node_idx].detach().clone().to(device).requires_grad_(True)

    # Forward pass with x0 replacing features of node_idx
    X_mod = X_t.clone().detach().to(device)
    X_mod[node_idx] = x0
    logits = model(X_mod, A_t)[node_idx]

    # Use predicted class
    pred_class = logits.argmax().item()
    logp = F.log_softmax(logits, dim=0)
    loss = logp[pred_class]

    g = torch.autograd.grad(loss, x0, retain_graph=False, create_graph=False, allow_unused=False)[0]
    return x0.detach(), g.detach(), pred_class

# -------------------- Storage --------------------
per_sample_info = []   # (node_idx, label, lambda_max, avg_rel_error)

print("\nComputing Hessian curvature proxy for selected nodes...")

for node_idx in selected_nodes:
    node_idx = int(node_idx)
    label = int(labels_np[node_idx])

    x0, g, pred_class = compute_gradient(node_idx)
    if g is None:
        lambda_max = 0.0
        avg_rel_err = 0.0
    else:
        # curvature proxy = ||g||^2
        lambda_max = float(g.norm(p=2).item() ** 2)

        # relative error by finite-difference
        rel_errs = []
        for _ in range(TRIALS_PER_NODE):
            delta = FD_EPS * torch.randn_like(x0).to(device)
            gt_delta = torch.dot(g, delta).item()
            pred_second = 0.5 * (gt_delta ** 2)

            # recompute logits at perturbed input
            X_mod = X_t.clone().detach().to(device)
            X_mod[node_idx] = x0 + delta
            logits_p = model(X_mod, A_t)[node_idx]
            logp_p = F.log_softmax(logits_p, dim=0)
            actual_second = float((logp_p[pred_class] - F.log_softmax(model(X_t, A_t)[node_idx], dim=0)[pred_class]).item() - torch.dot(g, delta).item())

            rel_error = abs(pred_second - actual_second) / (abs(actual_second) + 1e-8)
            rel_errs.append(rel_error)

        avg_rel_err = float(np.mean(rel_errs))

    per_sample_info.append((node_idx, label, lambda_max, avg_rel_err))

# -------------------- Aggregate stats --------------------
clean_stats = [t for t in per_sample_info if t[1]==0]
troj_stats  = [t for t in per_sample_info if t[1]==1]

def summarize(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ls = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = summarize(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = summarize(troj_stats)

print("\nAggregated Hessian curvature stats:")
print(f" Clean:  avg_lambda={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_lambda={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,lambda,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)

evaluate_model("Margin", perturbed_X, selected_nodes)

# ---------------- Metric 4: Prediction Margin ----------------
margin_info = []
for node_idx in selected_nodes:
    logits = model(perturbed_X, A_t)[node_idx]
    pred_class = logits.argmax().item()
    margin = logits[pred_class].item() - logits[[j for j in range(len(logits)) if j!=pred_class]].max().item()
    delta = FD_EPS*torch.randn_like(perturbed_X[node_idx])
    logits_p = model(perturbed_X.clone().detach(), A_t)[node_idx]
    margin_p = logits_p[pred_class].item() - logits_p[[j for j in range(len(logits_p)) if j!=pred_class]].max().item()
    rel_err = abs(margin-margin_p)/(abs(margin_p)+1e-12)
    margin_info.append((int(labels_np[node_idx]), margin, rel_err))
print("\nPrediction Margin:")
for cls in [0,1]:
    vals = [j[1] for j in margin_info if j[0]==cls]
    errs = [j[2] for j in margin_info if j[0]==cls]
    print(f" Class {cls}: margin={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Margin", perturbed_X, selected_nodes)

# ---------------- Metric 5: ARR ----------------
# (kept simplified: min perturbation until flip)
def adversarial_radius(node_idx):
    x0 = perturbed_X[node_idx].detach().clone()
    base_pred = int(model(perturbed_X, A_t)[node_idx].argmax().item())
    eps, growth = 1e-3, 1.2
    while eps < 20:
        x_try = x0 + eps*torch.randn_like(x0)
        with torch.no_grad():
            pred = int(model(perturbed_X.clone().detach(), A_t)[node_idx].argmax().item())
        if pred != base_pred: return eps
        eps *= growth
    return 20.0

arr_info = []
for n in selected_nodes:
    arr_val = adversarial_radius(n)
    # finite-difference style perturbation for ARR
    delta = FD_EPS * torch.randn_like(perturbed_X[n])
    arr_val_p = adversarial_radius(n)  # here you could recompute with perturbed input if desired
    rel_err = abs(arr_val - arr_val_p) / (abs(arr_val_p) + 1e-12)
    arr_info.append((int(labels_np[n]), arr_val, rel_err))

print("\nAdversarial Robustness Radius:")
for cls in [0,1]:
    vals = [j[1] for j in arr_info if j[0] == cls]
    errs = [j[2] for j in arr_info if j[0] == cls]
    print(f" Class {cls}: radius={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("ARR", perturbed_X, selected_nodes)

# ---------------- Metric 6: Stability ----------------
stability_info = []
for node_idx in selected_nodes:
    base_logits = model(perturbed_X, A_t)[node_idx].detach()
    diffs = []
    for _ in range(10):
        noise = 0.05 * torch.randn_like(perturbed_X[node_idx])
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = perturbed_X[node_idx] + noise
        with torch.no_grad():
            logits_n = model(X_mod, A_t)[node_idx]
        diffs.append(torch.norm(logits_n - base_logits).item())
    stability_val = np.mean(diffs)
    # finite-difference style perturbation for stability
    noise_fd = 0.05 * torch.randn_like(perturbed_X[node_idx])
    X_fd = perturbed_X.clone().detach()
    X_fd[node_idx] = perturbed_X[node_idx] + noise_fd
    with torch.no_grad():
        logits_fd = model(X_fd, A_t)[node_idx]
    diffs_fd = [torch.norm(logits_fd - base_logits).item()]
    stability_val_p = np.mean(diffs_fd)
    rel_err = abs(stability_val - stability_val_p) / (abs(stability_val_p) + 1e-12)
    stability_info.append((int(labels_np[node_idx]), stability_val, rel_err))

print("\nStability Under Noise:")
for cls in [0,1]:
    vals = [j[1] for j in stability_info if j[0] == cls]
    errs = [j[2] for j in stability_info if j[0] == cls]
    print(f" Class {cls}: stability={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Stability", perturbed_X, selected_nodes)





Computing Hessian curvature proxy for selected nodes...

Aggregated Hessian curvature stats:
 Clean:  avg_lambda=0.0023 ± 0.0162, avg_FDrel=3.4178e-01 ± 3.4251e-01
 Trojan: avg_lambda=0.0082 ± 0.0364, avg_FDrel=7.5314e-01 ± 1.3846e-01

Sample preview (first 6): (idx,label,lambda,FD_rel_err)
(26119, 0, 3.0676810870633265e-26, 5.7264076863641344e-08)
(57231, 0, 1.8918313797500821e-19, 0.00019066434799292788)
(40297, 0, 4.435684259553356e-12, 0.4839210733958579)
(27843, 0, 1.9685782326908182e-12, 0.3691970972634096)
(26863, 0, 8.190093011877588e-10, 0.776914613113268)
(46828, 0, 3.023158947433883e-15, 0.03561590148276072)

=== Robustness Eval (Margin) ===
Flipped 93/200 (46.50%)
Accuracy=53.50 | Precision=0.5350 | Recall=0.5350 | F1=0.5350
Confusion Matrix:
[[53 47]
 [46 54]]
Classification Report:
              precision    recall  f1-score   support

       clean     0.5354    0.5300    0.5327       100
      trojan     0.5347    0.5400    0.5373       100

    accuracy                

tensor([[ 11.1726, -11.6568],
        [  9.0651,  -9.4224],
        [  9.9202, -10.1688],
        ...,
        [  5.3161,  -5.7546],
        [  5.3161,  -5.7546],
        [  3.2757,  -3.6181]])