#### Training Evaluation with HW2VEC taken from https://github.com/AICPS/hw2vec

In [1]:
# train_subgraph_gnn_fixed_hw2vec.py
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, global_mean_pool  # NOTE: switched to HW2VEC-style GIN

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 80
LR = 1e-3
HID_DIM = 64

# -----------------------
# Load node features
# -----------------------
nodes_df = pd.read_csv("GNNDatasets/node.csv")
nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# one-hot gate_type
if "gate_type" in nodes_df.columns:
    gate_ohe = pd.get_dummies(nodes_df["gate_type"], prefix="gt")
    nodes_feat_df = pd.concat([nodes_df.drop(columns=["gate_type"]), gate_ohe], axis=1)
else:
    nodes_feat_df = nodes_df.copy()

meta_cols = {"uid","node","circuit_name","label","label_node","label_graph","label_subgraph","folder"}
num_cols = [c for c in nodes_feat_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(nodes_feat_df[c])]

uid_to_feat = {}
for _, r in nodes_feat_df.iterrows():
    uid_to_feat[r["uid"]] = r[num_cols].astype(float).values

scaler = StandardScaler().fit(np.stack(list(uid_to_feat.values())))
for k in list(uid_to_feat.keys()):
    uid_to_feat[k] = scaler.transform(uid_to_feat[k].reshape(1,-1)).reshape(-1)

feat_dim = len(uid_to_feat[list(uid_to_feat.keys())[0]])

# -----------------------
# Merge edge CSVs
# -----------------------
edge_files = [
    "GNNDatasets/node_edges.csv",
    "GNNDatasets/subgraph_edges_andxor.csv",
    "GNNDatasets/subgraph_edges_countermux.csv",
    "GNNDatasets/subgraph_edges_fsmor.csv",
]

edges_by_circuit = defaultdict(list)
for ef in edge_files:
    df = pd.read_csv(ef)
    for _, r in df.iterrows():
        edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# -----------------------
# Build dataset
# -----------------------
sub_df = pd.read_csv("GNNDatasets/subgraph.csv")
data_list, labels = [], []

