#### Training Evaluation similar to https://github.com/DfX-NYUAD/TrojanSAINT

In [2]:
# train_trojansaint_node_fixed.py
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

NODE_CSV  = "GNNDatasets/node.csv"
EDGE_CSV  = "GNNDatasets/node_edges.csv"
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)

# ----------------------------- Load nodes -----------------------------
nodes_df = pd.read_csv(NODE_CSV)

label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        label_col = cand; break
if label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

feat_df = nodes_df.copy()
if "gate_type" in feat_df.columns:
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name",label_col}
num_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
X = feat_df[num_cols].fillna(0.0).values.astype(np.float32)
y = nodes_df[label_col].values.astype(np.int64)

# ----------------------------- Load edges; add missing nodes -----------------------------
edges_df = pd.read_csv(EDGE_CSV)
edges_df["src_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["src"].astype(str)
edges_df["dst_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["dst"].astype(str)

known_uids = set(nodes_df["uid"])
edge_uids = set(edges_df["src_uid"]).union(set(edges_df["dst_uid"]))
missing = list(edge_uids - known_uids)

if missing:
    zero_row = np.zeros((1, X.shape[1]), dtype=np.float32)
    addX = np.repeat(zero_row, len(missing), axis=0)
    addY = -1*np.ones(len(missing), dtype=np.int64)
    add_df = pd.DataFrame({
        "uid": missing,
        "circuit_name": [u.split("::",1)[0] for u in missing],
        "node": [u.split("::",1)[1] for u in missing],
        label_col: addY
    })
    X = np.vstack([X, addX])
    y = np.concatenate([y, addY])
    nodes_df = pd.concat([nodes_df, add_df], ignore_index=True)

uid_to_idx = {u:i for i,u in enumerate(nodes_df["uid"].tolist())}
src_idx = edges_df["src_uid"].map(uid_to_idx).dropna().astype(int).values
dst_idx = edges_df["dst_uid"].map(uid_to_idx).dropna().astype(int).values
edge_index = np.stack([np.concatenate([src_idx, dst_idx]),
                       np.concatenate([dst_idx, src_idx])], axis=0)

num_nodes = X.shape(0) if callable(getattr(X, "shape", None)) else X.shape[0]

# ----------------------------- Scale features -----------------------------
labeled_mask_np = (y >= 0)
scaler = StandardScaler()
X_scaled = X.copy()
X_scaled[labeled_mask_np] = scaler.fit_transform(X_scaled[labeled_mask_np])
if (~labeled_mask_np).any():
    X_scaled[~labeled_mask_np] = (X_scaled[~labeled_mask_np] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# ----------------------------- Splits -----------------------------
idx_all = np.where(labeled_mask_np)[0]
y_all = y[labeled_mask_np]

idx_train, idx_tmp, y_train, y_tmp = train_test_split(
    idx_all, y_all, test_size=0.30, random_state=SEED, stratify=y_all
)
idx_val, idx_test, y_val, y_test = train_test_split(
    idx_tmp, y_tmp, test_size=0.50, random_state=SEED, stratify=y_tmp
)

# ----------------------------- Torch tensors -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_t = torch.from_numpy(X_scaled).to(device)
y_t = torch.from_numpy(y).to(device)
edge_index_t = torch.from_numpy(edge_index).long().to(device)

train_mask_t = torch.zeros(len(y), dtype=torch.bool, device=device); train_mask_t[idx_train] = True
val_mask_t   = torch.zeros(len(y), dtype=torch.bool, device=device); val_mask_t[idx_val]   = True
test_mask_t  = torch.zeros(len(y), dtype=torch.bool, device=device); test_mask_t[idx_test] = True
labeled_mask_t = torch.from_numpy(labeled_mask_np).to(device)

# ----------------------------- Utility: degree, neighbors -----------------------------
def degrees(num_nodes, ei):
    deg = torch.bincount(ei[0], minlength=num_nodes).float()
    return deg

def coalesce_edge_index(ei):
    # simple coalesce: unique columns of ei (2 x E)
    u = ei[0]* (ei[0].max()+1) + ei[1]
    uniq, idx = torch.unique(u, return_inverse=True)
    ei0 = uniq // (ei[0].max()+1)
    ei1 = uniq %  (ei[0].max()+1)
    return torch.stack([ei0, ei1], dim=0)

# ----------------------------- GraphSAINT-style sampler -----------------------------
class GraphSAINTSampler:
    """
    Lightweight GraphSAINT-style node sampler:
    - Sample a set of seed nodes (biased by degree for stability)
    - Induce 1-hop subgraph
    - Provide normalization weights to approximately debias sampling
    """
    def __init__(self, num_nodes, edge_index, train_mask, batch_size=2000, num_steps=1,
                 bias_by_degree=True, device='cpu'):
        self.num_nodes = num_nodes
        self.edge_index = edge_index
        self.train_mask = train_mask
        self.batch_size = batch_size
        self.num_steps = num_steps
        self.device = device

        deg = torch.bincount(edge_index[0], minlength=num_nodes).float()
        deg = deg + 1e-6
        base_prob = deg / deg.sum() if bias_by_degree else torch.full((num_nodes,), 1.0/num_nodes, device=edge_index.device)

        # Only sample from labeled training nodes
        self.sample_space = torch.where(train_mask)[0]
        p_space = base_prob[self.sample_space]
        self.p_space = p_space / p_space.sum()

        self.base_prob = base_prob  # keep global probs for normalization

    def __iter__(self):
        return self

    def next_seeds(self):
        replace = self.sample_space.numel() < self.batch_size
        idx = torch.multinomial(self.p_space, self.batch_size, replacement=replace)
        return self.sample_space[idx]

    def _induce_subgraph(self, seeds):
        ei = self.edge_index
        # keep edges if either endpoint is in seeds (1-hop expansion)
        mask = (ei[0].unsqueeze(1) == seeds.unsqueeze(0)).any(dim=1) | \
               (ei[1].unsqueeze(1) == seeds.unsqueeze(0)).any(dim=1)
        sub_ei = ei[:, mask]

        # nodes present in subgraph
        nodes = torch.unique(torch.cat([sub_ei[0], sub_ei[1], seeds], dim=0))

        # map global -> local ids
        nid_map = -torch.ones(self.num_nodes, dtype=torch.long, device=ei.device)
        nid_map[nodes] = torch.arange(nodes.numel(), device=ei.device)
        sub_ei = nid_map[sub_ei]

        # coalesce and normalize (GCN)
        # ensure no negative indices (can happen if sub_ei was empty)
        if sub_ei.numel() == 0:
            # create a trivial self-loop to avoid empty adj
            sub_ei = torch.stack([torch.tensor([0], device=ei.device), torch.tensor([0], device=ei.device)], dim=0)
            nodes = nodes[:1]

        # coalesce
        u = sub_ei[0] * nodes.numel() + sub_ei[1]
        uniq = torch.unique(u)
        sub_ei0 = uniq // nodes.numel()
        sub_ei1 = uniq %  nodes.numel()
        sub_ei = torch.stack([sub_ei0, sub_ei1], dim=0)

        # add self-loops
        self_loops = torch.arange(nodes.numel(), device=ei.device)
        sub_ei = torch.cat([sub_ei, torch.stack([self_loops, self_loops])], dim=1)

        deg = torch.bincount(sub_ei[0], minlength=nodes.numel()).float()
        deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
        w = deg_inv_sqrt[sub_ei[0]] * deg_inv_sqrt[sub_ei[1]]
        A = torch.sparse_coo_tensor(sub_ei, w, (nodes.numel(), nodes.numel())).coalesce()
        return nodes, A

    def __next__(self):
        seeds = self.next_seeds()
        nodes, A = self._induce_subgraph(seeds)
        # importance weights (GraphSAINT normalization proxy)
        p = torch.clamp(self.base_prob[nodes], min=1e-8)
        norm_loss = (1.0 / p) / (1.0 / p).mean()
        return nodes, A, norm_loss

# ----------------------------- TrojanSAINT (GraphSAINT-style) model -----------------------------
class SAINTGCNLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.0, bias=True):
        super().__init__()
        self.lin = nn.Linear(in_dim, out_dim, bias=bias)
        self.bn = nn.BatchNorm1d(out_dim)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.lin.weight)

    def forward(self, x, adj):
        # GCN propagation on sampled subgraph
        x = torch.sparse.mm(adj, x)
        x = self.lin(x)
        x = self.bn(x)
        x = F.relu(x, inplace=True)
        x = self.dropout(x)
        return x

class TrojanSAINT(nn.Module):
    def __init__(self, in_dim, hid_dim=96, out_dim=2, dropout=0.35):
        super().__init__()
        self.g1 = SAINTGCNLayer(in_dim, hid_dim, dropout=dropout)
        self.g2 = nn.Linear(hid_dim, out_dim, bias=True)
        nn.init.xavier_uniform_(self.g2.weight)

    def forward(self, x, adj):
        x = self.g1(x, adj)
        x = torch.sparse.mm(adj, x)  # last propagation (linear head after propagation)
        x = self.g2(x)
        return x

model = TrojanSAINT(in_dim=X_t.size(1), hid_dim=96, out_dim=2, dropout=0.35).to(device)

# ----------------------------- Class weights, loss, optimizer -----------------------------
train_labels = y_t[train_mask_t]
classes, counts = torch.unique(train_labels, return_counts=True)
num_pos = counts[classes==1].item() if (classes==1).any() else 1
num_neg = counts[classes==0].item() if (classes==0).any() else 1
weight_pos = (num_neg + num_pos) / (2.0 * num_pos)
weight_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights = torch.tensor([weight_neg, weight_pos], dtype=torch.float32, device=device)

criterion = nn.CrossEntropyLoss(reduction='none', weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3, weight_decay=5e-4)

# ----------------------------- Sampler -----------------------------
num_nodes = X_t.size(0)
sampler = GraphSAINTSampler(
    num_nodes=num_nodes,
    edge_index=edge_index_t,
    train_mask=train_mask_t & (y_t >= 0),
    batch_size=min(4000, (train_mask_t & (y_t >= 0)).sum().item()),
    device=device
)

# ----------------------------- Evaluation (full-graph) -----------------------------
def build_full_adj(num_nodes, edge_index):
    self_loops = torch.arange(num_nodes, device=edge_index.device)
    ei = torch.cat([edge_index, torch.stack([self_loops, self_loops])], dim=1)
    deg = torch.bincount(ei[0], minlength=num_nodes).float()
    deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
    w = deg_inv_sqrt[ei[0]] * deg_inv_sqrt[ei[1]]
    A = torch.sparse_coo_tensor(ei, w, (num_nodes, num_nodes))
    return A.coalesce()

A_full = build_full_adj(X_t.size(0), edge_index_t)

@torch.no_grad()
def evaluate(mask_t):
    model.eval()
    logits = model(X_t, A_full)
    pred = logits.argmax(dim=1)
    msk = mask_t & (y_t >= 0)
    if msk.sum() == 0: return 0.0
    return (pred[msk] == y_t[msk]).float().mean().item()

# ----------------------------- Training (GraphSAINT-style) -----------------------------
best_val, best_state = -1.0, None
patience, patience_cnt = 20, 0
EPOCHS = 300
steps_per_epoch = max(1, int(np.ceil((train_mask_t & (y_t >= 0)).sum().item() / sampler.batch_size)))

for epoch in range(1, EPOCHS+1):
    model.train()
    epoch_loss = 0.0
    for _ in range(steps_per_epoch):
        nodes, A_b, norm_loss = next(sampler)
        x_b = X_t[nodes]
        y_b = y_t[nodes]
        train_b = train_mask_t[nodes] & (y_b >= 0)
    
        if train_b.sum() == 0:
            continue
    
        logits_b = model(x_b, A_b)
        loss_raw = criterion(logits_b[train_b], y_b[train_b])
        loss = (loss_raw * norm_loss[train_b]).mean()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 0 or epoch == 1:
        val_acc = evaluate(val_mask_t)
        test_acc = evaluate(test_mask_t)
        print(f"Epoch {epoch:03d} | Loss {epoch_loss/steps_per_epoch:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print("Early stopping."); break

if best_state is not None:
    model.load_state_dict(best_state)

# ----------------------------- Final eval -----------------------------
model.eval()
with torch.no_grad():
    logits = model(X_t, A_full)
    preds = logits.argmax(dim=1)

msk = (test_mask_t & (y_t >= 0)).cpu().numpy()
y_true = y_t.cpu().numpy()[msk]
y_pred = preds.cpu().numpy()[msk]

acc = (y_true == y_pred).mean()
print("\nFinal Evaluation (Node-Level, TrojanSAINT/GraphSAINT-style)")
print("============================================================")
print(f"Test Accuracy: {acc:.4f}\n")

print("Classification Report:")
print(classification_report(y_true, y_pred, labels=[0,1], target_names=["clean","trojan"], digits=4))

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred, labels=[0,1]))


Epoch 001 | Loss 0.2149 | Val 0.2494 | Test 0.2495
Epoch 010 | Loss 0.0014 | Val 0.2497 | Test 0.2500
Epoch 020 | Loss 0.0007 | Val 0.3832 | Test 0.3831
Epoch 030 | Loss 0.0006 | Val 0.3490 | Test 0.3488
Epoch 040 | Loss 0.0006 | Val 0.2675 | Test 0.2687
Epoch 050 | Loss 0.0025 | Val 0.8437 | Test 0.8493
Epoch 060 | Loss 0.0005 | Val 0.9770 | Test 0.9776
Epoch 070 | Loss 0.0005 | Val 0.9625 | Test 0.9622
Epoch 080 | Loss 0.0015 | Val 0.7824 | Test 0.7851
Epoch 090 | Loss 0.0004 | Val 0.3913 | Test 0.3925
Epoch 100 | Loss 0.0004 | Val 0.2505 | Test 0.2509
Epoch 110 | Loss 0.0004 | Val 0.2494 | Test 0.2495
Epoch 120 | Loss 0.0004 | Val 0.2783 | Test 0.2783
Epoch 130 | Loss 0.0004 | Val 0.2757 | Test 0.2754
Epoch 140 | Loss 0.0019 | Val 0.2494 | Test 0.2495
Epoch 150 | Loss 0.0022 | Val 0.2494 | Test 0.2495
Epoch 160 | Loss 0.0019 | Val 0.2494 | Test 0.2495
Epoch 170 | Loss 0.0004 | Val 0.2494 | Test 0.2495
Epoch 180 | Loss 0.0004 | Val 0.2494 | Test 0.2495
Epoch 190 | Loss 0.0004 | Val 0