for idx, row in tqdm(sub_df.iterrows(), total=len(sub_df), desc="Building subgraphs"):
    ckt = row["circuit_name"]
    lbl = int(row.get("label_subgraph", row.get("label", 0)))
    
    if ckt not in edges_by_circuit:
        continue
    
    # collect node uids from node.csv that belong to this circuit
    sub_nodes = nodes_df[nodes_df["circuit_name"]==ckt]["node"].tolist()
    if not sub_nodes: 
        continue
    
    uid_map = {n:i for i,n in enumerate(sub_nodes)}
    x_list = []
    for n in sub_nodes:
        uid = f"{ckt}::{n}"
        if uid in uid_to_feat:
            x_list.append(uid_to_feat[uid])
        else:
            x_list.append(np.zeros(feat_dim))
    x = torch.tensor(np.vstack(x_list), dtype=torch.float)
    
    # build edge_index
    edge_idx = [[], []]
    for u,v in edges_by_circuit[ckt]:
        if u in uid_map and v in uid_map:
            edge_idx[0].append(uid_map[u]); edge_idx[1].append(uid_map[v])
            edge_idx[0].append(uid_map[v]); edge_idx[1].append(uid_map[u])
    if not edge_idx[0]:
        continue
    edge_index = torch.tensor(edge_idx, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    data_list.append(data)
    labels.append(lbl)

print(f"Built {len(data_list)} subgraphs (usable)")

# -----------------------
# Split
# -----------------------
labels = np.array(labels)
idxs = np.arange(len(data_list))
train_idx, temp_idx, y_train, y_temp = train_test_split(idxs, labels, test_size=0.3, 
                                                        stratify=labels, random_state=RANDOM_SEED)
val_idx, test_idx, y_val, y_test = train_test_split(temp_idx, y_temp, test_size=0.5,
                                                    stratify=y_temp, random_state=RANDOM_SEED)

train_loader = DataLoader([data_list[i] for i in train_idx], batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader([data_list[i] for i in val_idx], batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader([data_list[i] for i in test_idx], batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

# -----------------------
# Model (HW2VEC-style GIN)
# -----------------------
class MLP(nn.Module):
    def __init__(self, in_dim, hid_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim, bias=True),
            nn.BatchNorm1d(hid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_dim, hid_dim, bias=True),
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

class HW2VEC_GIN_Subgraph(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2, dropout=0.4):
        super().__init__()
        # Two GINConv layers with MLPs, BatchNorm, ReLU, and dropout (HW2VEC style)
        self.mlp1 = MLP(in_dim, hid_dim)
        self.conv1 = GINConv(self.mlp1, train_eps=True)
        self.bn1 = nn.BatchNorm1d(hid_dim)

        self.mlp2 = MLP(hid_dim, hid_dim)
        self.conv2 = GINConv(self.mlp2, train_eps=True)
        self.bn2 = nn.BatchNorm1d(hid_dim)

        self.dropout = nn.Dropout(dropout)
        self.head = nn.Linear(hid_dim, out_dim)
        nn.init.xavier_uniform_(self.head.weight)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x, inplace=True)
        x = self.dropout(x)

        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x, inplace=True)
        x = self.dropout(x)

        x = global_mean_pool(x, batch)
        return self.head(x)

model = HW2VEC_GIN_Subgraph(feat_dim, hid_dim=HID_DIM, out_dim=2, dropout=0.4).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)

# class weights
cls_counts = np.bincount(labels)
w = torch.tensor([cls_counts.sum()/c for c in cls_counts], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=w)

# -----------------------
# Training loop
# -----------------------
def evaluate(loader):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for b in loader:
            b = b.to(DEVICE)
            out = model(b.x, b.edge_index, b.batch)
            p = out.argmax(dim=1).cpu().numpy()
            ys.extend(b.y.cpu().numpy())
            preds.extend(p)
    return np.array(ys), np.array(preds)

best_val, best_state = -1, None
for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss=0
    for b in train_loader:
        b = b.to(DEVICE)
        optimizer.zero_grad()
        out = model(b.x, b.edge_index, b.batch)
        loss = criterion(out, b.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch%5==0 or epoch==1:
        yv,pv = evaluate(val_loader)
        acc = accuracy_score(yv,pv)
        print(f"Epoch {epoch:03d} | Loss {total_loss/len(train_loader):.4f} | Val {acc:.4f}")
        if acc>best_val: 
            best_val=acc; best_state=model.state_dict().copy()

# load best
if best_state: model.load_state_dict(best_state)

# -----------------------
# Final test
# -----------------------
yt,pt = evaluate(test_loader)
print("\nFinal Evaluation (Subgraph-Level)")
print("=================================")
print(f"Test Accuracy: {accuracy_score(yt,pt):.4f}\n")
print("Classification Report:")
print(classification_report(yt,pt,digits=4))
print("Confusion Matrix:")
print(confusion_matrix(yt,pt))


Building subgraphs: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 674/674 [00:33<00:00, 19.98it/s]


Built 674 subgraphs (usable)
Train: 471, Val: 101, Test: 102
Epoch 001 | Loss 0.8773 | Val 0.3861
Epoch 005 | Loss 0.6591 | Val 0.6238
Epoch 010 | Loss 0.6547 | Val 0.3762
Epoch 015 | Loss 0.6205 | Val 0.4356
Epoch 020 | Loss 0.5790 | Val 0.6139
Epoch 025 | Loss 0.5428 | Val 0.6337
Epoch 030 | Loss 0.5144 | Val 0.6337
Epoch 035 | Loss 0.4601 | Val 0.6535
Epoch 040 | Loss 0.4255 | Val 0.7129
Epoch 045 | Loss 0.4230 | Val 0.7426
Epoch 050 | Loss 0.3813 | Val 0.7030
Epoch 055 | Loss 0.3525 | Val 0.8614
Epoch 060 | Loss 0.3503 | Val 0.8515
Epoch 065 | Loss 0.3179 | Val 0.8614
Epoch 070 | Loss 0.2979 | Val 0.9010
Epoch 075 | Loss 0.2914 | Val 0.9010
Epoch 080 | Loss 0.2619 | Val 0.9208

Final Evaluation (Subgraph-Level)
Test Accuracy: 0.9510

Classification Report:
              precision    recall  f1-score   support

           0     0.7143    0.9091    0.8000        11
           1     0.9886    0.9560    0.9721        91

    accuracy                         0.9510       102
   macro av

#### All in One, same perturbation across all metric.

In [2]:
# =====================================================================
# Subgraph-level: unified end-to-end robustness evaluation reusing one PGD
# =====================================================================
import time
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# -------------------------
# Parameters (tweak here)
# -------------------------
PER_CLASS        = 20        # number of subgraphs per class to perturb
EPSILON_PGD      = 4.0       # L2 radius for the PGD perturbation applied once and reused
ALPHA_PGD        = 1.0       # PGD step size
NUM_ITERS_PGD    = 30        # PGD iterations
FD_EPS           = 1e-3      # finite-difference epsilon
ARR_INITIAL_EPS  = 1e-3      # initial epsilon for ARR search
ARR_GROW1        = 1.25
ARR_GROW2        = 1.4
ARR_MAX_EPS      = 20.0
ARR_BS_ITERS     = 10
ARR_TRIALS       = 6
STAB_SIGMA       = 0.5       # stability noise sigma
STAB_SAMPLES     = 20
STAB_RELERR_RPTS = 5
SEED             = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# -------------------------
# Sanity
# -------------------------
required = ["model", "test_loader", "DEVICE"]
for r in required:
    if r not in globals():
        raise RuntimeError(f"Required variable '{r}' not found in the environment (must be defined earlier).")

model.to(DEVICE)
model.eval()

# -------------------------
# Extract test dataset list & labels (preserve loader order)
# -------------------------
dataset = list(test_loader.dataset) if hasattr(test_loader, "dataset") else list(test_loader)
n_test = len(dataset)
labels_all = np.array([int(d.y.item()) for d in dataset])
print(f"Total test subgraphs: {n_test}; class counts: {np.bincount(labels_all)}")

# -------------------------
# Select indices (same pool for all metrics)
# -------------------------
rng = np.random.default_rng(SEED)
selected = []
for cls in [0,1]:
    idxs = np.where(labels_all == cls)[0]
    if len(idxs)==0:
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False)
    selected.extend(chosen.tolist())
selected = np.array(sorted(selected), dtype=np.int64)
print("Selected perturbation pool counts:", {0:int((labels_all[selected]==0).sum()), 1:int((labels_all[selected]==1).sum())})

# small helper for single-graph batch vectors
def batch_for(x):
    return torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)

# -------------------------
# Build one shared PGD perturbation map (perturbed_map)
# - perturb each selected subgraph's node features (x) via PGD (L2 constrained)
# - store perturbed_map[idx] -> perturbed x (on DEVICE)
# -------------------------
perturbed_map = {}
orig_preds_selected = []
adv_preds_selected = []

print("\n--- Step A: Creating shared PGD perturbations for selected subgraphs ---")
t0 = time.time()
for idx in selected:
    data = dataset[int(idx)]
    x_orig = data.x.detach().clone().to(DEVICE)    # (N_nodes, feat_dim)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_orig.shape
    y_true = torch.tensor([int(data.y.item())], dtype=torch.long, device=DEVICE)

    # random initialization inside L2-ball (on flattened features)
    delta = torch.randn_like(x_orig, device=DEVICE)
    delta = delta * (EPSILON_PGD / (delta.view(-1).norm() + 1e-12))
    x_adv = (x_orig + delta).detach().clone().requires_grad_(True)

    # PGD loop (maximize CE loss to cause misclassification)
    for it in range(NUM_ITERS_PGD):
        out = model(x_adv, edge_index, batch_for(x_adv))           # [1, C]
        loss = F.cross_entropy(out, y_true)
        grad_x = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        gnorm = grad_x.view(-1).norm().item()
        if gnorm == 0:
            break
        step = (ALPHA_PGD * grad_x) / (gnorm + 1e-12)
        x_adv = (x_adv + step).detach()
        # project back to L2 ball (flattened)
        delta = x_adv - x_orig
        dnorm = delta.view(-1).norm().item()
        if dnorm > EPSILON_PGD:
            delta = delta * (EPSILON_PGD / (dnorm + 1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)

    x_adv_final = x_adv.detach().clone()
    perturbed_map[int(idx)] = x_adv_final

    # store flip stats
    with torch.no_grad():
        out_orig = model(x_orig, edge_index, batch_for(x_orig))
        out_adv  = model(x_adv_final, edge_index, batch_for(x_adv_final))
        orig_preds_selected.append(int(out_orig.argmax(dim=1).item()))
        adv_preds_selected.append(int(out_adv.argmax(dim=1).item()))

t1 = time.time()
print(f"? PGD perturbations done. Time: {t1-t0:.1f}s")
orig_preds_selected = np.array(orig_preds_selected)
adv_preds_selected  = np.array(adv_preds_selected)
num_flips = int((orig_preds_selected != adv_preds_selected).sum())
print(f"Selected subgraphs: {len(selected)}. Flipped after PGD: {num_flips} ({100.0 * num_flips / len(selected):.2f}%).")

# Also compute and show perturbed-only classifier behavior (for selected set)
with torch.no_grad():
    labels_sel = labels_all[selected]
    preds_perturbed = []
    for idx in selected:
        data = dataset[int(idx)]
        logits = model(perturbed_map[int(idx)], data.edge_index.to(DEVICE), batch_for(perturbed_map[int(idx)]))
        preds_perturbed.append(int(logits.argmax(dim=1).item()))
    preds_perturbed = np.array(preds_perturbed)

print("\nClassification on PERTURBED samples only (selected set):")
acc_sel = (preds_perturbed == labels_sel).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_sel, preds_perturbed, average='weighted', zero_division=0)
print(f"Accuracy (perturbed selected): {acc_sel*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")
print("Classification report (perturbed selected):")
print(classification_report(labels_sel, preds_perturbed, target_names=['clean','trojan'], digits=4))
print("Confusion matrix (perturbed selected):")
print(confusion_matrix(labels_sel, preds_perturbed, labels=[0,1]))