In [5]:
# ================================
# Unified Robustness Evaluation
# ================================
import torch, numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ---------------- Parameters ----------------
PER_CLASS = 100
EPSILON   = 5.0     # L2 budget for PGD
ALPHA     = 1.0     # PGD step size
NUM_ITERS = 40      # PGD iterations
FD_EPS    = 1e-3    # finite-difference epsilon
SEED      = 42

torch.manual_seed(SEED); np.random.seed(SEED)
A_t = A_full
required_vars = ["model","X_t","A_t","y_t","test_mask_t","device"]
for v in required_vars:
    if v not in globals():
        raise RuntimeError(f"Required var '{v}' not found.")

model.to(device); model.eval()
labels_np = y_t.cpu().numpy()

# ---------------- Node Selection ----------------
test_indices = np.where(test_mask_t.cpu().numpy())[0]
rng = np.random.default_rng(SEED)
selected_nodes = []
for cls in [0,1]:
    idxs = [int(i) for i in test_indices if labels_np[i]==cls]
    chosen = rng.choice(idxs, size=min(PER_CLASS, len(idxs)), replace=False)
    selected_nodes.extend(chosen)
selected_nodes = np.array(selected_nodes, dtype=np.int64)

print(f"Selected: clean={int((labels_np[selected_nodes]==0).sum())}, "
      f"trojan={int((labels_np[selected_nodes]==1).sum())}")