# -------------------------------------------------------------------------
# Utility: evaluate on perturbed selected set only (used by metrics to print)
# -------------------------------------------------------------------------
def eval_perturbed_selected(perturbed_map, selected_idxs):
    with torch.no_grad():
        y_true = []
        y_pred = []
        for idx in selected_idxs:
            data = dataset[int(idx)]
            lab = int(data.y.item())
            x_eval = perturbed_map[int(idx)]
            logits = model(x_eval, data.edge_index.to(DEVICE), batch_for(x_eval))
            y_true.append(lab)
            y_pred.append(int(logits.argmax(dim=1).item()))
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    acc = (y_true == y_pred).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    return y_true, y_pred, acc, prec, rec, f1

# -------------------------
# Metric 1: Jacobian Frobenius norm + FD relative error
# (computed at the perturbed inputs; flattened features)
# -------------------------
print("\n\n================ Metric: Jacobian Frobenius Norm + FD Relative Error (on perturbed inputs) ================")
per_sample_jac = []   # (idx,label, jac_frob, fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_adv.shape
    d = n_nodes * feat_dim

    # flatten
    x_flat = x_adv.view(-1).detach().clone().requires_grad_(True)

    def f_flat(x_in):
        x_mat = x_in.view_as(x_adv)
        out = model(x_mat, edge_index, batch_for(x_mat))   # [1, C]
        return out.squeeze(0)                              # (C,)

    # Jacobian: shape (C, d)
    try:
        J = torch.autograd.functional.jacobian(f_flat, x_flat)   # (C, d)
    except RuntimeError:
        # fallback: compute per-output jac rows
        logits0 = f_flat(x_flat).detach()
        C = logits0.shape[0]
        rows = []
        for c in range(C):
            def scalar_f(z, cidx=c):
                return f_flat(z)[cidx]
            r = torch.autograd.functional.jacobian(scalar_f, x_flat)
            rows.append(r.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    J = J.detach()
    jac_frob = float(torch.norm(J, p='fro').item())

    # FD relative error
    delta_fd = FD_EPS * torch.randn(d, device=DEVICE)
    pred_change = (J @ delta_fd)                      # (C,)
    f0 = f_flat(x_flat).detach()
    f0p = f_flat((x_flat + delta_fd)).detach()
    actual_change = f0p - f0
    if torch.norm(actual_change).item() == 0:
        rel_err = 0.0
    else:
        rel_err = float(torch.norm(pred_change - actual_change).item() / (torch.norm(actual_change).item() + 1e-8))

    per_sample_jac.append((int(idx), int(data.y.item()), jac_frob, rel_err))

# aggregate & print
arr = np.array([[i,l,j,r] for (i,l,j,r) in per_sample_jac], dtype=object)
if arr.size:
    clean_vals = arr[arr[:,1]==0][:,2].astype(float) if (arr[:,1]==0).any() else np.array([])
    troj_vals  = arr[arr[:,1]==1][:,2].astype(float) if (arr[:,1]==1).any() else np.array([])
    clean_errs = arr[arr[:,1]==0][:,3].astype(float) if (arr[:,1]==0).any() else np.array([])
    troj_errs  = arr[arr[:,1]==1][:,3].astype(float) if (arr[:,1]==1).any() else np.array([])

    def mean_std(a): return (a.mean(), a.std()) if len(a)>0 else (0.0,0.0)

    c_mean, c_std = mean_std(clean_vals)
    t_mean, t_std = mean_std(troj_vals)
    ce_mean, ce_std = mean_std(clean_errs)
    te_mean, te_std = mean_std(troj_errs)

    print("\nJacobian Frobenius norms & FD relative errors (aggregated) ON PERTURBED SAMPLES:")
    print(f" Clean subgraphs:  avg_norm={c_mean:.4f} ± {c_std:.4f}, avg_FDrel={ce_mean:.4e} ± {ce_std:.4e}")
    print(f" Trojan subgraphs: avg_norm={t_mean:.4f} ± {t_std:.4f}, avg_FDrel={te_mean:.4e} ± {te_std:.4e}")
    print("\nSample preview (first 6): (idx,label,jacobian_frob,FD_rel_err)")
    for t in per_sample_jac[:6]:
        print(t)
else:
    print("No Jacobian samples computed.")

# Print classification performance on perturbed selected set only
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Jacobian stage (re-used perturbed set):")
print(f"Accuracy: {acc_sel*100:.2f}%")
print(f"Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print("Classification report:")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print("Confusion matrix:")
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 2: Local Lipschitz (spectral norm of Jacobian) + FD relative error
# (Compute at perturbed inputs; flatten J to (C,d) and compute top singular value)
# -------------------------
print("\n\n================ Metric: Local Lipschitz (spectral norm) + FD Relative Error ================")
per_sample_lip = []  # (idx,label,sigma_max, fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_adv.shape
    d = n_nodes * feat_dim

    def f_local(x):
        return model(x, edge_index, batch_for(x)).squeeze(0)

    # Jacobian shape (C, N, F)
    try:
        J = torch.autograd.functional.jacobian(f_local, x_adv).detach()   # (C, N, F)
    except RuntimeError:
        # fallback per-output (slower)
        logits0 = f_local(x_adv).detach()
        C = logits0.shape[0]
        rows = []
        for c in range(C):
            def scalar_f(z, cidx=c):
                return f_local(z)[cidx]
            row = torch.autograd.functional.jacobian(scalar_f, x_adv)
            rows.append(row.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    C = J.shape[0]
    J_flat = J.reshape(C, d)   # (C, d)

    # spectral norm via SVD (largest singular value)
    try:
        _, S, _ = torch.linalg.svd(J_flat, full_matrices=False)
        sigma_max = float(S[0].item())
    except RuntimeError:
        # fallback via eigen
        JJT = (J_flat @ J_flat.T).cpu().numpy()
        eigvals = np.linalg.eigvalsh(JJT)
        sigma_max = float(np.sqrt(max(eigvals.max(), 0.0)))

    # FD relative error
    delta_fd = FD_EPS * torch.randn(d, device=DEVICE)
    pred_change = (J_flat @ delta_fd)               # (C,)
    f0 = f_local(x_adv).detach()
    f0p = f_local(x_adv + delta_fd.view_as(x_adv)).detach()
    actual_change = f0p - f0
    if torch.norm(actual_change).item() == 0:
        fd_rel = 0.0
    else:
        fd_rel = float(torch.norm(pred_change - actual_change).item() / (torch.norm(actual_change).item() + 1e-8))

    per_sample_lip.append((int(idx), int(data.y.item()), float(sigma_max), float(fd_rel)))

# aggregate & print
clean_stats = [p for p in per_sample_lip if p[1]==0]
troj_stats  = [p for p in per_sample_lip if p[1]==1]
def aggs(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ls = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = aggs(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = aggs(troj_stats)

print("\nLocal Lipschitz (spectral) & FD relative errors (on PERTURBED samples):")
print(f" Clean:  avg_L={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_L={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,sigma_max,fd_rel_err)")
for p in per_sample_lip[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Local Lipschitz stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 3: Hessian curvature proxy (||g||^2) + FD relative error
# (Compute gradient of log-prob for predicted class at perturbed inputs)
# -------------------------
print("\n\n================ Metric: Hessian curvature proxy (||g||^2) + FD Relative Error ================")
TRIALS_HESS_FD = 5
per_sample_hess = []   # (idx,label, lambda_proxy=||g||^2, avg_fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)

    # forward & predicted class (allow grad)
    logits = model(x_adv, edge_index, batch_for(x_adv)).squeeze(0)
    pred_class = int(logits.argmax().item())

    # compute log-prob of predicted class
    logp = F.log_softmax(logits, dim=0)[pred_class]
    # gradient wrt x_adv
    g = torch.autograd.grad(logp, x_adv, retain_graph=False, create_graph=False, allow_unused=False)[0].detach()
    lambda_proxy = float(g.norm(p=2).item() ** 2)

    # FD relative errors (multiple small deltas)
    rels = []
    for _ in range(TRIALS_HESS_FD):
        delta = FD_EPS * torch.randn_like(x_adv).to(DEVICE)
        gt_delta = float((g * delta).sum().item())
        pred_second = 0.5 * (gt_delta ** 2)
        # recompute logp at perturbed input
        logits_p = model((x_adv + delta).detach(), edge_index, batch_for(x_adv + delta)).squeeze(0)
        logp_p = F.log_softmax(logits_p, dim=0)[pred_class]
        actual_second = float((logp_p - F.log_softmax(logits, dim=0)[pred_class]).item() - gt_delta)
        rel_err = abs(pred_second - actual_second) / (abs(actual_second) + 1e-8)
        rels.append(rel_err)
    avg_rel_err = float(np.mean(rels))
    per_sample_hess.append((int(idx), int(data.y.item()), lambda_proxy, avg_rel_err))

# aggregate & print
clean_stats = [t for t in per_sample_hess if t[1]==0]
troj_stats  = [t for t in per_sample_hess if t[1]==1]
def summarize(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ls = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = summarize(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = summarize(troj_stats)
print("\nHessian-curvature proxy (||g||^2) & FD relative errors (on PERTURBED samples):")
print(f" Clean:  avg_lambda={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_lambda={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,lambda,FD_rel_err)")
for p in per_sample_hess[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Hessian stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 4: Prediction Margin + FD relative error (on perturbed inputs)
# -------------------------
print("\n\n================ Metric: Prediction Margin + FD Relative Error ================")
per_sample_margin = []   # (idx,label, margin, fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)

    with torch.no_grad():
        logits = model(x_adv, edge_index, batch_for(x_adv)).squeeze(0)
    pred_class = int(logits.argmax().item())
    pred_logit = float(logits[pred_class].item())
    other_logits = logits.clone()
    other_logits[pred_class] = -float('inf')
    second_max = float(other_logits.max().item())
    margin = pred_logit - second_max

    # FD check
    delta = FD_EPS * torch.randn_like(x_adv).to(DEVICE)
    with torch.no_grad():
        logits_p = model(x_adv + delta, edge_index, batch_for(x_adv)).squeeze(0)
    pred_logit_p = float(logits_p[pred_class].item())
    other_logits_p = logits_p.clone()
    other_logits_p[pred_class] = -float('inf')
    second_max_p = float(other_logits_p.max().item())
    margin_p = pred_logit_p - second_max_p
    rel_err = abs(margin - margin_p) / (abs(margin_p) + 1e-12)
    per_sample_margin.append((int(idx), int(data.y.item()), float(margin), float(rel_err)))

# aggregate & print
clean_stats = [p for p in per_sample_margin if p[1]==0]
troj_stats  = [p for p in per_sample_margin if p[1]==1]
def aggs_m(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ms = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ms.mean(), Ms.std(), Es.mean(), Es.std())

cM_mean, cM_std, cE_mean, cE_std = aggs_m(clean_stats)
tM_mean, tM_std, tE_mean, tE_std = aggs_m(troj_stats)
print("\nPrediction Margin stats (on PERTURBED samples):")
print(f" Clean:  avg_margin={cM_mean:.4f} ± {cM_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_margin={tM_mean:.4f} ± {tM_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,margin,fd_rel_err)")
for p in per_sample_margin[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Prediction Margin stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 5: Adversarial Robustness Radius (ARR) around perturbed points
# -------------------------
print("\n\n================ Metric: Adversarial Robustness Radius (ARR) ================")
def f_for_subgraph(x_tensor, data):
    with torch.no_grad():
        out = model(x_tensor, data.edge_index.to(DEVICE), batch_for(x_tensor))
    return out.squeeze(0)

def adversarial_radius_for_subgraph(data, x0, initial_epsilon=ARR_INITIAL_EPS, growth_factor=ARR_GROW1,
                                    max_epsilon=ARR_MAX_EPS, bs_iters=ARR_BS_ITERS, num_trials=ARR_TRIALS):
    """Estimate minimal L2 norm that flips prediction around x0."""
    x0 = x0.clone().detach().to(DEVICE)
    with torch.no_grad():
        base_out = model(x0, data.edge_index.to(DEVICE), batch_for(x0))
        y0 = int(base_out.argmax().item())

    def is_same(x):
        out = f_for_subgraph(x, data)
        return int(out.argmax().item()) == y0

    radii = []
    for _ in range(num_trials):
        d = torch.randn_like(x0).to(DEVICE)
        d = d / (d.view(-1).norm() + 1e-12)
        eps = initial_epsilon
        while eps < max_epsilon and is_same(x0 + eps * d):
            eps *= growth_factor
        if eps >= max_epsilon:
            radii.append(float(max_epsilon))
            continue
        low, high = eps / growth_factor, eps
        for _ in range(bs_iters):
            mid = 0.5 * (low + high)
            if is_same(x0 + mid * d):
                low = mid
            else:
                high = mid
        radii.append(float(high))
    return float(min(radii))

def adversarial_radius_relerr(data, x0):
    r1 = adversarial_radius_for_subgraph(data, x0, growth_factor=ARR_GROW1, num_trials=ARR_TRIALS)
    r2 = adversarial_radius_for_subgraph(data, x0, growth_factor=ARR_GROW2, num_trials=ARR_TRIALS)
    rel_err = abs(r1 - r2) / (abs(r2) + 1e-12)
    return r1, rel_err

class_names = ['clean','trojan']
class_adv_radius = {cn: [] for cn in class_names}
class_rel_errors = {cn: [] for cn in class_names}
all_rads = []; all_relerrs = []

t0 = time.time()
for i, idx in enumerate(selected):
    data = dataset[int(idx)]
    x0 = perturbed_map[int(idx)].clone().detach().to(DEVICE)
    r, rel = adversarial_radius_relerr(data, x0)
    lab = int(data.y.item())
    class_adv_radius[class_names[lab]].append(r)
    class_rel_errors[class_names[lab]].append(rel)
    all_rads.append(r); all_relerrs.append(rel)
    if (i+1)%10 == 0:
        print(f" processed {i+1}/{len(selected)} ...")
t1 = time.time()
print(f"? ARR computation done. Time elapsed: {t1-t0:.1f}s")

# reporting ARR
print("\nARR (Adversarial Robustness Radius) Stats (on perturbed selected samples):")
for cn in class_names:
    vals = class_adv_radius[cn]
    errs = class_rel_errors[cn]
    if vals:
        print(f" {cn:6s}: avg_radius={np.mean(vals):.4f} ± {np.std(vals):.4f}, avg_relerr={np.mean(errs):.4e} ± {np.std(errs):.4e}")
    else:
        print(f" {cn:6s}: -")
print(f" Overall: avg_radius={np.mean(all_rads):.4f} ± {np.std(all_rads):.4f}; avg_relerr={np.mean(all_relerrs):.4e} ± {np.std(all_relerrs):.4e}")

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - ARR stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 6: Stability Under Input Noise (SUIN) on perturbed subgraphs
# -------------------------
print("\n\n================ Metric: Stability Under Input Noise (SUIN) ================")
def stability_for_subgraph(idx, sigma, num_samples):
    data = dataset[int(idx)]
    base_x = perturbed_map[int(idx)].to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)
    batch = batch_for(base_x)
    with torch.no_grad():
        f_orig = model(base_x, edge_index, batch).squeeze(0)
    diffs = []
    for _ in range(num_samples):
        noise = sigma * torch.randn_like(base_x).to(DEVICE)
        f_noisy = model(base_x + noise, edge_index, batch).squeeze(0)
        diffs.append(torch.norm(f_noisy - f_orig).item())
    return float(np.mean(diffs))

t0 = time.time()
per_sample_stab = []  # (idx,label, stability, rel_err)
for i, idx in enumerate(selected):
    s_val = stability_for_subgraph(idx, STAB_SIGMA, STAB_SAMPLES)
    # relative error by repeating
    re_vals = [stability_for_subgraph(idx, STAB_SIGMA, STAB_SAMPLES) for _ in range(STAB_RELERR_RPTS)]
    s_ref = float(np.mean(re_vals))
    rel_err = abs(s_val - s_ref) / (abs(s_ref) + 1e-12)
    per_sample_stab.append((int(idx), int(labels_all[idx]), float(s_val), float(rel_err)))
    if (i+1) % 10 == 0:
        print(f" processed {i+1}/{len(selected)} ...")
t1 = time.time()
print(f"? SUIN done. Time elapsed: {t1-t0:.1f}s")

# aggregate & print
clean_stats = [p for p in per_sample_stab if p[1]==0]
troj_stats  = [p for p in per_sample_stab if p[1]==1]
def aggs_s(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ss = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ss.mean(), Ss.std(), Es.mean(), Es.std())

cS_mean, cS_std, cE_mean, cE_std = aggs_s(clean_stats)
tS_mean, tS_std, tE_mean, tE_std = aggs_s(troj_stats)
print("\nStability Under Input Noise (on perturbed selected samples):")
print(f" Clean: avg_stability={cS_mean:.4f} ± {cS_std:.4f}, avg_relerr={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan:avg_stability={tS_mean:.4f} ± {tS_std:.4f}, avg_relerr={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,stability,rel_err)")
for p in per_sample_stab[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - SUIN stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

print("\n\nAll metrics computed on the SAME selected subgraphs and the SAME PGD perturbations (perturbed_map).")
print("You can adjust PER_CLASS, EPSILON_PGD, ALPHA_PGD, NUM_ITERS_PGD to change attack strength.")


Total test subgraphs: 102; class counts: [11 91]
Selected perturbation pool counts: {0: 11, 1: 20}

--- Step A: Creating shared PGD perturbations for selected subgraphs ---
? PGD perturbations done. Time: 101.2s
Selected subgraphs: 31. Flipped after PGD: 2 (6.45%).

Classification on PERTURBED samples only (selected set):
Accuracy (perturbed selected): 83.87%
Precision: 0.8369, Recall: 0.8387, F1: 0.8368
Classification report (perturbed selected):
              precision    recall  f1-score   support

       clean     0.8000    0.7273    0.7619        11
      trojan     0.8571    0.9000    0.8780        20

    accuracy                         0.8387        31
   macro avg     0.8286    0.8136    0.8200        31
weighted avg     0.8369    0.8387    0.8368        31

Confusion matrix (perturbed selected):
[[ 8  3]
 [ 2 18]]



Jacobian Frobenius norms & FD relative errors (aggregated) ON PERTURBED SAMPLES:
 Clean subgraphs:  avg_norm=0.5544 ± 1.2676, avg_FDrel=1.2866e-01 ± 1.5905e-01