# ---------------- Shared PGD Perturbations ----------------
perturbed_X = X_t.clone().detach().to(device)
for node_idx in selected_nodes:
    node_idx = int(node_idx)
    x_orig = X_t[node_idx].detach().clone().to(device)
    x_adv = (x_orig + 1e-3*torch.randn_like(x_orig)).detach().requires_grad_(True)

    for _ in range(NUM_ITERS):
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = x_adv
        logits = model(X_mod, A_t)
        loss = F.cross_entropy(logits[node_idx].unsqueeze(0), y_t[node_idx].unsqueeze(0))
        grad_x = torch.autograd.grad(loss, x_adv)[0]
        if grad_x.norm().item()==0: break
        step = ALPHA * grad_x / (grad_x.norm() + 1e-12)
        x_adv = (x_adv + step).detach()
        delta = x_adv - x_orig
        if delta.norm() > EPSILON:
            delta = delta * (EPSILON/(delta.norm()+1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)
    perturbed_X[node_idx] = x_adv.detach()
print("? Shared PGD perturbations done.")

# ---------------- Eval Helper ----------------
def evaluate_model(name, perturbed_X, selected_nodes):
    with torch.no_grad():
        # Predictions on original and perturbed inputs
        orig_logits = model(X_t, A_t)
        pert_logits = model(perturbed_X, A_t)

        orig_preds = orig_logits.argmax(dim=1).cpu().numpy()
        pert_preds = pert_logits.argmax(dim=1).cpu().numpy()

    # Restrict to selected perturbed samples only
    sel_idx = np.array(selected_nodes)
    sel_labels = labels_np[sel_idx]
    sel_orig_preds = orig_preds[sel_idx]
    sel_pert_preds = pert_preds[sel_idx]

    # Flip count first
    flips = (sel_orig_preds != sel_pert_preds).sum()
    print(f"\n=== Robustness Eval ({name}) ===")
    print(f"Flipped {flips}/{len(sel_idx)} ({100*flips/len(sel_idx):.2f}%)")

    # Accuracy, precision, recall, F1 on perturbed subset only
    acc = (sel_pert_preds == sel_labels).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(
        sel_labels, sel_pert_preds, average="weighted", zero_division=0)

    print(f"Accuracy={acc*100:.2f} | Precision={prec:.4f} | Recall={rec:.4f} | F1={f1:.4f}")
    print("Confusion Matrix:")
    print(confusion_matrix(sel_labels, sel_pert_preds, labels=[0, 1]))
    print("Classification Report:")
    print(classification_report(sel_labels, sel_pert_preds,
                                target_names=["clean", "trojan"], digits=4))

    return pert_logits


# ---------------- Metric 1: Jacobian Sensitivity ----------------
jac_info = []
for node_idx in selected_nodes:
    x0 = perturbed_X[node_idx].detach().clone().requires_grad_(True)
    def f_local(x):
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = x
        return model(X_mod, A_t)[node_idx]
    J = torch.autograd.functional.jacobian(f_local, x0)
    jac_norm = torch.norm(J, p='fro').item()
    delta_fd = FD_EPS * torch.randn_like(x0)
    pred_change = J.mv(delta_fd)
    f0, f0p = f_local(x0).detach(), f_local(x0+delta_fd).detach()
    actual_change = f0p - f0
    rel_err = (torch.norm(pred_change-actual_change)/(torch.norm(actual_change)+1e-8)).item()
    jac_info.append((int(labels_np[node_idx]), jac_norm, rel_err))
print("\nJacobian Sensitivity:")
for cls in [0,1]:
    vals = [j[1] for j in jac_info if j[0]==cls]
    errs = [j[2] for j in jac_info if j[0]==cls]
    print(f" Class {cls}: norm={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Jacobian", perturbed_X, selected_nodes)

# ---------------- Metric 2: Lipschitz (Spectral Norm) ----------------
lip_info = []
for node_idx in selected_nodes:
    x0 = perturbed_X[node_idx].detach().clone().requires_grad_(True)
    def f_node(x):
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = x
        return model(X_mod, A_t)[node_idx]
    J = torch.autograd.functional.jacobian(f_node, x0).detach()
    U, S, Vh = torch.linalg.svd(J, full_matrices=False)
    sigma_max = S[0].item()
    delta_fd = FD_EPS*torch.randn_like(x0)
    pred_change = J.mv(delta_fd)
    f0, f0p = f_node(x0).detach(), f_node(x0+delta_fd).detach()
    actual_change = f0p-f0
    rel_err = (torch.norm(pred_change-actual_change)/(torch.norm(actual_change)+1e-8)).item()
    lip_info.append((int(labels_np[node_idx]), sigma_max, rel_err))
print("\nLipschitz Constant:")
for cls in [0,1]:
    vals = [j[1] for j in lip_info if j[0]==cls]
    errs = [j[2] for j in lip_info if j[0]==cls]
    print(f" Class {cls}: L={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Lipschitz", perturbed_X, selected_nodes)

Selected: clean=100, trojan=100
? Shared PGD perturbations done.

Jacobian Sensitivity:
 Class 0: norm=0.4569±0.1086, relerr=1.8763e-03±4.0839e-03
 Class 1: norm=7.4574±0.0800, relerr=1.6980e-03±7.2592e-03

=== Robustness Eval (Jacobian) ===
Flipped 144/200 (72.00%)
Accuracy=27.50 | Precision=0.1774 | Recall=0.2750 | F1=0.2157
Confusion Matrix:
[[ 55  45]
 [100   0]]
Classification Report:
              precision    recall  f1-score   support

       clean     0.3548    0.5500    0.4314       100
      trojan     0.0000    0.0000    0.0000       100

    accuracy                         0.2750       200
   macro avg     0.1774    0.2750    0.2157       200
weighted avg     0.1774    0.2750    0.2157       200


Lipschitz Constant:
 Class 0: L=0.4566±0.1085, relerr=1.8539e-03±5.3200e-03
 Class 1: L=7.4548±0.0799, relerr=1.1322e-03±2.0651e-03

=== Robustness Eval (Lipschitz) ===
Flipped 144/200 (72.00%)
Accuracy=27.50 | Precision=0.1774 | Recall=0.2750 | F1=0.2157
Confusion Matrix:
[[ 55

tensor([[ 4.9928, -5.0827],
        [ 4.9975, -4.9917],
        [ 4.9780, -4.9694],
        ...,
        [ 4.7599, -4.9413],
        [ 4.0829, -4.2264],
        [ 4.0829, -4.2264]])

In [7]:
# =========================
# Hessian-Based Curvature (grad outer-product) for node-level Trojan detection
# =========================
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# -------------------- Parameters --------------------
PER_CLASS = 100          # 100 nodes per class (clean/trojan)
FD_EPS = 5e-3            # finite-diff epsilon for relative-error check
TRIALS_PER_NODE = 10     # average trials per node for relative error
PERT_P = 6.0             # L2 magnitude for final Hessian-aligned perturbation (tuneable)
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

model.to(device)
model.eval()

# -------------------- Class names --------------------
class_names = ["clean", "trojan"]


# -------------------- Helper: compute g(x) --------------------
def compute_gradient(node_idx):
    """
    Returns gradient g = ?_x log p(y_hat|x) at node_idx.
    """
    x0 = X_t[node_idx].detach().clone().to(device).requires_grad_(True)

    # Forward pass with x0 replacing features of node_idx
    X_mod = X_t.clone().detach().to(device)
    X_mod[node_idx] = x0
    logits = model(X_mod, A_t)[node_idx]

    # Use predicted class
    pred_class = logits.argmax().item()
    logp = F.log_softmax(logits, dim=0)
    loss = logp[pred_class]

    g = torch.autograd.grad(loss, x0, retain_graph=False, create_graph=False, allow_unused=False)[0]
    return x0.detach(), g.detach(), pred_class

# -------------------- Storage --------------------
per_sample_info = []   # (node_idx, label, lambda_max, avg_rel_error)

print("\nComputing Hessian curvature proxy for selected nodes...")

for node_idx in selected_nodes:
    node_idx = int(node_idx)
    label = int(labels_np[node_idx])

    x0, g, pred_class = compute_gradient(node_idx)
    if g is None:
        lambda_max = 0.0
        avg_rel_err = 0.0
    else:
        # curvature proxy = ||g||^2
        lambda_max = float(g.norm(p=2).item() ** 2)

        # relative error by finite-difference
        rel_errs = []
        for _ in range(TRIALS_PER_NODE):
            delta = FD_EPS * torch.randn_like(x0).to(device)
            gt_delta = torch.dot(g, delta).item()
            pred_second = 0.5 * (gt_delta ** 2)

            # recompute logits at perturbed input
            X_mod = X_t.clone().detach().to(device)
            X_mod[node_idx] = x0 + delta
            logits_p = model(X_mod, A_t)[node_idx]
            logp_p = F.log_softmax(logits_p, dim=0)
            actual_second = float((logp_p[pred_class] - F.log_softmax(model(X_t, A_t)[node_idx], dim=0)[pred_class]).item() - torch.dot(g, delta).item())

            rel_error = abs(pred_second - actual_second) / (abs(actual_second) + 1e-8)
            rel_errs.append(rel_error)

        avg_rel_err = float(np.mean(rel_errs))

    per_sample_info.append((node_idx, label, lambda_max, avg_rel_err))

# -------------------- Aggregate stats --------------------
clean_stats = [t for t in per_sample_info if t[1]==0]
troj_stats  = [t for t in per_sample_info if t[1]==1]

def summarize(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ls = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = summarize(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = summarize(troj_stats)

print("\nAggregated Hessian curvature stats:")
print(f" Clean:  avg_lambda={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_lambda={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,lambda,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)

evaluate_model("Margin", perturbed_X, selected_nodes)

# ---------------- Metric 4: Prediction Margin ----------------
margin_info = []
for node_idx in selected_nodes:
    logits = model(perturbed_X, A_t)[node_idx]
    pred_class = logits.argmax().item()
    margin = logits[pred_class].item() - logits[[j for j in range(len(logits)) if j!=pred_class]].max().item()
    delta = FD_EPS*torch.randn_like(perturbed_X[node_idx])
    logits_p = model(perturbed_X.clone().detach(), A_t)[node_idx]
    margin_p = logits_p[pred_class].item() - logits_p[[j for j in range(len(logits_p)) if j!=pred_class]].max().item()
    rel_err = abs(margin-margin_p)/(abs(margin_p)+1e-12)
    margin_info.append((int(labels_np[node_idx]), margin, rel_err))
print("\nPrediction Margin:")
for cls in [0,1]:
    vals = [j[1] for j in margin_info if j[0]==cls]
    errs = [j[2] for j in margin_info if j[0]==cls]
    print(f" Class {cls}: margin={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Margin", perturbed_X, selected_nodes)

# ---------------- Metric 5: ARR ----------------
# (kept simplified: min perturbation until flip)
def adversarial_radius(node_idx):
    x0 = perturbed_X[node_idx].detach().clone()
    base_pred = int(model(perturbed_X, A_t)[node_idx].argmax().item())
    eps, growth = 1e-3, 1.2
    while eps < 20:
        x_try = x0 + eps*torch.randn_like(x0)
        with torch.no_grad():
            pred = int(model(perturbed_X.clone().detach(), A_t)[node_idx].argmax().item())
        if pred != base_pred: return eps
        eps *= growth
    return 20.0

arr_info = []
for n in selected_nodes:
    arr_val = adversarial_radius(n)
    # finite-difference style perturbation for ARR
    delta = FD_EPS * torch.randn_like(perturbed_X[n])
    arr_val_p = adversarial_radius(n)  # here you could recompute with perturbed input if desired
    rel_err = abs(arr_val - arr_val_p) / (abs(arr_val_p) + 1e-12)
    arr_info.append((int(labels_np[n]), arr_val, rel_err))

print("\nAdversarial Robustness Radius:")
for cls in [0,1]:
    vals = [j[1] for j in arr_info if j[0] == cls]
    errs = [j[2] for j in arr_info if j[0] == cls]
    print(f" Class {cls}: radius={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("ARR", perturbed_X, selected_nodes)

# ---------------- Metric 6: Stability ----------------
stability_info = []
for node_idx in selected_nodes:
    base_logits = model(perturbed_X, A_t)[node_idx].detach()
    diffs = []
    for _ in range(10):
        noise = 0.05 * torch.randn_like(perturbed_X[node_idx])
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = perturbed_X[node_idx] + noise
        with torch.no_grad():
            logits_n = model(X_mod, A_t)[node_idx]
        diffs.append(torch.norm(logits_n - base_logits).item())
    stability_val = np.mean(diffs)
    # finite-difference style perturbation for stability
    noise_fd = 0.05 * torch.randn_like(perturbed_X[node_idx])
    X_fd = perturbed_X.clone().detach()
    X_fd[node_idx] = perturbed_X[node_idx] + noise_fd
    with torch.no_grad():
        logits_fd = model(X_fd, A_t)[node_idx]
    diffs_fd = [torch.norm(logits_fd - base_logits).item()]
    stability_val_p = np.mean(diffs_fd)
    rel_err = abs(stability_val - stability_val_p) / (abs(stability_val_p) + 1e-12)
    stability_info.append((int(labels_np[node_idx]), stability_val, rel_err))

print("\nStability Under Noise:")
for cls in [0,1]:
    vals = [j[1] for j in stability_info if j[0] == cls]
    errs = [j[2] for j in stability_info if j[0] == cls]
    print(f" Class {cls}: stability={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Stability", perturbed_X, selected_nodes)





Computing Hessian curvature proxy for selected nodes...

Aggregated Hessian curvature stats:
 Clean:  avg_lambda=0.0000 ± 0.0000, avg_FDrel=7.2122e-01 ± 8.1502e-02
 Trojan: avg_lambda=0.1053 ± 0.1613, avg_FDrel=1.1916e+00 ± 1.7923e-01

Sample preview (first 6): (idx,label,lambda,FD_rel_err)
(26119, 0, 5.3058623615239424e-09, 0.5967704660381343)
(57231, 0, 2.919435396550133e-08, 0.6549227663929634)
(40297, 0, 3.253715851850403e-08, 0.7398776462650221)
(27843, 0, 7.068591906242526e-07, 0.7594926884346475)
(26863, 0, 1.433692691525624e-06, 0.7370227599597732)
(46828, 0, 3.9822156324595236e-08, 0.7163436675705489)

=== Robustness Eval (Margin) ===
Flipped 144/200 (72.00%)
Accuracy=27.50 | Precision=0.1774 | Recall=0.2750 | F1=0.2157
Confusion Matrix:
[[ 55  45]
 [100   0]]
Classification Report:
              precision    recall  f1-score   support

       clean     0.3548    0.5500    0.4314       100
      trojan     0.0000    0.0000    0.0000       100

    accuracy                    

KeyboardInterrupt: 

In [8]:

# ---------------- Metric 4: Prediction Margin ----------------
margin_info = []
for node_idx in selected_nodes:
    logits = model(perturbed_X, A_t)[node_idx]
    pred_class = logits.argmax().item()
    margin = logits[pred_class].item() - logits[[j for j in range(len(logits)) if j!=pred_class]].max().item()
    delta = FD_EPS*torch.randn_like(perturbed_X[node_idx])
    logits_p = model(perturbed_X.clone().detach(), A_t)[node_idx]
    margin_p = logits_p[pred_class].item() - logits_p[[j for j in range(len(logits_p)) if j!=pred_class]].max().item()
    rel_err = abs(margin-margin_p)/(abs(margin_p)+1e-12)
    margin_info.append((int(labels_np[node_idx]), margin, rel_err))
print("\nPrediction Margin:")
for cls in [0,1]:
    vals = [j[1] for j in margin_info if j[0]==cls]
    errs = [j[2] for j in margin_info if j[0]==cls]
    print(f" Class {cls}: margin={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Margin", perturbed_X, selected_nodes)

# ---------------- Metric 5: ARR ----------------
# (kept simplified: min perturbation until flip)
def adversarial_radius(node_idx):
    x0 = perturbed_X[node_idx].detach().clone()
    base_pred = int(model(perturbed_X, A_t)[node_idx].argmax().item())
    eps, growth = 1e-3, 1.2
    while eps < 20:
        x_try = x0 + eps*torch.randn_like(x0)
        with torch.no_grad():
            pred = int(model(perturbed_X.clone().detach(), A_t)[node_idx].argmax().item())
        if pred != base_pred: return eps
        eps *= growth
    return 20.0

arr_info = []
for n in selected_nodes:
    arr_val = adversarial_radius(n)
    # finite-difference style perturbation for ARR
    delta = FD_EPS * torch.randn_like(perturbed_X[n])
    arr_val_p = adversarial_radius(n)  # here you could recompute with perturbed input if desired
    rel_err = abs(arr_val - arr_val_p) / (abs(arr_val_p) + 1e-12)
    arr_info.append((int(labels_np[n]), arr_val, rel_err))

print("\nAdversarial Robustness Radius:")
for cls in [0,1]:
    vals = [j[1] for j in arr_info if j[0] == cls]
    errs = [j[2] for j in arr_info if j[0] == cls]
    print(f" Class {cls}: radius={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("ARR", perturbed_X, selected_nodes)

# ---------------- Metric 6: Stability ----------------
stability_info = []
for node_idx in selected_nodes:
    base_logits = model(perturbed_X, A_t)[node_idx].detach()
    diffs = []
    for _ in range(10):
        noise = 0.05 * torch.randn_like(perturbed_X[node_idx])
        X_mod = perturbed_X.clone().detach()
        X_mod[node_idx] = perturbed_X[node_idx] + noise
        with torch.no_grad():
            logits_n = model(X_mod, A_t)[node_idx]
        diffs.append(torch.norm(logits_n - base_logits).item())
    stability_val = np.mean(diffs)
    # finite-difference style perturbation for stability
    noise_fd = 0.05 * torch.randn_like(perturbed_X[node_idx])
    X_fd = perturbed_X.clone().detach()
    X_fd[node_idx] = perturbed_X[node_idx] + noise_fd
    with torch.no_grad():
        logits_fd = model(X_fd, A_t)[node_idx]
    diffs_fd = [torch.norm(logits_fd - base_logits).item()]
    stability_val_p = np.mean(diffs_fd)
    rel_err = abs(stability_val - stability_val_p) / (abs(stability_val_p) + 1e-12)
    stability_info.append((int(labels_np[node_idx]), stability_val, rel_err))

print("\nStability Under Noise:")
for cls in [0,1]:
    vals = [j[1] for j in stability_info if j[0] == cls]
    errs = [j[2] for j in stability_info if j[0] == cls]
    print(f" Class {cls}: stability={np.mean(vals):.4f}±{np.std(vals):.4f}, relerr={np.mean(errs):.4e}±{np.std(errs):.4e}")

evaluate_model("Stability", perturbed_X, selected_nodes)




Prediction Margin:
 Class 0: margin=1.7415±1.5443, relerr=0.0000e+00±0.0000e+00
 Class 1: margin=38.8759±1.5805, relerr=0.0000e+00±0.0000e+00

=== Robustness Eval (Margin) ===
Flipped 144/200 (72.00%)
Accuracy=27.50 | Precision=0.1774 | Recall=0.2750 | F1=0.2157
Confusion Matrix:
[[ 55  45]
 [100   0]]
Classification Report:
              precision    recall  f1-score   support

       clean     0.3548    0.5500    0.4314       100
      trojan     0.0000    0.0000    0.0000       100

    accuracy                         0.2750       200
   macro avg     0.1774    0.2750    0.2157       200
weighted avg     0.1774    0.2750    0.2157       200


Adversarial Robustness Radius:
 Class 0: radius=20.0000±0.0000, relerr=0.0000e+00±0.0000e+00
 Class 1: radius=20.0000±0.0000, relerr=0.0000e+00±0.0000e+00

=== Robustness Eval (ARR) ===
Flipped 144/200 (72.00%)
Accuracy=27.50 | Precision=0.1774 | Recall=0.2750 | F1=0.2157
Confusion Matrix:
[[ 55  45]
 [100   0]]
Classification Report:
       

tensor([[ 4.9928, -5.0827],
        [ 4.9975, -4.9917],
        [ 4.9780, -4.9694],
        ...,
        [ 4.7599, -4.9413],
        [ 4.0829, -4.2264],
        [ 4.0829, -4.2264]])