#### Training Evaluation

In [23]:
# train_subgraph_gnn_fixed.py
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv, global_mean_pool

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 80
LR = 1e-3
HID_DIM = 64

# -----------------------
# Load node features
# -----------------------
nodes_df = pd.read_csv("GNNDatasets/node.csv")
nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# one-hot gate_type
if "gate_type" in nodes_df.columns:
    gate_ohe = pd.get_dummies(nodes_df["gate_type"], prefix="gt")
    nodes_feat_df = pd.concat([nodes_df.drop(columns=["gate_type"]), gate_ohe], axis=1)
else:
    nodes_feat_df = nodes_df.copy()

meta_cols = {"uid","node","circuit_name","label","label_node","label_graph","label_subgraph","folder"}
num_cols = [c for c in nodes_feat_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(nodes_feat_df[c])]

uid_to_feat = {}
for _, r in nodes_feat_df.iterrows():
    uid_to_feat[r["uid"]] = r[num_cols].astype(float).values

scaler = StandardScaler().fit(np.stack(list(uid_to_feat.values())))
for k in list(uid_to_feat.keys()):
    uid_to_feat[k] = scaler.transform(uid_to_feat[k].reshape(1,-1)).reshape(-1)

feat_dim = len(uid_to_feat[list(uid_to_feat.keys())[0]])

# -----------------------
# Merge edge CSVs
# -----------------------
edge_files = [
    "GNNDatasets/node_edges.csv",
    "GNNDatasets/subgraph_edges_andxor.csv",
    "GNNDatasets/subgraph_edges_countermux.csv",
    "GNNDatasets/subgraph_edges_fsmor.csv",
]

edges_by_circuit = defaultdict(list)
for ef in edge_files:
    df = pd.read_csv(ef)
    for _, r in df.iterrows():
        edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# -----------------------
# Build dataset
# -----------------------
sub_df = pd.read_csv("GNNDatasets/subgraph.csv")
data_list, labels = [], []

for idx, row in tqdm(sub_df.iterrows(), total=len(sub_df), desc="Building subgraphs"):
    ckt = row["circuit_name"]
    lbl = int(row.get("label_subgraph", row.get("label", 0)))
    
    if ckt not in edges_by_circuit:
        continue
    
    # collect node uids from node.csv that belong to this circuit
    sub_nodes = nodes_df[nodes_df["circuit_name"]==ckt]["node"].tolist()
    if not sub_nodes: 
        continue
    
    uid_map = {n:i for i,n in enumerate(sub_nodes)}
    x_list = []
    for n in sub_nodes:
        uid = f"{ckt}::{n}"
        if uid in uid_to_feat:
            x_list.append(uid_to_feat[uid])
        else:
            x_list.append(np.zeros(feat_dim))
    x = torch.tensor(np.vstack(x_list), dtype=torch.float)
    
    # build edge_index
    edge_idx = [[], []]
    for u,v in edges_by_circuit[ckt]:
        if u in uid_map and v in uid_map:
            edge_idx[0].append(uid_map[u]); edge_idx[1].append(uid_map[v])
            edge_idx[0].append(uid_map[v]); edge_idx[1].append(uid_map[u])
    if not edge_idx[0]:
        continue
    edge_index = torch.tensor(edge_idx, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    data_list.append(data)
    labels.append(lbl)

print(f"Built {len(data_list)} subgraphs (usable)")

# -----------------------
# Split
# -----------------------
labels = np.array(labels)
idxs = np.arange(len(data_list))
train_idx, temp_idx, y_train, y_temp = train_test_split(idxs, labels, test_size=0.3, 
                                                        stratify=labels, random_state=RANDOM_SEED)
val_idx, test_idx, y_val, y_test = train_test_split(temp_idx, y_temp, test_size=0.5,
                                                    stratify=y_temp, random_state=RANDOM_SEED)

train_loader = DataLoader([data_list[i] for i in train_idx], batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader([data_list[i] for i in val_idx], batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader([data_list[i] for i in test_idx], batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

# -----------------------
# Model
# -----------------------
class SubgraphClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hid_dim)
        self.conv2 = SAGEConv(hid_dim, hid_dim)
        self.lin = nn.Linear(hid_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.4, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = SubgraphClassifier(feat_dim).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)

# class weights
cls_counts = np.bincount(labels)
w = torch.tensor([cls_counts.sum()/c for c in cls_counts], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=w)

# -----------------------
# Training loop
# -----------------------
def evaluate(loader):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for b in loader:
            b = b.to(DEVICE)
            out = model(b.x, b.edge_index, b.batch)
            p = out.argmax(dim=1).cpu().numpy()
            ys.extend(b.y.cpu().numpy())
            preds.extend(p)
    return np.array(ys), np.array(preds)

best_val, best_state = -1, None
for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss=0
    for b in train_loader:
        b = b.to(DEVICE)
        optimizer.zero_grad()
        out = model(b.x, b.edge_index, b.batch)
        loss = criterion(out, b.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch%5==0 or epoch==1:
        yv,pv = evaluate(val_loader)
        acc = accuracy_score(yv,pv)
        print(f"Epoch {epoch:03d} | Loss {total_loss/len(train_loader):.4f} | Val {acc:.4f}")
        if acc>best_val: 
            best_val=acc; best_state=model.state_dict().copy()

# load best
if best_state: model.load_state_dict(best_state)

# -----------------------
# Final test
# -----------------------
yt,pt = evaluate(test_loader)
print("\nFinal Evaluation (Subgraph-Level)")
print("=================================")
print(f"Test Accuracy: {accuracy_score(yt,pt):.4f}\n")
print("Classification Report:")
print(classification_report(yt,pt,digits=4))
print("Confusion Matrix:")
print(confusion_matrix(yt,pt))


Building subgraphs: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 674/674 [00:21<00:00, 32.00it/s]


Built 674 subgraphs (usable)
Train: 471, Val: 101, Test: 102
Epoch 001 | Loss 0.7464 | Val 0.2673
Epoch 005 | Loss 0.6193 | Val 0.6238
Epoch 010 | Loss 0.5896 | Val 0.5941
Epoch 015 | Loss 0.5101 | Val 0.6832
Epoch 020 | Loss 0.3784 | Val 0.8020
Epoch 025 | Loss 0.2963 | Val 0.8317
Epoch 030 | Loss 0.2182 | Val 0.8812
Epoch 035 | Loss 0.1549 | Val 0.9208
Epoch 040 | Loss 0.1265 | Val 0.9208
Epoch 045 | Loss 0.1028 | Val 0.9208
Epoch 050 | Loss 0.0897 | Val 0.9208
Epoch 055 | Loss 0.0840 | Val 0.9208
Epoch 060 | Loss 0.0851 | Val 0.9307
Epoch 065 | Loss 0.0641 | Val 0.9307
Epoch 070 | Loss 0.0713 | Val 0.9604
Epoch 075 | Loss 0.0604 | Val 0.9604
Epoch 080 | Loss 0.0645 | Val 0.9604

Final Evaluation (Subgraph-Level)
Test Accuracy: 0.9902

Classification Report:
              precision    recall  f1-score   support

           0     0.9167    1.0000    0.9565        11
           1     1.0000    0.9890    0.9945        91

    accuracy                         0.9902       102
   macro av

#### Jacobain 

In [4]:
# Subgraph-level: PGD-first ? evaluation ? Jacobian + FD relative error
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ---------------------- PARAMETERS ----------------------
PER_CLASS = 20        # # subgraphs per class to perturb (adjust)
EPSILON = 4.0         # L2 radius for PGD perturbation (on flattened subgraph features)
ALPHA = 1.0           # PGD step size (normalized)
NUM_ITERS = 30        # PGD iterations
FD_EPS = 1e-3         # finite-difference epsilon for Jacobian check
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# ---------------------- Sanity check ----------------------
required = ["model", "test_loader", "DEVICE"]
for v in required:
    if v not in globals():
        raise RuntimeError(f"Required variable '{v}' not found in the environment.")

model.to(DEVICE)
model.eval()

# ---------------------- Get test dataset as list ----------------------
test_data_list = list(test_loader.dataset)   # preserve the same order used by DataLoader
test_labels = np.array([int(d.y.item()) for d in test_data_list])
n_test = len(test_data_list)
print(f"Test subgraphs: {n_test}; class counts: {np.bincount(test_labels)}")

# ---------------------- Select subgraphs to perturb (PER_CLASS per class) ----------------------
rng = np.random.default_rng(SEED)
selected_idxs = []
for cls in [0, 1]:
    idxs = np.where(test_labels == cls)[0]
    if len(idxs) == 0:
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False)
    selected_idxs.extend(chosen.tolist())
selected_idxs = np.array(sorted(selected_idxs), dtype=np.int64)
print("Selected counts:", {0:int((test_labels[selected_idxs]==0).sum()), 1:int((test_labels[selected_idxs]==1).sum())})

# ---------------------- PGD perturbation for selected subgraphs ----------------------
perturbed_map = {}  # idx -> perturbed x tensor (on DEVICE)
print("\nRunning PGD perturbation on selected subgraphs (this may take a while)...")

for idx in selected_idxs:
    data = test_data_list[int(idx)]
    x_orig = data.x.detach().to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_orig.shape

    # initialize adv example (random direction scaled to EPSILON)
    delta = torch.randn_like(x_orig, device=DEVICE)
    delta = delta * (EPSILON / (delta.view(-1).norm() + 1e-12))
    x_adv = (x_orig + delta).detach().clone().requires_grad_(True)

    # batch vector for single-graph forward
    batch_zero = torch.zeros(n_nodes, dtype=torch.long, device=DEVICE)

    y_true = torch.tensor([int(data.y.item())], dtype=torch.long, device=DEVICE)

    for it in range(NUM_ITERS):
        # forward
        out = model(x_adv, edge_index, batch_zero)            # shape [1, C]
        loss = F.cross_entropy(out, y_true)
        # gradient wrt inputs
        grad_x = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        gnorm = grad_x.view(-1).norm().item()
        if gnorm == 0:
            break
        step = (ALPHA * grad_x) / (gnorm + 1e-12)
        x_adv = (x_adv + step).detach()
        # project to L2 ball around x_orig
        delta = x_adv - x_orig
        dnorm = delta.view(-1).norm().item()
        if dnorm > EPSILON:
            delta = delta * (EPSILON / (dnorm + 1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)

    perturbed_map[int(idx)] = x_adv.detach().clone()

print("? PGD perturbation finished for selected subgraphs.")

# ---------------------- Evaluate model on full test set (perturbed selected + originals) ----------------------
y_true_list = []
y_pred_list = []

with torch.no_grad():
    for i, data in enumerate(test_data_list):
        x_in = perturbed_map[i] if i in perturbed_map else data.x.to(DEVICE)
        edge_index = data.edge_index.to(DEVICE)
        batch_zero = torch.zeros(x_in.size(0), dtype=torch.long, device=DEVICE)
        out = model(x_in, edge_index, batch_zero)         # [1, C]
        pred = int(out.argmax(dim=1).item())
        y_pred_list.append(pred)
        y_true_list.append(int(data.y.item()))

y_true_arr = np.array(y_true_list)
y_pred_arr = np.array(y_pred_list)

print("\n============= Robustness Evaluation (Full Test Set: perturbed selected + rest original) =============")
acc = (y_true_arr == y_pred_arr).mean()
prec, rec, f1, _ = precision_recall_fscore_support(y_true_arr, y_pred_arr, average='weighted', zero_division=0)
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification report:")
print(classification_report(y_true_arr, y_pred_arr, labels=[0,1], target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(y_true_arr, y_pred_arr, labels=[0,1]))

# how many selected flipped?
orig_preds = []
adv_preds = []
with torch.no_grad():
    for idx in selected_idxs:
        data = test_data_list[int(idx)]
        edge_index = data.edge_index.to(DEVICE)
        batch_zero = torch.zeros(data.x.size(0), dtype=torch.long, device=DEVICE)
        # original
        out_orig = model(data.x.to(DEVICE), edge_index, batch_zero)
        orig_preds.append(int(out_orig.argmax(dim=1).item()))
        # adv
        out_adv = model(perturbed_map[int(idx)], edge_index, batch_zero)
        adv_preds.append(int(out_adv.argmax(dim=1).item()))
orig_preds = np.array(orig_preds); adv_preds = np.array(adv_preds)
num_flips = int((orig_preds != adv_preds).sum())
print(f"\nSelected subgraphs: {len(selected_idxs)}. Flipped after attack: {num_flips} ({100.0*num_flips/len(selected_idxs):.2f}%).")

# ---------------------- Jacobian & finite-difference relative error (computed AT the PERTURBED input) ----------------------
print("\nComputing Jacobian norms & FD relative error on the PERTURBED selected subgraphs...")
per_sample_info = []   # tuples: (idx, label, jacobian_fro_norm, fd_rel_error)
for idx in selected_idxs:
    data = test_data_list[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)  # perturbed input
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_adv.shape
    batch_zero = torch.zeros(n_nodes, dtype=torch.long, device=DEVICE)

    # flatten input
    x_flat = x_adv.view(-1).detach().clone().requires_grad_(True)

    def f_flat(x_flat_input):
        x_mat = x_flat_input.view_as(x_adv)
        out = model(x_mat, edge_index, batch_zero)   # [1, C]
        return out.squeeze(0)                        # (C,)

    # Jacobian: shape (C, D) where D = n_nodes * feat_dim
    try:
        J = torch.autograd.functional.jacobian(f_flat, x_flat)    # (C, D)
    except RuntimeError:
        # fallback: compute per-output jacobian rows to avoid memory blowup (slower)
        C = int(model(x_adv, edge_index, batch_zero).shape[1])
        rows = []
        for out_i in range(C):
            def scalar_f(z, i=out_i):
                return f_flat(z)[i]
            r = torch.autograd.functional.jacobian(scalar_f, x_flat)
            rows.append(r.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    J = J.detach()
    jac_frob = float(torch.norm(J, p='fro').item())

    # finite-difference relative error
    delta_fd = FD_EPS * torch.randn_like(x_flat).to(DEVICE)
    pred_change = J @ delta_fd                          # (C,)
    f0 = f_flat(x_flat).detach()
    f0p = f_flat(x_flat + delta_fd).detach()
    actual_change = f0p - f0
    rel_err = float((torch.norm(pred_change - actual_change) / (torch.norm(actual_change) + 1e-8)).item())

    per_sample_info.append((int(idx), int(data.y.item()), jac_frob, rel_err))

# ---------------------- Aggregate & print summary ----------------------
arr = np.array([[i,l,j,r] for (i,l,j,r) in per_sample_info], dtype=object)
if arr.size:
    clean_vals = arr[arr[:,1]==0][:,2].astype(float) if (arr[:,1]==0).any() else np.array([])
    troj_vals  = arr[arr[:,1]==1][:,2].astype(float) if (arr[:,1]==1).any() else np.array([])
    clean_errs = arr[arr[:,1]==0][:,3].astype(float) if (arr[:,1]==0).any() else np.array([])
    troj_errs  = arr[arr[:,1]==1][:,3].astype(float) if (arr[:,1]==1).any() else np.array([])

    def mean_std(a): return (a.mean(), a.std()) if len(a)>0 else (0.0,0.0)

    c_mean, c_std = mean_std(clean_vals)
    t_mean, t_std = mean_std(troj_vals)
    ce_mean, ce_std = mean_std(clean_errs)
    te_mean, te_std = mean_std(troj_errs)

    print("\nJacobian norms & FD relative errors (aggregated):")
    print(f" Clean subgraphs:  avg_norm={c_mean:.4f} ± {c_std:.4f}, avg_FDrel={ce_mean:.4e} ± {ce_std:.4e}")
    print(f" Trojan subgraphs: avg_norm={t_mean:.4f} ± {t_std:.4f}, avg_FDrel={te_mean:.4e} ± {te_std:.4e}")

    print("\nSample preview (first 6): (idx,label,jacobian_frob,FD_rel_err)")
    for t in per_sample_info[:6]:
        print(t)
else:
    print("No Jacobian samples computed (selected set empty).")

print("\nDone. (Order: PGD perturbation -> full-test evaluation -> Jacobian computed at perturbed inputs.)")


Test subgraphs: 102; class counts: [11 91]
Selected counts: {0: 11, 1: 20}

Running PGD perturbation on selected subgraphs (this may take a while)...
? PGD perturbation finished for selected subgraphs.

Accuracy: 96.08%
Precision: 0.9651, Recall: 0.9608, F1: 0.9622

Classification report:
              precision    recall  f1-score   support

       clean     0.7692    0.9091    0.8333        11
      trojan     0.9888    0.9670    0.9778        91

    accuracy                         0.9608       102
   macro avg     0.8790    0.9381    0.9056       102
weighted avg     0.9651    0.9608    0.9622       102

Confusion Matrix:
[[10  1]
 [ 3 88]]

Selected subgraphs: 31. Flipped after attack: 3 (9.68%).

Computing Jacobian norms & FD relative error on the PERTURBED selected subgraphs...

Jacobian norms & FD relative errors (aggregated):
 Clean subgraphs:  avg_norm=0.4876 ± 0.1973, avg_FDrel=2.7351e-02 ± 5.7933e-02
 Trojan subgraphs: avg_norm=0.9912 ± 2.6177, avg_FDrel=7.6477e-01 ± 2.442

#### Local Lipschitz Constants

In [6]:
# ----------------------- Local Lipschitz (Subgraph-Level) -----------------------
import torch, numpy as np, torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ----- parameters (tune if needed) -----
PER_CLASS = 20      # up to this many per class from test set (min with available)
FD_EPS = 1e-3
EPSILON = 5.0       # L2 budget for PGD (feature units)
ALPHA = 1.0         # PGD step size
NUM_ITERS = 30      # PGD iters
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# ----- quick sanity -----
if 'model' not in globals() or 'test_loader' not in globals():
    raise RuntimeError("Required variables `model` and `test_loader` must exist in the notebook.")

model.to(DEVICE)
model.eval()

# ----- get dataset (robust to DataLoader type) -----
dataset = test_loader.dataset
n_test = len(dataset)

# extract labels array
labels_np = np.array([int(d.y.item() if hasattr(d.y, 'item') else d.y[0].item()) for d in dataset])

# ----- select indices (PER_CLASS per class) -----
rng = np.random.default_rng(SEED)
selected = []
for cls in [0, 1]:
    idxs = np.where(labels_np == cls)[0].tolist()
    if len(idxs) == 0:
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False).tolist()
    selected.extend(chosen)
selected = np.array(selected, dtype=np.int64)
print("Selected perturbation pool:", {0:int((labels_np[selected]==0).sum()), 1:int((labels_np[selected]==1).sum())})

# ----- perturb each selected subgraph using top-right singular vector init + PGD -----
perturbed_map = {}   # idx -> perturbed_x (on DEVICE)
orig_preds_list = []
adv_preds_list = []

print("\nRunning Lipschitz-directed PGD for selected subgraphs (this may take a while)...")
for idx in selected:
    data = dataset[int(idx)]
    x_orig = data.x.detach().clone().to(DEVICE)            # (num_nodes, feat_dim)
    edge_index = data.edge_index.to(DEVICE)
    y_true = torch.tensor([int(data.y.item())], dtype=torch.long, device=DEVICE)

    # helper: produce batch vector of zeros for single graph
    def batch_for(x):
        return torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)

    # define f_local that accepts x (num_nodes,feat) and returns logits (num_classes,)
    def f_local(x):
        return model(x, edge_index, batch_for(x)).squeeze(0)

    # Compute Jacobian at x_orig (shape: (num_classes, num_nodes, feat_dim))
    # NOTE: we compute Jacobian once per graph to get singular vector init
    try:
        J = torch.autograd.functional.jacobian(f_local, x_orig)    # (C, N, F)
    except RuntimeError:
        # if autograd fails due to graph size, compute per-output jac manually (slower)
        logits0 = f_local(x_orig).detach()
        C = logits0.shape[0]
        rows = []
        for c in range(C):
            def scalar_f(x, cidx=c):
                return f_local(x)[cidx]
            row = torch.autograd.functional.jacobian(scalar_f, x_orig)  # (N, F)
            rows.append(row.unsqueeze(0))
        J = torch.cat(rows, dim=0)  # (C, N, F)

    # flatten J to shape (C, d) where d = N * F
    C = J.shape[0]
    d = int(J.shape[1] * J.shape[2])
    J_flat = J.reshape(C, d)

    # get top-right singular vector v (length d)
    try:
        _, S, Vh = torch.linalg.svd(J_flat, full_matrices=False)
        v = Vh[0, :].detach().to(DEVICE)   # (d,)
    except RuntimeError:
        # fallback random direction
        v = torch.randn(d, device=DEVICE)
    if v.norm().item() > 0:
        v = v / (v.norm() + 1e-12)
    else:
        v = torch.randn_like(v).to(DEVICE); v = v / (v.norm() + 1e-12)

    # initialize adversarial x_adv (reshape v -> (N,F))
    v_mat = v.view(x_orig.shape[0], x_orig.shape[1])
    x_adv = (x_orig + 0.5 * EPSILON * v_mat).detach().clone().requires_grad_(True)

    # PGD loop maximize CE loss (untargeted)
    for it in range(NUM_ITERS):
        logits = model(x_adv, edge_index, batch_for(x_adv))
        loss = F.cross_entropy(logits, y_true)
        grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        gn = grad.view(-1)
        gnorm = gn.norm().item()
        if gnorm == 0:
            break
        step = ALPHA * (grad / (gnorm + 1e-12))
        x_adv = (x_adv + step).detach()
        # project to L2-ball around x_orig
        delta = (x_adv - x_orig).view(-1)
        dnorm = delta.norm().item()
        if dnorm > EPSILON:
            delta = delta * (EPSILON / (dnorm + 1e-12))
            x_adv = (x_orig + delta.view_as(x_orig)).detach()
        x_adv = x_adv.requires_grad_(True)

    # finalize
    x_adv_final = x_adv.detach()
    perturbed_map[int(idx)] = x_adv_final

    # store preds for flip statistics
    with torch.no_grad():
        p_orig = f_local(x_orig).argmax().item()
        p_adv  = f_local(x_adv_final).argmax().item()
    orig_preds_list.append(p_orig)
    adv_preds_list.append(p_adv)

print("? Finished perturbations.")

# ----- Evaluate full test set using perturbed subgraphs for selected indices -----
print("\n============= Robustness Evaluation (Full Test Set) =============")
preds_all = []
labels_all = []
with torch.no_grad():
    for i, data in enumerate(dataset):
        x = data.x.detach().clone().to(DEVICE)
        edge_index = data.edge_index.to(DEVICE)
        lab = int(data.y.item())
        if int(i) in perturbed_map:
            x_eval = perturbed_map[int(i)]
        else:
            x_eval = x
        logits = model(x_eval, edge_index, batch_for(x_eval))
        preds_all.append(int(logits.argmax(dim=1).item()))
        labels_all.append(lab)

preds_all = np.array(preds_all)
labels_all = np.array(labels_all)

acc = (preds_all == labels_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_all, preds_all, average='weighted', zero_division=0)
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification Report:")
print(classification_report(labels_all, preds_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_all, preds_all, labels=[0,1]))

# flip statistics
num_flips = int((np.array(orig_preds_list) != np.array(adv_preds_list)).sum())
print(f"\nSelected subgraphs: {len(selected)}. Flipped after attack: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ----- Now compute Local Lipschitz constants and FD relative error ON THE PERTURBED SUBGRAPHS -----
print("\nComputing Local Lipschitz constants + FD relative errors (on perturbed subgraphs)...")
per_sample_info = []   # tuples (dataset_idx, label, sigma_max, fd_rel_err)
for i, idx in enumerate(selected):
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)
    label = int(data.y.item())

    # define f_local on perturbed input
    def f_local_pert(x):
        return model(x, edge_index, batch_for(x)).squeeze(0)

    # jacobian shape (C, N, F)
    try:
        J = torch.autograd.functional.jacobian(f_local_pert, x_adv).detach()   # (C, N, F)
    except RuntimeError:
        # fallback per-output jac (slower)
        logits0 = f_local_pert(x_adv).detach()
        C = logits0.shape[0]
        rows = []
        for c in range(C):
            def scalar_f(x, cidx=c):
                return f_local_pert(x)[cidx]
            row = torch.autograd.functional.jacobian(scalar_f, x_adv)
            rows.append(row.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    C = J.shape[0]
    d = int(J.shape[1] * J.shape[2])
    J_flat = J.reshape(C, d)

    # spectral norm via SVD (largest singular value)
    try:
        _, S, _ = torch.linalg.svd(J_flat, full_matrices=False)
        sigma_max = float(S[0].item())
    except RuntimeError:
        # fallback via eigen on JJT
        JJT = (J_flat @ J_flat.T).cpu().numpy()
        eigvals = np.linalg.eigvalsh(JJT)
        sigma_max = float(np.sqrt(max(eigvals.max(), 0.0)))

    # finite-difference relative error
    delta_fd = FD_EPS * torch.randn_like(x_adv).to(DEVICE).view(-1)   # length d
    pred_change = (J_flat @ delta_fd)                 # (C,)
    f0 = f_local_pert(x_adv).detach()
    f0p = f_local_pert((x_adv + delta_fd.view_as(x_adv))).detach()
    actual_change = f0p - f0
    fd_rel_err = (torch.norm(pred_change - actual_change) / (torch.norm(actual_change) + 1e-8)).item()

    per_sample_info.append((int(idx), label, float(sigma_max), float(fd_rel_err)))

# aggregate and print
clean_stats = [p for p in per_sample_info if p[1] == 0]
troj_stats  = [p for p in per_sample_info if p[1] == 1]

def aggs(stats):
    if not stats:
        return (0.0, 0.0, 0.0, 0.0)
    Ls = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = aggs(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = aggs(troj_stats)

print("\n--- Local Lipschitz Constants (on perturbed subgraphs) ---")
print(f" Clean:  avg_L={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_L={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,L,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)


Selected perturbation pool: {0: 11, 1: 20}

Running Lipschitz-directed PGD for selected subgraphs (this may take a while)...
? Finished perturbations.

Accuracy: 96.08%
Precision: 0.9651, Recall: 0.9608, F1: 0.9622

Classification Report:
              precision    recall  f1-score   support

       clean     0.7692    0.9091    0.8333        11
      trojan     0.9888    0.9670    0.9778        91

    accuracy                         0.9608       102
   macro avg     0.8790    0.9381    0.9056       102
weighted avg     0.9651    0.9608    0.9622       102

Confusion Matrix:
[[10  1]
 [ 3 88]]

Selected subgraphs: 31. Flipped after attack: 3 (9.68%).

Computing Local Lipschitz constants + FD relative errors (on perturbed subgraphs)...

--- Local Lipschitz Constants (on perturbed subgraphs) ---
 Clean:  avg_L=0.4934 ± 0.2302, avg_FDrel=1.4172e-02 ± 1.3608e-02
 Trojan: avg_L=1.0035 ± 2.6309, avg_FDrel=7.5908e-01 ± 2.5349e+00

Sample preview (first 6): (idx,label,L,FD_rel_err)
(14, 0, 0

#### Hessian-based Curvature Analysis

In [8]:
# ----------------------- Hessian-Based Curvature (Subgraph-Level) -----------------------
import torch, numpy as np, torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ----- parameters -----
PER_CLASS = 20        # subgraphs per class
FD_EPS = 5e-3         # finite diff epsilon
TRIALS_PER_GRAPH = 5  # #trials for FD error
PERT_P = 10.0          # L2 magnitude for Hessian-aligned perturbation
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

if 'model' not in globals() or 'test_loader' not in globals():
    raise RuntimeError("Need `model` and `test_loader` in environment.")

model.to(DEVICE)
model.eval()

dataset = test_loader.dataset
n_test = len(dataset)
labels_np = np.array([int(d.y.item()) for d in dataset])

# ----- select PER_CLASS subgraphs per class -----
rng = np.random.default_rng(SEED)
selected = []
for cls in [0, 1]:
    idxs = np.where(labels_np == cls)[0].tolist()
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False).tolist()
    selected.extend(chosen)
selected = np.array(selected, dtype=np.int64)
print(f"Selected pool: clean={int((labels_np[selected]==0).sum())}, trojan={int((labels_np[selected]==1).sum())}")

# ----- base predictions -----
def batch_for(x): return torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)
with torch.no_grad():
    base_preds = []
    for data in dataset:
        logits = model(data.x.to(DEVICE), data.edge_index.to(DEVICE), batch_for(data.x.to(DEVICE)))
        base_preds.append(int(logits.argmax()))
base_preds = np.array(base_preds)

# ----- compute gradient norms + FD relative errors -----
per_sample_info = []  # (idx, label, lambda_max, avg_rel_err)

print("\nComputing Hessian curvature proxies on selected subgraphs...")
for idx in selected:
    data = dataset[int(idx)]
    x0 = data.x.detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)

    # predicted class at x0
    with torch.no_grad():
        logits = model(x0, edge_index, batch_for(x0))
    pred_class = int(logits.argmax().item())

    def h(x):
        return F.log_softmax(model(x, edge_index, batch_for(x)).squeeze(0), dim=0)[pred_class]

    h0 = h(x0)
    g = torch.autograd.grad(h0, x0, retain_graph=False, create_graph=False)[0].detach()

    lambda_max = float(g.norm(p=2).item() ** 2)

    # FD relative error
    errs = []
    for _ in range(TRIALS_PER_GRAPH):
        delta = FD_EPS * torch.randn_like(x0).to(DEVICE)
        gt_delta = float((g * delta).sum().item())
        pred_second = 0.5 * (gt_delta ** 2)
        actual_second = float((h(x0 + delta) - h0 - (g * delta).sum()).item())
        rel_err = abs(pred_second - actual_second) / (abs(actual_second) + 1e-8)
        errs.append(rel_err)
    avg_rel_err = float(np.mean(errs))

    per_sample_info.append((int(idx), int(data.y.item()), lambda_max, avg_rel_err))

# ----- build Hessian-aligned perturbations -----
print("\nConstructing Hessian-aligned perturbations...")
perturbed_map = {}
for (idx, label, lambda_val, avg_rel_err) in per_sample_info:
    data = dataset[int(idx)]
    x0 = data.x.detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)

    def h(x):
        return F.log_softmax(model(x, edge_index, batch_for(x)).squeeze(0), dim=0)[int(data.y.item())]

    g = torch.autograd.grad(h(x0), x0, retain_graph=False, create_graph=False)[0].detach()
    gnorm = g.norm().item()
    if gnorm < 1e-12:
        dir_vec = torch.randn_like(x0).to(DEVICE)
    else:
        dir_vec = - g / (gnorm + 1e-12)

    delta = (PERT_P * dir_vec).detach()
    perturbed_map[int(idx)] = (x0 + delta).detach()

# ----- evaluate on full test set -----
print("\n================ Robustness Evaluation (Full Test Set: perturbed+clean) ===============")
preds_all, labels_all = [], []
with torch.no_grad():
    for i, data in enumerate(dataset):
        x = data.x.to(DEVICE)
        if int(i) in perturbed_map:
            x_eval = perturbed_map[int(i)]
        else:
            x_eval = x
        logits = model(x_eval, data.edge_index.to(DEVICE), batch_for(x_eval))
        preds_all.append(int(logits.argmax()))
        labels_all.append(int(data.y.item()))

preds_all = np.array(preds_all)
labels_all = np.array(labels_all)

acc = (preds_all == labels_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_all, preds_all, average='weighted', zero_division=0)
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification Report:")
print(classification_report(labels_all, preds_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_all, preds_all, labels=[0,1]))

# ----- flip stats -----
orig_sel_preds = base_preds[selected]
adv_sel_preds = np.array([preds_all[i] for i in selected])
num_flips = int((orig_sel_preds != adv_sel_preds).sum())
print(f"\nSelected subgraphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ----- aggregate curvature stats -----
clean_stats = [p for p in per_sample_info if p[1] == 0]
troj_stats  = [p for p in per_sample_info if p[1] == 1]

def summarize(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ls = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = summarize(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = summarize(troj_stats)

print("\n--- Hessian Curvature Stats (grad outer-product proxy) ---")
print(f" Clean:  avg_lambda={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_lambda={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,lambda,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)


Selected pool: clean=11, trojan=20

Computing Hessian curvature proxies on selected subgraphs...

Constructing Hessian-aligned perturbations...

Accuracy: 93.14%
Precision: 0.9289, Recall: 0.9314, F1: 0.9299

Classification Report:
              precision    recall  f1-score   support

       clean     0.7000    0.6364    0.6667        11
      trojan     0.9565    0.9670    0.9617        91

    accuracy                         0.9314       102
   macro avg     0.8283    0.8017    0.8142       102
weighted avg     0.9289    0.9314    0.9299       102

Confusion Matrix:
[[ 7  4]
 [ 3 88]]

Selected subgraphs: 31. Flipped after perturbation: 6 (19.35%).

--- Hessian Curvature Stats (grad outer-product proxy) ---
 Clean:  avg_lambda=0.0001 ± 0.0003, avg_FDrel=9.4980e-01 ± 3.6951e-02
 Trojan: avg_lambda=0.0001 ± 0.0003, avg_FDrel=5.0204e-01 ± 4.5488e-01

Sample preview (first 6): (idx,label,lambda,FD_rel_err)
(14, 0, 3.5693579686088897e-06, 0.9636353130872353)
(33, 0, 1.0263525662213891e-

#### Prediction Margin

In [14]:
# ----------------------- Prediction Margin (Subgraph-Level) -----------------------
import torch, numpy as np, torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# ----- parameters -----
PER_CLASS = 20        # subgraphs per class
EPSILON = 5.0         # perturbation budget (strong)
ALPHA = 0.4           # PGD step size
NUM_ITERS = 15        # PGD iterations
FD_EPS = 1e-3         # finite-difference perturbation
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)

if 'model' not in globals() or 'test_loader' not in globals():
    raise RuntimeError("Need `model` and `test_loader` in environment.")

model.to(DEVICE)
model.eval()

dataset = test_loader.dataset
n_test = len(dataset)
labels_np = np.array([int(d.y.item()) for d in dataset])

# ----- select PER_CLASS subgraphs per class -----
rng = np.random.default_rng(SEED)
selected = []
for cls in [0, 1]:
    idxs = np.where(labels_np == cls)[0].tolist()
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False).tolist()
    selected.extend(chosen)
selected = np.array(selected, dtype=np.int64)
print(f"Selected pool: clean={int((labels_np[selected]==0).sum())}, trojan={int((labels_np[selected]==1).sum())}")

# helper for batching
def batch_for(x): 
    return torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)

# ----- base predictions -----
with torch.no_grad():
    base_preds = []
    for data in dataset:
        logits = model(data.x.to(DEVICE), data.edge_index.to(DEVICE), batch_for(data.x.to(DEVICE)))
        base_preds.append(int(logits.argmax()))
base_preds = np.array(base_preds)

# ----- adversarial perturbations (PGD) -----
perturbed_map = {}

print("\nRunning PGD perturbations on selected subgraphs...")
for idx in selected:
    data = dataset[int(idx)]
    x_orig = data.x.detach().clone().to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)
    y_scalar = int(data.y.item())
    target = torch.tensor([y_scalar], dtype=torch.long, device=DEVICE)  # shape (1,)

    # random init inside epsilon-ball
    delta = torch.randn_like(x_orig).to(DEVICE)
    delta = EPSILON * delta / (delta.norm() + 1e-12)
    x_adv = (x_orig + delta).detach().clone().requires_grad_(True)

    for it in range(NUM_ITERS):
        logits = model(x_adv, edge_index, batch_for(x_adv))  # shape (1, C)
        loss = F.cross_entropy(logits, target)               # target shape (1,)
        grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]

        step = ALPHA * grad / (grad.norm() + 1e-12)
        x_adv = (x_adv + step).detach()
        delta = x_adv - x_orig
        if delta.norm() > EPSILON:
            delta = delta * (EPSILON / (delta.norm() + 1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)

    perturbed_map[int(idx)] = x_adv.detach()

print("? Finished perturbations.")

# ----- evaluate on full test set (perturbed + clean) -----
print("\n================ Robustness Evaluation (Full Test Set: perturbed+clean) ===============")
preds_all, labels_all = [], []
with torch.no_grad():
    for i, data in enumerate(dataset):
        x = data.x.to(DEVICE)
        if int(i) in perturbed_map:
            x_eval = perturbed_map[int(i)]
        else:
            x_eval = x
        logits = model(x_eval, data.edge_index.to(DEVICE), batch_for(x_eval))
        preds_all.append(int(logits.argmax()))
        labels_all.append(int(data.y.item()))

preds_all = np.array(preds_all)
labels_all = np.array(labels_all)

acc = (preds_all == labels_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_all, preds_all, average='weighted', zero_division=0)
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification Report:")
print(classification_report(labels_all, preds_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_all, preds_all, labels=[0,1]))

# ----- flip stats -----
orig_sel_preds = base_preds[selected]
adv_sel_preds = np.array([preds_all[i] for i in selected])
num_flips = int((orig_sel_preds != adv_sel_preds).sum())
print(f"\nSelected subgraphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ----- compute prediction margin + FD relative error -----
per_sample_info = []  # (idx, label, margin, FD_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_eval = perturbed_map[int(idx)]
    edge_index = data.edge_index.to(DEVICE)

    with torch.no_grad():
        logits = model(x_eval, edge_index, batch_for(x_eval)).squeeze(0)
    pred_class = int(logits.argmax().item())
    pred_logit = logits[pred_class].item()
    other_logits = logits.clone()
    other_logits[pred_class] = -float('inf')
    second_max = other_logits.max().item()
    margin = pred_logit - second_max

    # finite difference check
    delta = FD_EPS * torch.randn_like(x_eval).to(DEVICE)
    with torch.no_grad():
        logits_p = model(x_eval + delta, edge_index, batch_for(x_eval)).squeeze(0)
    pred_logit_p = logits_p[pred_class].item()
    other_logits_p = logits_p.clone()
    other_logits_p[pred_class] = -float('inf')
    second_max_p = other_logits_p.max().item()
    margin_p = pred_logit_p - second_max_p

    rel_err = abs(margin - margin_p) / (abs(margin_p) + 1e-12)
    per_sample_info.append((int(idx), int(data.y.item()), float(margin), float(rel_err)))

# ----- aggregate stats -----
clean_stats = [p for p in per_sample_info if p[1] == 0]
troj_stats  = [p for p in per_sample_info if p[1] == 1]

def summarize(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ms = np.array([s[2] for s in stats])
    Es = np.array([s[3] for s in stats])
    return (Ms.mean(), Ms.std(), Es.mean(), Es.std())

cM_mean, cM_std, cE_mean, cE_std = summarize(clean_stats)
tM_mean, tM_std, tE_mean, tE_std = summarize(troj_stats)

print("\n--- Prediction Margin Stats (on perturbed subgraphs) ---")
print(f" Clean:  avg_margin={cM_mean:.4f} ± {cM_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_margin={tM_mean:.4f} ± {tM_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx,label,margin,FD_rel_err)")
for p in per_sample_info[:6]:
    print(p)


Selected pool: clean=11, trojan=20

Running PGD perturbations on selected subgraphs...
? Finished perturbations.

Accuracy: 96.08%
Precision: 0.9651, Recall: 0.9608, F1: 0.9622

Classification Report:
              precision    recall  f1-score   support

       clean     0.7692    0.9091    0.8333        11
      trojan     0.9888    0.9670    0.9778        91

    accuracy                         0.9608       102
   macro avg     0.8790    0.9381    0.9056       102
weighted avg     0.9651    0.9608    0.9622       102

Confusion Matrix:
[[10  1]
 [ 3 88]]

Selected subgraphs: 31. Flipped after perturbation: 3 (9.68%).

--- Prediction Margin Stats (on perturbed subgraphs) ---
 Clean:  avg_margin=2.8802 ± 0.9243, avg_FDrel=2.6496e-04 ± 4.0346e-04
 Trojan: avg_margin=88.2754 ± 267.9658, avg_FDrel=6.3873e-05 ± 8.6440e-05

Sample preview (first 6): (idx,label,margin,FD_rel_err)
(14, 0, 3.277017593383789, 0.00011118161021237533)
(33, 0, 3.455471992492676, 8.099638524306989e-05)
(63, 0, 2.

#### Adversarial Robustness Radius

In [17]:
# ----------------------- Adversarial Robustness Radius (Subgraph-Level) -----------------------
import time

# ----- helpers -----
def f_for_subgraph(x_tensor, data):
    """Return logits for given subgraph with node features replaced by x_tensor."""
    x_tensor = x_tensor.to(DEVICE)
    with torch.no_grad():
        out = model(x_tensor, data.edge_index.to(DEVICE), batch_for(x_tensor))
    return out.squeeze(0)

def adversarial_radius_for_subgraph(data, x0, initial_epsilon=1e-3, growth_factor=1.25,
                                    max_epsilon=20.0, bs_iters=10, num_trials=6):
    """Estimate minimal perturbation radius that flips prediction (around perturbed point)."""
    with torch.no_grad():
        base_out = model(x0, data.edge_index.to(DEVICE), batch_for(x0))
        y0 = int(base_out.argmax().item())

    def is_same(x):
        out = f_for_subgraph(x, data)
        return int(out.argmax().item()) == y0

    radii = []
    for _ in range(num_trials):
        d = torch.randn_like(x0)
        d = d / (d.norm() + 1e-12)

        eps = initial_epsilon
        while eps < max_epsilon and is_same(x0 + eps * d):
            eps *= growth_factor

        if eps >= max_epsilon:
            candidate = max_epsilon
        else:
            low, high = eps / growth_factor, eps
            for _ in range(bs_iters):
                mid = 0.5 * (low + high)
                if is_same(x0 + mid * d):
                    low = mid
                else:
                    high = mid
            candidate = float(high)
        radii.append(candidate)
    return float(min(radii))

def adversarial_radius_relerr(data, x0):
    r1 = adversarial_radius_for_subgraph(data, x0, growth_factor=2.25, bs_iters=10, num_trials=6)
    r2 = adversarial_radius_for_subgraph(data, x0, growth_factor=1.4, bs_iters=12, num_trials=6)
    rel_err = abs(r1 - r2) / (abs(r2) + 1e-12)
    return r1, rel_err

# ----- ARR computation on selected perturbed subgraphs -----
class_names = ['clean', 'trojan']
class_adv_radius = {cn: [] for cn in class_names}
class_rel_errors = {cn: [] for cn in class_names}
all_radii, all_rel_errs = [], []

t0 = time.time()
print("\nComputing Adversarial Robustness Radius (ARR) for selected perturbed subgraphs...")
for i, idx in enumerate(selected):
    idx = int(idx)
    data = dataset[idx]
    label = int(data.y.item())
    cn = class_names[label]

    x0 = perturbed_map[idx].clone().detach().to(DEVICE)  # start from perturbed features
    r, rel = adversarial_radius_relerr(data, x0)

    class_adv_radius[cn].append(r)
    class_rel_errors[cn].append(rel)
    all_radii.append(r)
    all_rel_errs.append(rel)

    if (i+1) % 20 == 0:
        print(f"  processed {i+1}/{len(selected)} subgraphs...")

t1 = time.time()
print(f"? Done ARR computation. Time elapsed: {t1-t0:.1f}s")

# ----- Reporting ARR aggregates -----
print("\n--- Adversarial Robustness Radius (ARR) Stats ---")
print("{:<10s} {:>14s} {:>22s}".format("Class", "Avg Radius ± Std", "Avg Rel. Error ± Std"))
print("-"*52)
for cn in class_names:
    if class_adv_radius[cn]:
        print("{:<10s} {:>7.4f} ± {:<7.4f} {:>14.4e} ± {:<10.4e}".format(
            cn, np.mean(class_adv_radius[cn]), np.std(class_adv_radius[cn]),
            np.mean(class_rel_errors[cn]), np.std(class_rel_errors[cn])
        ))
    else:
        print("{:<10s} {:>10s}".format(cn, "-"))

print("\nOverall ARR: Avg Radius: {:.4f} ± {:.4f}".format(np.mean(all_radii), np.std(all_radii)))
print("Overall ARR: Avg Relative Error: {:.4e} ± {:.4e}".format(np.mean(all_rel_errs), np.std(all_rel_errs)))



Computing Adversarial Robustness Radius (ARR) for selected perturbed subgraphs...
  processed 20/31 subgraphs...
? Done ARR computation. Time elapsed: 23.6s

--- Adversarial Robustness Radius (ARR) Stats ---
Class      Avg Radius ± Std   Avg Rel. Error ± Std
----------------------------------------------------
clean      20.0000 ± 0.0000      0.0000e+00 ± 0.0000e+00
trojan     20.0000 ± 0.0000      0.0000e+00 ± 0.0000e+00

Overall ARR: Avg Radius: 20.0000 ± 0.0000
Overall ARR: Avg Relative Error: 0.0000e+00 ± 0.0000e+00


In [18]:
# ------------------ Evaluate model on full test set (200 perturbed + rest unperturbed) ------------------
model.eval()
with torch.no_grad():
    all_logits = []
    all_labels = []
    for i, data in enumerate(dataset):
        x_in = perturbed_map[i].to(DEVICE) if i in perturbed_map else data.x.to(DEVICE)
        logits = model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        all_logits.append(logits.cpu())
        all_labels.append(int(data.y.item()))

    all_logits = torch.cat(all_logits, dim=0)
    preds_all = all_logits.argmax(dim=1).numpy()
    labels_all = np.array(all_labels)

acc = (preds_all == labels_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_all, preds_all, average='weighted', zero_division=0)

print("\n============= Robustness Evaluation (Full Test Set: perturbed selected subgraphs + others unperturbed) =============")
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification report:")
print(classification_report(labels_all, preds_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_all, preds_all, labels=[0,1]))

# ------------------ Flip statistics for selected subgraphs ------------------
with torch.no_grad():
    orig_preds = []
    for data in dataset:
        x_in = data.x.to(DEVICE)
        logits = model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        orig_preds.append(logits.cpu())
    orig_preds = torch.cat(orig_preds, dim=0).argmax(dim=1).numpy()

adv_sel_preds = preds_all[selected]
num_flips = int((orig_preds[selected] != adv_sel_preds).sum())
print(f"\nSelected subgraphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0 * num_flips/len(selected):.2f}%).")



Accuracy: 96.08%
Precision: 0.9651, Recall: 0.9608, F1: 0.9622

Classification report:
              precision    recall  f1-score   support

       clean     0.7692    0.9091    0.8333        11
      trojan     0.9888    0.9670    0.9778        91

    accuracy                         0.9608       102
   macro avg     0.8790    0.9381    0.9056       102
weighted avg     0.9651    0.9608    0.9622       102

Confusion Matrix:
[[10  1]
 [ 3 88]]

Selected subgraphs: 31. Flipped after perturbation: 3 (9.68%).


#### Stability Under Input Noise

In [20]:
# ============================================================
# Stability Under Input Noise (SUIN) - Subgraph classification
# ============================================================
import torch
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import time

# --- Parameters ---
NOISE_SIGMA        = 0.5   # Gaussian noise stddev for stability metric
NUM_NOISE_SAMPLES  = 20     # Monte Carlo samples per subgraph
RELERR_RESAMPLES   = 5      # repeats to estimate relative error

print("\n--- Stability Under Input Noise (PGD-first, then metric) ---")
print(f"Selected subgraphs: clean={(labels_np[selected]==0).sum()}, trojan={(labels_np[selected]==1).sum()}")

# ============================================================
# Step 1: Evaluate model on full dataset (perturbed + original)
# ============================================================
model.eval()
with torch.no_grad():
    all_logits, all_labels = [], []
    for i, data in enumerate(dataset):
        x_in = perturbed_map[i].to(DEVICE) if i in perturbed_map else data.x.to(DEVICE)
        logits = model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        all_logits.append(logits.cpu())
        all_labels.append(int(data.y.item()))

    all_logits = torch.cat(all_logits, dim=0)
    preds_all  = all_logits.argmax(dim=1).numpy()
    labels_all = np.array(all_labels)

acc = (preds_all == labels_all).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_all, preds_all, average='weighted', zero_division=0)

print("\n============= Robustness Evaluation (Full Test Set: perturbed selected + others unperturbed) =============")
print(f"Accuracy: {acc*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}\n")
print("Classification report:")
print(classification_report(labels_all, preds_all, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix:")
print(confusion_matrix(labels_all, preds_all, labels=[0,1]))

# Flip statistics for selected subgraphs
with torch.no_grad():
    orig_preds = []
    for data in dataset:
        x_in = data.x.to(DEVICE)
        logits = model(x_in, data.edge_index.to(DEVICE), batch_for(x_in))
        orig_preds.append(logits.cpu())
    orig_preds = torch.cat(orig_preds, dim=0).argmax(dim=1).numpy()

adv_sel_preds = preds_all[selected]
num_flips = int((orig_preds[selected] != adv_sel_preds).sum())
print(f"\nSelected subgraphs: {len(selected)}. Flipped after perturbation: {num_flips} ({100.0*num_flips/len(selected):.2f}%).")

# ============================================================
# Step 2: Stability Under Input Noise (on perturbed subgraphs)
# ============================================================
def stability_for_subgraph(idx, sigma, num_samples):
    """Compute avg L2 change in logits for noisy perturbations around perturbed subgraph."""
    data = dataset[idx]
    base_x = perturbed_map[idx].to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)
    batch = batch_for(base_x)

    with torch.no_grad():
        f_orig = model(base_x, edge_index, batch).squeeze()

    diffs = []
    for _ in range(num_samples):
        noise = sigma * torch.randn_like(base_x).to(DEVICE)
        f_noisy = model(base_x + noise, edge_index, batch).squeeze()
        diffs.append(torch.norm(f_noisy - f_orig).item())
    return float(np.mean(diffs))

print("\nComputing Stability Under Input Noise (this may take a while)...")
t0 = time.time()
per_sample_info = []  # (idx, label, stability, rel_err)
for i, idx in enumerate(selected):
    s_val = stability_for_subgraph(idx, NOISE_SIGMA, NUM_NOISE_SAMPLES)
    re_vals = [stability_for_subgraph(idx, NOISE_SIGMA, NUM_NOISE_SAMPLES) for _ in range(RELERR_RESAMPLES)]
    s_ref = float(np.mean(re_vals))
    rel_err = abs(s_val - s_ref) / (abs(s_ref) + 1e-12)
    per_sample_info.append((int(idx), int(labels_np[idx]), s_val, rel_err))
    if (i+1) % 10 == 0:
        print(f"  processed {i+1}/{len(selected)} subgraphs...")
t1 = time.time()
print(f"Done SUIN computation. Time elapsed: {t1-t0:.1f}s")

# ============================================================
# Step 3: Aggregate and report
# ============================================================
clean_stats = [(i, s, e) for (i, l, s, e) in per_sample_info if l == 0]
troj_stats  = [(i, s, e) for (i, l, s, e) in per_sample_info if l == 1]

def aggs(stats):
    if not stats:
        return (0.0, 0.0, 0.0, 0.0)
    Ss = np.array([s for (_, s, _) in stats])
    Es = np.array([e for (_, _, e) in stats])
    return (Ss.mean(), Ss.std(), Es.mean(), Es.std())

cS_mean, cS_std, cE_mean, cE_std = aggs(clean_stats)
tS_mean, tS_std, tE_mean, tE_std = aggs(troj_stats)

print("\n--- Stability Under Input Noise (on perturbed subgraphs) ---")
print(f" Clean:  avg_stability={cS_mean:.4f} ± {cS_std:.4f}, avg_relerr={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_stability={tS_mean:.4f} ± {tS_std:.4f}, avg_relerr={tE_mean:.4e} ± {tE_std:.4e}")

print("\nSample preview (first 6): (idx, label, stability, rel_err)")
for p in per_sample_info[:6]:
    print(p)



--- Stability Under Input Noise (PGD-first, then metric) ---
Selected subgraphs: clean=11, trojan=20

Accuracy: 96.08%
Precision: 0.9651, Recall: 0.9608, F1: 0.9622

Classification report:
              precision    recall  f1-score   support

       clean     0.7692    0.9091    0.8333        11
      trojan     0.9888    0.9670    0.9778        91

    accuracy                         0.9608       102
   macro avg     0.8790    0.9381    0.9056       102
weighted avg     0.9651    0.9608    0.9622       102

Confusion Matrix:
[[10  1]
 [ 3 88]]

Selected subgraphs: 31. Flipped after perturbation: 3 (9.68%).

Computing Stability Under Input Noise (this may take a while)...
  processed 10/31 subgraphs...
  processed 20/31 subgraphs...
  processed 30/31 subgraphs...
Done SUIN computation. Time elapsed: 17.2s

--- Stability Under Input Noise (on perturbed subgraphs) ---
 Clean:  avg_stability=4.4768 ± 1.8360, avg_relerr=2.5896e-02 ± 1.7661e-02
 Trojan: avg_stability=5.4528 ± 4.0126, avg

#### All in One, same perturbation across all metric.

In [21]:
# =====================================================================
# Subgraph-level: unified end-to-end robustness evaluation reusing one PGD
# =====================================================================
import time
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

# -------------------------
# Parameters (tweak here)
# -------------------------
PER_CLASS        = 20        # number of subgraphs per class to perturb
EPSILON_PGD      = 4.0       # L2 radius for the PGD perturbation applied once and reused
ALPHA_PGD        = 1.0       # PGD step size
NUM_ITERS_PGD    = 30        # PGD iterations
FD_EPS           = 1e-3      # finite-difference epsilon
ARR_INITIAL_EPS  = 1e-3      # initial epsilon for ARR search
ARR_GROW1        = 1.25
ARR_GROW2        = 1.4
ARR_MAX_EPS      = 20.0
ARR_BS_ITERS     = 10
ARR_TRIALS       = 6
STAB_SIGMA       = 0.5       # stability noise sigma
STAB_SAMPLES     = 20
STAB_RELERR_RPTS = 5
SEED             = 42

torch.manual_seed(SEED); np.random.seed(SEED)

# -------------------------
# Sanity
# -------------------------
required = ["model", "test_loader", "DEVICE"]
for r in required:
    if r not in globals():
        raise RuntimeError(f"Required variable '{r}' not found in the environment (must be defined earlier).")

model.to(DEVICE)
model.eval()

# -------------------------
# Extract test dataset list & labels (preserve loader order)
# -------------------------
dataset = list(test_loader.dataset) if hasattr(test_loader, "dataset") else list(test_loader)
n_test = len(dataset)
labels_all = np.array([int(d.y.item()) for d in dataset])
print(f"Total test subgraphs: {n_test}; class counts: {np.bincount(labels_all)}")

# -------------------------
# Select indices (same pool for all metrics)
# -------------------------
rng = np.random.default_rng(SEED)
selected = []
for cls in [0,1]:
    idxs = np.where(labels_all == cls)[0]
    if len(idxs)==0:
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False)
    selected.extend(chosen.tolist())
selected = np.array(sorted(selected), dtype=np.int64)
print("Selected perturbation pool counts:", {0:int((labels_all[selected]==0).sum()), 1:int((labels_all[selected]==1).sum())})

# small helper for single-graph batch vectors
def batch_for(x):
    return torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)

# -------------------------
# Build one shared PGD perturbation map (perturbed_map)
# - perturb each selected subgraph's node features (x) via PGD (L2 constrained)
# - store perturbed_map[idx] -> perturbed x (on DEVICE)
# -------------------------
perturbed_map = {}
orig_preds_selected = []
adv_preds_selected = []

print("\n--- Step A: Creating shared PGD perturbations for selected subgraphs ---")
t0 = time.time()
for idx in selected:
    data = dataset[int(idx)]
    x_orig = data.x.detach().clone().to(DEVICE)    # (N_nodes, feat_dim)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_orig.shape
    y_true = torch.tensor([int(data.y.item())], dtype=torch.long, device=DEVICE)

    # random initialization inside L2-ball (on flattened features)
    delta = torch.randn_like(x_orig, device=DEVICE)
    delta = delta * (EPSILON_PGD / (delta.view(-1).norm() + 1e-12))
    x_adv = (x_orig + delta).detach().clone().requires_grad_(True)

    # PGD loop (maximize CE loss to cause misclassification)
    for it in range(NUM_ITERS_PGD):
        out = model(x_adv, edge_index, batch_for(x_adv))           # [1, C]
        loss = F.cross_entropy(out, y_true)
        grad_x = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        gnorm = grad_x.view(-1).norm().item()
        if gnorm == 0:
            break
        step = (ALPHA_PGD * grad_x) / (gnorm + 1e-12)
        x_adv = (x_adv + step).detach()
        # project back to L2 ball (flattened)
        delta = x_adv - x_orig
        dnorm = delta.view(-1).norm().item()
        if dnorm > EPSILON_PGD:
            delta = delta * (EPSILON_PGD / (dnorm + 1e-12))
            x_adv = (x_orig + delta).detach()
        x_adv = x_adv.requires_grad_(True)

    x_adv_final = x_adv.detach().clone()
    perturbed_map[int(idx)] = x_adv_final

    # store flip stats
    with torch.no_grad():
        out_orig = model(x_orig, edge_index, batch_for(x_orig))
        out_adv  = model(x_adv_final, edge_index, batch_for(x_adv_final))
        orig_preds_selected.append(int(out_orig.argmax(dim=1).item()))
        adv_preds_selected.append(int(out_adv.argmax(dim=1).item()))

t1 = time.time()
print(f"? PGD perturbations done. Time: {t1-t0:.1f}s")
orig_preds_selected = np.array(orig_preds_selected)
adv_preds_selected  = np.array(adv_preds_selected)
num_flips = int((orig_preds_selected != adv_preds_selected).sum())
print(f"Selected subgraphs: {len(selected)}. Flipped after PGD: {num_flips} ({100.0 * num_flips / len(selected):.2f}%).")

# Also compute and show perturbed-only classifier behavior (for selected set)
with torch.no_grad():
    labels_sel = labels_all[selected]
    preds_perturbed = []
    for idx in selected:
        data = dataset[int(idx)]
        logits = model(perturbed_map[int(idx)], data.edge_index.to(DEVICE), batch_for(perturbed_map[int(idx)]))
        preds_perturbed.append(int(logits.argmax(dim=1).item()))
    preds_perturbed = np.array(preds_perturbed)

print("\nClassification on PERTURBED samples only (selected set):")
acc_sel = (preds_perturbed == labels_sel).mean()
prec, rec, f1, _ = precision_recall_fscore_support(labels_sel, preds_perturbed, average='weighted', zero_division=0)
print(f"Accuracy (perturbed selected): {acc_sel*100:.2f}%")
print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")
print("Classification report (perturbed selected):")
print(classification_report(labels_sel, preds_perturbed, target_names=['clean','trojan'], digits=4))
print("Confusion matrix (perturbed selected):")
print(confusion_matrix(labels_sel, preds_perturbed, labels=[0,1]))

# -------------------------------------------------------------------------
# Utility: evaluate on perturbed selected set only (used by metrics to print)
# -------------------------------------------------------------------------
def eval_perturbed_selected(perturbed_map, selected_idxs):
    with torch.no_grad():
        y_true = []
        y_pred = []
        for idx in selected_idxs:
            data = dataset[int(idx)]
            lab = int(data.y.item())
            x_eval = perturbed_map[int(idx)]
            logits = model(x_eval, data.edge_index.to(DEVICE), batch_for(x_eval))
            y_true.append(lab)
            y_pred.append(int(logits.argmax(dim=1).item()))
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    acc = (y_true == y_pred).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    return y_true, y_pred, acc, prec, rec, f1

# -------------------------
# Metric 1: Jacobian Frobenius norm + FD relative error
# (computed at the perturbed inputs; flattened features)
# -------------------------
print("\n\n================ Metric: Jacobian Frobenius Norm + FD Relative Error (on perturbed inputs) ================")
per_sample_jac = []   # (idx,label, jac_frob, fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_adv.shape
    d = n_nodes * feat_dim

    # flatten
    x_flat = x_adv.view(-1).detach().clone().requires_grad_(True)

    def f_flat(x_in):
        x_mat = x_in.view_as(x_adv)
        out = model(x_mat, edge_index, batch_for(x_mat))   # [1, C]
        return out.squeeze(0)                              # (C,)

    # Jacobian: shape (C, d)
    try:
        J = torch.autograd.functional.jacobian(f_flat, x_flat)   # (C, d)
    except RuntimeError:
        # fallback: compute per-output jac rows
        logits0 = f_flat(x_flat).detach()
        C = logits0.shape[0]
        rows = []
        for c in range(C):
            def scalar_f(z, cidx=c):
                return f_flat(z)[cidx]
            r = torch.autograd.functional.jacobian(scalar_f, x_flat)
            rows.append(r.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    J = J.detach()
    jac_frob = float(torch.norm(J, p='fro').item())

    # FD relative error
    delta_fd = FD_EPS * torch.randn(d, device=DEVICE)
    pred_change = (J @ delta_fd)                      # (C,)
    f0 = f_flat(x_flat).detach()
    f0p = f_flat((x_flat + delta_fd)).detach()
    actual_change = f0p - f0
    if torch.norm(actual_change).item() == 0:
        rel_err = 0.0
    else:
        rel_err = float(torch.norm(pred_change - actual_change).item() / (torch.norm(actual_change).item() + 1e-8))

    per_sample_jac.append((int(idx), int(data.y.item()), jac_frob, rel_err))

# aggregate & print
arr = np.array([[i,l,j,r] for (i,l,j,r) in per_sample_jac], dtype=object)
if arr.size:
    clean_vals = arr[arr[:,1]==0][:,2].astype(float) if (arr[:,1]==0).any() else np.array([])
    troj_vals  = arr[arr[:,1]==1][:,2].astype(float) if (arr[:,1]==1).any() else np.array([])
    clean_errs = arr[arr[:,1]==0][:,3].astype(float) if (arr[:,1]==0).any() else np.array([])
    troj_errs  = arr[arr[:,1]==1][:,3].astype(float) if (arr[:,1]==1).any() else np.array([])

    def mean_std(a): return (a.mean(), a.std()) if len(a)>0 else (0.0,0.0)

    c_mean, c_std = mean_std(clean_vals)
    t_mean, t_std = mean_std(troj_vals)
    ce_mean, ce_std = mean_std(clean_errs)
    te_mean, te_std = mean_std(troj_errs)

    print("\nJacobian Frobenius norms & FD relative errors (aggregated) ON PERTURBED SAMPLES:")
    print(f" Clean subgraphs:  avg_norm={c_mean:.4f} ± {c_std:.4f}, avg_FDrel={ce_mean:.4e} ± {ce_std:.4e}")
    print(f" Trojan subgraphs: avg_norm={t_mean:.4f} ± {t_std:.4f}, avg_FDrel={te_mean:.4e} ± {te_std:.4e}")
    print("\nSample preview (first 6): (idx,label,jacobian_frob,FD_rel_err)")
    for t in per_sample_jac[:6]:
        print(t)
else:
    print("No Jacobian samples computed.")

# Print classification performance on perturbed selected set only
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Jacobian stage (re-used perturbed set):")
print(f"Accuracy: {acc_sel*100:.2f}%")
print(f"Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print("Classification report:")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print("Confusion matrix:")
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 2: Local Lipschitz (spectral norm of Jacobian) + FD relative error
# (Compute at perturbed inputs; flatten J to (C,d) and compute top singular value)
# -------------------------
print("\n\n================ Metric: Local Lipschitz (spectral norm) + FD Relative Error ================")
per_sample_lip = []  # (idx,label,sigma_max, fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_adv.shape
    d = n_nodes * feat_dim

    def f_local(x):
        return model(x, edge_index, batch_for(x)).squeeze(0)

    # Jacobian shape (C, N, F)
    try:
        J = torch.autograd.functional.jacobian(f_local, x_adv).detach()   # (C, N, F)
    except RuntimeError:
        # fallback per-output (slower)
        logits0 = f_local(x_adv).detach()
        C = logits0.shape[0]
        rows = []
        for c in range(C):
            def scalar_f(z, cidx=c):
                return f_local(z)[cidx]
            row = torch.autograd.functional.jacobian(scalar_f, x_adv)
            rows.append(row.unsqueeze(0))
        J = torch.cat(rows, dim=0)

    C = J.shape[0]
    J_flat = J.reshape(C, d)   # (C, d)

    # spectral norm via SVD (largest singular value)
    try:
        _, S, _ = torch.linalg.svd(J_flat, full_matrices=False)
        sigma_max = float(S[0].item())
    except RuntimeError:
        # fallback via eigen
        JJT = (J_flat @ J_flat.T).cpu().numpy()
        eigvals = np.linalg.eigvalsh(JJT)
        sigma_max = float(np.sqrt(max(eigvals.max(), 0.0)))

    # FD relative error
    delta_fd = FD_EPS * torch.randn(d, device=DEVICE)
    pred_change = (J_flat @ delta_fd)               # (C,)
    f0 = f_local(x_adv).detach()
    f0p = f_local(x_adv + delta_fd.view_as(x_adv)).detach()
    actual_change = f0p - f0
    if torch.norm(actual_change).item() == 0:
        fd_rel = 0.0
    else:
        fd_rel = float(torch.norm(pred_change - actual_change).item() / (torch.norm(actual_change).item() + 1e-8))

    per_sample_lip.append((int(idx), int(data.y.item()), float(sigma_max), float(fd_rel)))

# aggregate & print
clean_stats = [p for p in per_sample_lip if p[1]==0]
troj_stats  = [p for p in per_sample_lip if p[1]==1]
def aggs(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ls = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = aggs(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = aggs(troj_stats)

print("\nLocal Lipschitz (spectral) & FD relative errors (on PERTURBED samples):")
print(f" Clean:  avg_L={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_L={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,sigma_max,fd_rel_err)")
for p in per_sample_lip[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Local Lipschitz stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 3: Hessian curvature proxy (||g||^2) + FD relative error
# (Compute gradient of log-prob for predicted class at perturbed inputs)
# -------------------------
print("\n\n================ Metric: Hessian curvature proxy (||g||^2) + FD Relative Error ================")
TRIALS_HESS_FD = 5
per_sample_hess = []   # (idx,label, lambda_proxy=||g||^2, avg_fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE).requires_grad_(True)
    edge_index = data.edge_index.to(DEVICE)

    # forward & predicted class (allow grad)
    logits = model(x_adv, edge_index, batch_for(x_adv)).squeeze(0)
    pred_class = int(logits.argmax().item())

    # compute log-prob of predicted class
    logp = F.log_softmax(logits, dim=0)[pred_class]
    # gradient wrt x_adv
    g = torch.autograd.grad(logp, x_adv, retain_graph=False, create_graph=False, allow_unused=False)[0].detach()
    lambda_proxy = float(g.norm(p=2).item() ** 2)

    # FD relative errors (multiple small deltas)
    rels = []
    for _ in range(TRIALS_HESS_FD):
        delta = FD_EPS * torch.randn_like(x_adv).to(DEVICE)
        gt_delta = float((g * delta).sum().item())
        pred_second = 0.5 * (gt_delta ** 2)
        # recompute logp at perturbed input
        logits_p = model((x_adv + delta).detach(), edge_index, batch_for(x_adv + delta)).squeeze(0)
        logp_p = F.log_softmax(logits_p, dim=0)[pred_class]
        actual_second = float((logp_p - F.log_softmax(logits, dim=0)[pred_class]).item() - gt_delta)
        rel_err = abs(pred_second - actual_second) / (abs(actual_second) + 1e-8)
        rels.append(rel_err)
    avg_rel_err = float(np.mean(rels))
    per_sample_hess.append((int(idx), int(data.y.item()), lambda_proxy, avg_rel_err))

# aggregate & print
clean_stats = [t for t in per_sample_hess if t[1]==0]
troj_stats  = [t for t in per_sample_hess if t[1]==1]
def summarize(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ls = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ls.mean(), Ls.std(), Es.mean(), Es.std())

cL_mean, cL_std, cE_mean, cE_std = summarize(clean_stats)
tL_mean, tL_std, tE_mean, tE_std = summarize(troj_stats)
print("\nHessian-curvature proxy (||g||^2) & FD relative errors (on PERTURBED samples):")
print(f" Clean:  avg_lambda={cL_mean:.4f} ± {cL_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_lambda={tL_mean:.4f} ± {tL_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,lambda,FD_rel_err)")
for p in per_sample_hess[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Hessian stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 4: Prediction Margin + FD relative error (on perturbed inputs)
# -------------------------
print("\n\n================ Metric: Prediction Margin + FD Relative Error ================")
per_sample_margin = []   # (idx,label, margin, fd_rel_err)
for idx in selected:
    data = dataset[int(idx)]
    x_adv = perturbed_map[int(idx)].detach().clone().to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)

    with torch.no_grad():
        logits = model(x_adv, edge_index, batch_for(x_adv)).squeeze(0)
    pred_class = int(logits.argmax().item())
    pred_logit = float(logits[pred_class].item())
    other_logits = logits.clone()
    other_logits[pred_class] = -float('inf')
    second_max = float(other_logits.max().item())
    margin = pred_logit - second_max

    # FD check
    delta = FD_EPS * torch.randn_like(x_adv).to(DEVICE)
    with torch.no_grad():
        logits_p = model(x_adv + delta, edge_index, batch_for(x_adv)).squeeze(0)
    pred_logit_p = float(logits_p[pred_class].item())
    other_logits_p = logits_p.clone()
    other_logits_p[pred_class] = -float('inf')
    second_max_p = float(other_logits_p.max().item())
    margin_p = pred_logit_p - second_max_p
    rel_err = abs(margin - margin_p) / (abs(margin_p) + 1e-12)
    per_sample_margin.append((int(idx), int(data.y.item()), float(margin), float(rel_err)))

# aggregate & print
clean_stats = [p for p in per_sample_margin if p[1]==0]
troj_stats  = [p for p in per_sample_margin if p[1]==1]
def aggs_m(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ms = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ms.mean(), Ms.std(), Es.mean(), Es.std())

cM_mean, cM_std, cE_mean, cE_std = aggs_m(clean_stats)
tM_mean, tM_std, tE_mean, tE_std = aggs_m(troj_stats)
print("\nPrediction Margin stats (on PERTURBED samples):")
print(f" Clean:  avg_margin={cM_mean:.4f} ± {cM_std:.4f}, avg_FDrel={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan: avg_margin={tM_mean:.4f} ± {tM_std:.4f}, avg_FDrel={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,margin,fd_rel_err)")
for p in per_sample_margin[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - Prediction Margin stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 5: Adversarial Robustness Radius (ARR) around perturbed points
# -------------------------
print("\n\n================ Metric: Adversarial Robustness Radius (ARR) ================")
def f_for_subgraph(x_tensor, data):
    with torch.no_grad():
        out = model(x_tensor, data.edge_index.to(DEVICE), batch_for(x_tensor))
    return out.squeeze(0)

def adversarial_radius_for_subgraph(data, x0, initial_epsilon=ARR_INITIAL_EPS, growth_factor=ARR_GROW1,
                                    max_epsilon=ARR_MAX_EPS, bs_iters=ARR_BS_ITERS, num_trials=ARR_TRIALS):
    """Estimate minimal L2 norm that flips prediction around x0."""
    x0 = x0.clone().detach().to(DEVICE)
    with torch.no_grad():
        base_out = model(x0, data.edge_index.to(DEVICE), batch_for(x0))
        y0 = int(base_out.argmax().item())

    def is_same(x):
        out = f_for_subgraph(x, data)
        return int(out.argmax().item()) == y0

    radii = []
    for _ in range(num_trials):
        d = torch.randn_like(x0).to(DEVICE)
        d = d / (d.view(-1).norm() + 1e-12)
        eps = initial_epsilon
        while eps < max_epsilon and is_same(x0 + eps * d):
            eps *= growth_factor
        if eps >= max_epsilon:
            radii.append(float(max_epsilon))
            continue
        low, high = eps / growth_factor, eps
        for _ in range(bs_iters):
            mid = 0.5 * (low + high)
            if is_same(x0 + mid * d):
                low = mid
            else:
                high = mid
        radii.append(float(high))
    return float(min(radii))

def adversarial_radius_relerr(data, x0):
    r1 = adversarial_radius_for_subgraph(data, x0, growth_factor=ARR_GROW1, num_trials=ARR_TRIALS)
    r2 = adversarial_radius_for_subgraph(data, x0, growth_factor=ARR_GROW2, num_trials=ARR_TRIALS)
    rel_err = abs(r1 - r2) / (abs(r2) + 1e-12)
    return r1, rel_err

class_names = ['clean','trojan']
class_adv_radius = {cn: [] for cn in class_names}
class_rel_errors = {cn: [] for cn in class_names}
all_rads = []; all_relerrs = []

t0 = time.time()
for i, idx in enumerate(selected):
    data = dataset[int(idx)]
    x0 = perturbed_map[int(idx)].clone().detach().to(DEVICE)
    r, rel = adversarial_radius_relerr(data, x0)
    lab = int(data.y.item())
    class_adv_radius[class_names[lab]].append(r)
    class_rel_errors[class_names[lab]].append(rel)
    all_rads.append(r); all_relerrs.append(rel)
    if (i+1)%10 == 0:
        print(f" processed {i+1}/{len(selected)} ...")
t1 = time.time()
print(f"? ARR computation done. Time elapsed: {t1-t0:.1f}s")

# reporting ARR
print("\nARR (Adversarial Robustness Radius) Stats (on perturbed selected samples):")
for cn in class_names:
    vals = class_adv_radius[cn]
    errs = class_rel_errors[cn]
    if vals:
        print(f" {cn:6s}: avg_radius={np.mean(vals):.4f} ± {np.std(vals):.4f}, avg_relerr={np.mean(errs):.4e} ± {np.std(errs):.4e}")
    else:
        print(f" {cn:6s}: -")
print(f" Overall: avg_radius={np.mean(all_rads):.4f} ± {np.std(all_rads):.4f}; avg_relerr={np.mean(all_relerrs):.4e} ± {np.std(all_relerrs):.4e}")

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - ARR stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# -------------------------
# Metric 6: Stability Under Input Noise (SUIN) on perturbed subgraphs
# -------------------------
print("\n\n================ Metric: Stability Under Input Noise (SUIN) ================")
def stability_for_subgraph(idx, sigma, num_samples):
    data = dataset[int(idx)]
    base_x = perturbed_map[int(idx)].to(DEVICE)
    edge_index = data.edge_index.to(DEVICE)
    batch = batch_for(base_x)
    with torch.no_grad():
        f_orig = model(base_x, edge_index, batch).squeeze(0)
    diffs = []
    for _ in range(num_samples):
        noise = sigma * torch.randn_like(base_x).to(DEVICE)
        f_noisy = model(base_x + noise, edge_index, batch).squeeze(0)
        diffs.append(torch.norm(f_noisy - f_orig).item())
    return float(np.mean(diffs))

t0 = time.time()
per_sample_stab = []  # (idx,label, stability, rel_err)
for i, idx in enumerate(selected):
    s_val = stability_for_subgraph(idx, STAB_SIGMA, STAB_SAMPLES)
    # relative error by repeating
    re_vals = [stability_for_subgraph(idx, STAB_SIGMA, STAB_SAMPLES) for _ in range(STAB_RELERR_RPTS)]
    s_ref = float(np.mean(re_vals))
    rel_err = abs(s_val - s_ref) / (abs(s_ref) + 1e-12)
    per_sample_stab.append((int(idx), int(labels_all[idx]), float(s_val), float(rel_err)))
    if (i+1) % 10 == 0:
        print(f" processed {i+1}/{len(selected)} ...")
t1 = time.time()
print(f"? SUIN done. Time elapsed: {t1-t0:.1f}s")

# aggregate & print
clean_stats = [p for p in per_sample_stab if p[1]==0]
troj_stats  = [p for p in per_sample_stab if p[1]==1]
def aggs_s(stats):
    if not stats: return (0.0,0.0,0.0,0.0)
    Ss = np.array([s[2] for s in stats]); Es = np.array([s[3] for s in stats])
    return (Ss.mean(), Ss.std(), Es.mean(), Es.std())

cS_mean, cS_std, cE_mean, cE_std = aggs_s(clean_stats)
tS_mean, tS_std, tE_mean, tE_std = aggs_s(troj_stats)
print("\nStability Under Input Noise (on perturbed selected samples):")
print(f" Clean: avg_stability={cS_mean:.4f} ± {cS_std:.4f}, avg_relerr={cE_mean:.4e} ± {cE_std:.4e}")
print(f" Trojan:avg_stability={tS_mean:.4f} ± {tS_std:.4f}, avg_relerr={tE_mean:.4e} ± {tE_std:.4e}")
print("\nSample preview (first 6): (idx,label,stability,rel_err)")
for p in per_sample_stab[:6]:
    print(p)

# classification on perturbed selected
y_true_sel, y_pred_sel, acc_sel, prec_sel, rec_sel, f1_sel = (*eval_perturbed_selected(perturbed_map, selected),)
print("\nClassification (perturbed selected) - SUIN stage:")
print(f"Accuracy: {acc_sel*100:.2f}% | Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

print("\n\nAll metrics computed on the SAME selected subgraphs and the SAME PGD perturbations (perturbed_map).")
print("You can adjust PER_CLASS, EPSILON_PGD, ALPHA_PGD, NUM_ITERS_PGD to change attack strength.")


Total test subgraphs: 102; class counts: [11 91]
Selected perturbation pool counts: {0: 11, 1: 20}

--- Step A: Creating shared PGD perturbations for selected subgraphs ---
? PGD perturbations done. Time: 15.0s
Selected subgraphs: 31. Flipped after PGD: 3 (9.68%).

Classification on PERTURBED samples only (selected set):
Accuracy (perturbed selected): 90.32%
Precision: 0.9069, Recall: 0.9032, F1: 0.9041
Classification report (perturbed selected):
              precision    recall  f1-score   support

       clean     0.8333    0.9091    0.8696        11
      trojan     0.9474    0.9000    0.9231        20

    accuracy                         0.9032        31
   macro avg     0.8904    0.9045    0.8963        31
weighted avg     0.9069    0.9032    0.9041        31

Confusion matrix (perturbed selected):
[[10  1]
 [ 2 18]]



Jacobian Frobenius norms & FD relative errors (aggregated) ON PERTURBED SAMPLES:
 Clean subgraphs:  avg_norm=0.4876 ± 0.1973, avg_FDrel=2.7351e-02 ± 5.7933e-02
 

#### Discard the below run. Not to be used in the paper.

In [None]:
# Graph-level robustness evaluation (perturb -> compute metrics -> evaluate perturbed set)
# Requires: `model`, `test_loader`, `DEVICE` to already exist in the environment.
import time
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, accuracy_score

# ---------------------- Tuneable parameters ----------------------
PER_CLASS        = 20       # choose up to this many graphs per class to perturb (adjust to dataset size)
EPSILON_PGD      = 5.0      # L2 radius for PGD perturbation (applied to flattened features)
ALPHA_PGD        = 1.0      # PGD step scaling (we normalize gradient each step)
NUM_ITERS_PGD    = 30       # PGD iterations for creating perturbations
FD_EPS           = 1e-3     # finite-difference epsilon for Jacobian/fd checks
ARR_TRIALS       = 6        # trials per sample for ARR estimation
ARR_BS_ITERS     = 10
ARR_GROW1        = 1.25
ARR_GROW2        = 1.4
ARR_MAX_EPS      = 50.0
NOISE_SIGMA      = 0.1      # for stability metric (std dev of Gaussian noise)
NUM_NOISE_SAMPLES= 20
HESS_TRIALS      = 5        # FD trials for Hessian second-order relative error
SEED             = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

# ---------------------- Sanity checks (use variables from training run) ----------------------
required = ["model", "test_loader", "DEVICE"]
missing = [r for r in required if r not in globals()]
if missing:
    raise RuntimeError(f"Required variables missing in current environment: {missing}.\n"
                       "Make sure you executed your train/eval notebook first so model, test_loader, DEVICE exist.")

model.to(DEVICE)
model.eval()

# get dataset list (preserve original ordering)
dataset = list(test_loader.dataset)
n_test = len(dataset)
labels_all = np.array([int(d.y.item()) for d in dataset])
print(f"Test graphs: {n_test}; class counts: {np.bincount(labels_all)}")

# ---------------------- Select samples (PER_CLASS per class) ----------------------
rng = np.random.default_rng(SEED)
selected_idxs = []
for cls in sorted(np.unique(labels_all)):
    idxs = np.where(labels_all == cls)[0]
    if len(idxs) == 0:
        continue
    k = min(PER_CLASS, len(idxs))
    chosen = rng.choice(idxs, size=k, replace=False)
    selected_idxs.extend(chosen.tolist())
selected_idxs = np.array(sorted(selected_idxs), dtype=np.int64)
print("Selected perturbation pool:", {int(cls): int((labels_all[selected_idxs]==cls).sum()) for cls in np.unique(labels_all)})

if len(selected_idxs) == 0:
    raise RuntimeError("No samples selected - check PER_CLASS and the test split.")

# helper to produce batch vector for single graph
def batch_for(x): 
    return torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)

# ---------------------- Save original predictions for selected samples ----------------------
orig_preds = {}
with torch.no_grad():
    for idx in selected_idxs:
        data = dataset[int(idx)]
        x = data.x.to(DEVICE)
        out = model(x, data.edge_index.to(DEVICE), batch_for(x))
        orig_preds[int(idx)] = int(out.argmax(dim=1).item())

# ---------------------- Create one shared set of perturbations via PGD ----------------------
print("\n--- Creating PGD perturbations for selected graphs (shared across metrics) ---")
perturbed_map = {}   # idx -> perturbed x tensor (on DEVICE)
t0 = time.time()
for count, idx in enumerate(selected_idxs, start=1):
    data = dataset[int(idx)]
    x_orig = data.x.detach().clone().to(DEVICE)            # shape (N, F)
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_orig.shape
    flat_dim = n_nodes * feat_dim

    # initialize at random direction scaled to EPSILON_PGD
    delta = torch.randn_like(x_orig, device=DEVICE)
    delta = delta * (EPSILON_PGD / (delta.view(-1).norm() + 1e-12))
    x_adv = (x_orig + delta).detach().clone().requires_grad_(True)
    y_true = torch.tensor([int(data.y.item())], dtype=torch.long, device=DEVICE)

    # PGD loop: maximize CE (untargeted)
    for it in range(NUM_ITERS_PGD):
        out = model(x_adv, edge_index, batch_for(x_adv))   # shape [1, C]
        loss = F.cross_entropy(out, y_true)
        grad_x = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]
        gnorm = grad_x.view(-1).norm().item()
        if gnorm == 0.0:
            break
        step = (ALPHA_PGD * grad_x) / (gnorm + 1e-12)
        x_adv = (x_adv + step).detach()
        # project back to L2 ball
        delta_now = (x_adv - x_orig).view(-1)
        dnorm = delta_now.norm().item()
        if dnorm > EPSILON_PGD:
            delta_now = delta_now * (EPSILON_PGD / (dnorm + 1e-12))
            x_adv = (x_orig + delta_now.view_as(x_orig)).detach()
        x_adv = x_adv.requires_grad_(True)

    perturbed_map[int(idx)] = x_adv.detach().clone()
    if count % 10 == 0 or count == len(selected_idxs):
        print(f"  perturbed {count}/{len(selected_idxs)}")

t1 = time.time()
print(f"Finished PGD perturbations in {t1-t0:.1f}s. Perturbed samples: {len(perturbed_map)}")

# ---------------------- Flip statistics after perturbation (selected samples only) ----------------------
adv_preds = {}
with torch.no_grad():
    for idx in selected_idxs:
        data = dataset[int(idx)]
        x_adv = perturbed_map[int(idx)].to(DEVICE)
        out_adv = model(x_adv, data.edge_index.to(DEVICE), batch_for(x_adv))
        adv_preds[int(idx)] = int(out_adv.argmax(dim=1).item())

orig_array = np.array([orig_preds[int(i)] for i in selected_idxs])
adv_array  = np.array([adv_preds[int(i)] for i in selected_idxs])
num_flips = int((orig_array != adv_array).sum())
print(f"\nSelected graphs: {len(selected_idxs)}. Flipped after perturbation: {num_flips} ({100.0 * num_flips/len(selected_idxs):.2f}%).")

# ---------------------- Metric computation ON THE PERTURBED INPUTS ----------------------
print("\n--- Computing metrics at PERTURBED inputs (this can be slow) ---")
per_sample_records = []  # list of dicts, one per selected idx

def safe_jacobian(f, x):
    """Try to compute full jacobian; fallback to per-output loop on failure."""
    try:
        J = torch.autograd.functional.jacobian(f, x)   # shape (C, D)
    except RuntimeError:
        # fallback: compute per-output jacobian rows
        out0 = f(x)
        C = int(out0.shape[0])
        rows = []
        for c in range(C):
            def scalar_f(z, cidx=c):
                return f(z)[cidx]
            row = torch.autograd.functional.jacobian(scalar_f, x)
            rows.append(row.unsqueeze(0))
        J = torch.cat(rows, dim=0)
    return J

for i, idx in enumerate(selected_idxs):
    idx = int(idx)
    data = dataset[idx]
    x_adv = perturbed_map[idx].detach().clone().to(DEVICE).requires_grad_(True)   # perturbed baseline
    edge_index = data.edge_index.to(DEVICE)
    n_nodes, feat_dim = x_adv.shape
    d = n_nodes * feat_dim

    # flatten view and safe wrapper
    x_flat = x_adv.view(-1).detach().clone().requires_grad_(True)

    def f_flat(z):
        x_mat = z.view_as(x_adv)
        out = model(x_mat, edge_index, batch_for(x_mat))
        # model returns shape [1, C]; squeeze to (C,)
        return out.squeeze(0)

    # Jacobian J (C, D)
    try:
        J = safe_jacobian(f_flat, x_flat).detach()   # (C, D)
    except Exception as e:
        print(f"  Warning: Jacobian failed for idx={idx}: {e}")
        continue

    # Jacobian Frobenius norm
    jac_fro = float(torch.norm(J, p='fro').item())

    # finite-difference relative error for Jacobian (single trial)
    delta_fd = FD_EPS * torch.randn(d, device=DEVICE)
    pred_change = (J @ delta_fd)                      # (C,)
    f0 = f_flat(x_flat).detach()
    f0p = f_flat(x_flat + delta_fd).detach()
    actual_change = (f0p - f0)
    fd_rel_err_jac = float((torch.norm(pred_change - actual_change) / (torch.norm(actual_change) + 1e-8)).item())

    # Local Lipschitz (spectral norm of J)
    try:
        # prefer SVD on CPU if large to reduce CUDA memory pressure
        U, S, Vh = torch.linalg.svd(J, full_matrices=False)
        sigma_max = float(S[0].item())
    except RuntimeError:
        # fallback eigen on J @ J^T
        Jcpu = J.cpu()
        JJT = (Jcpu @ Jcpu.T).numpy()
        eigvals = np.linalg.eigvalsh(JJT)
        sigma_max = float(np.sqrt(max(eigvals.max(), 0.0)))

    # For Lipschitz we can reuse fd_rel_err_jac (since it's based on J)
    fd_rel_err_lip = fd_rel_err_jac

    # Hessian curvature proxy (||grad_x log p_pred||^2)
    # define h(x_flat) returns log-prob of predicted class at baseline x_flat
    with torch.no_grad():
        logits_baseline = f_flat(x_flat).detach()
    pred_class = int(logits_baseline.argmax().item())

    def h_flat(z):
        x_mat = z.view_as(x_adv)
        out = model(x_mat, edge_index, batch_for(x_mat)).squeeze(0)
        logp = F.log_softmax(out, dim=0)
        return logp[pred_class]

    # compute g = grad h(x_flat) w.r.t x_flat (allow_unused in case graph doesn't use inputs)
    h0 = h_flat(x_flat)
    g = torch.autograd.grad(h0, x_flat, retain_graph=False, create_graph=False, allow_unused=True)[0]
    if g is None:
        g = torch.zeros_like(x_flat)
    lambda_proxy = float((g.norm().item())**2)

    # Hessian FD relative error (average across trials)
    h_fd_errs = []
    for t in range(HESS_TRIALS):
        delta = FD_EPS * torch.randn_like(x_flat).to(DEVICE)
        gt_delta = float(torch.dot(g, delta).item())
        pred_second = 0.5 * (gt_delta ** 2)
        actual_second = float((h_flat(x_flat + delta) - h0 - torch.dot(g, delta)).item())
        rel_err = abs(pred_second - actual_second) / (abs(actual_second) + 1e-8)
        h_fd_errs.append(rel_err)
    fd_rel_err_hess = float(np.mean(h_fd_errs))

    # Prediction margin (on perturbed point) using logits (not probs)
    logits = f_flat(x_flat).detach()
    pred_c = int(logits.argmax().item())
    logit_pred = float(logits[pred_c].item())
    other_logits = logits.clone()
    other_logits[pred_c] = -float('inf')
    second_logit = float(other_logits.max().item())
    margin_val = float(logit_pred - second_logit)

    # FD relative error for margin (single trial)
    delta = FD_EPS * torch.randn_like(x_flat).to(DEVICE)
    logits_p = f_flat(x_flat + delta).detach()
    pred_c_p = int(logits_p.argmax().item())
    # compute margin w.r.t the same pred_c (we want marginal change of that predicted class)
    pred_logit_p = float(logits_p[pred_c].item()) if pred_c < logits_p.shape[0] else float(logits_p.max().item())
    other_p = logits_p.clone()
    other_p[pred_c] = -float('inf')
    second_p = float(other_p.max().item())
    margin_p = float(pred_logit_p - second_p)
    fd_rel_err_margin = abs(margin_val - margin_p) / (abs(margin_p) + 1e-12)

    # ARR: adversarial radius around perturbed point (estimate)
    def is_same_label_at(x_flat_candidate, base_label):
        x_mat = x_flat_candidate.view_as(x_adv)
        with torch.no_grad():
            out = model(x_mat, edge_index, batch_for(x_mat))
            return int(out.argmax(dim=1).item()) == base_label

    # minimal radius estimation along random directions (binary search)
    def adversarial_radius_single(x0_flat, base_label, initial_eps=1e-3, growth=ARR_GROW1,
                                  max_eps=ARR_MAX_EPS, bs_iters=ARR_BS_ITERS):
        # returns candidate radius (float)
        eps = initial_eps
        # expand until flip or cap
        while eps < max_eps:
            direction = torch.randn_like(x0_flat).to(DEVICE)
            direction = direction / (direction.norm() + 1e-12)
            if not is_same_label_at(x0_flat + eps * direction, base_label):
                # found flip at eps along this direction: binary search between eps/growth and eps
                low, high = eps / growth, eps
                for _ in range(bs_iters):
                    mid = 0.5 * (low + high)
                    if is_same_label_at(x0_flat + mid * direction, base_label):
                        low = mid
                    else:
                        high = mid
                return float(high)
            eps *= growth
        return float(max_eps)

    # combine multiple trials with two growth parameters to compute relative error
    base_label = int(pred_c)
    radii = []
    for trial in range(ARR_TRIALS):
        # try with growth1 then growth2 for relerr
        r1 = adversarial_radius_single(x_flat, base_label, growth=ARR_GROW1)
        r2 = adversarial_radius_single(x_flat, base_label, growth=ARR_GROW2)
        r = float(min(r1, r2))
        rel = abs(r1 - r2) / (abs(r2) + 1e-12)
        radii.append((r, rel))
    radii_vals = [rv for rv,_ in radii]
    rels_vals = [rel for _,rel in radii]
    adv_radius_est = float(np.min(radii_vals))
    adv_radius_relerr = float(np.mean(rels_vals))

    # Stability under input noise (averaged norm difference in logits)
    f_orig = f_flat(x_flat).detach()
    noise_diffs = []
    for _ in range(NUM_NOISE_SAMPLES):
        noise = NOISE_SIGMA * torch.randn_like(x_flat).to(DEVICE)
        f_noisy = f_flat(x_flat + noise).detach()
        noise_diffs.append(float(torch.norm(f_noisy - f_orig).item()))
    stability_val = float(np.mean(noise_diffs))
    # relative error for stability (resampling)
    revals = []
    for _ in range(3):
        revals.append(float(np.mean([float(torch.norm(f_flat(x_flat + NOISE_SIGMA * torch.randn_like(x_flat).to(DEVICE)) - f_orig).item()) for __ in range(int(NUM_NOISE_SAMPLES/2) or 1)])))
    stability_relerr = abs(stability_val - float(np.mean(revals))) / (abs(float(np.mean(revals))) + 1e-12)

    # collect per-sample record
    rec = {
        "idx": idx,
        "label": int(data.y.item()),
        "jac_fro": jac_fro,
        "fd_rel_jac": fd_rel_err_jac,
        "sigma_max": sigma_max,
        "fd_rel_lip": fd_rel_err_lip,
        "lambda_proxy": lambda_proxy,
        "fd_rel_hess": fd_rel_err_hess,
        "margin": margin_val,
        "fd_rel_margin": fd_rel_err_margin,
        "adv_radius": adv_radius_est,
        "adv_radius_relerr": adv_radius_relerr,
        "stability": stability_val,
        "stability_relerr": stability_relerr,
        "orig_pred": orig_preds[idx],
        "adv_pred": adv_preds[idx]
    }
    per_sample_records.append(rec)

    if (i+1) % 10 == 0 or (i+1) == len(selected_idxs):
        print(f"  computed metrics for {i+1}/{len(selected_idxs)} samples")

# ---------------------- Aggregate & print per-class stats ----------------------
import math
def mean_std(arr):
    if len(arr)==0:
        return (0.0, 0.0)
    a = np.array(arr, dtype=float)
    return (float(a.mean()), float(a.std()))

records = per_sample_records
if not records:
    raise RuntimeError("No per-sample records were computed.")

# group by label
records_by_label = {0: [], 1: []}
for r in records:
    records_by_label[int(r["label"])].append(r)

print("\n=== Aggregated metrics on PERTURBED selected samples (mean ± std) ===")
metrics_to_print = [
    ("jac_fro", "Jacobian Frobenius norm"),
    ("fd_rel_jac", "Jacobian FD relative error"),
    ("sigma_max", "Local Lipschitz (spectral)"),
    ("fd_rel_lip", "Lipschitz FD rel err"),
    ("lambda_proxy", "Hessian curvature proxy (||grad logp||^2)"),
    ("fd_rel_hess", "Hessian FD rel err"),
    ("margin", "Prediction margin (logit diff)"),
    ("fd_rel_margin", "Margin FD rel err"),
    ("adv_radius", "Adversarial Robustness Radius (est)"),
    ("adv_radius_relerr", "ARR rel err"),
    ("stability", "Stability under input noise"),
    ("stability_relerr", "Stability rel err"),
]

for cls in sorted(records_by_label.keys()):
    recs = records_by_label[cls]
    print(f"\nClass {cls} (n={len(recs)})")
    for key, pretty in metrics_to_print:
        vals = [r[key] for r in recs if not (isinstance(r[key], float) and (math.isinf(r[key]) or math.isnan(r[key])))]
        m, s = mean_std(vals) if vals else (float('nan'), float('nan'))
        print(f"  {pretty:40s}: {m:.4e} ± {s:.4e}")

# preview first 6 per-sample entries
print("\nSample preview (first 6 records):")
for r in records[:6]:
    print(r)

# ---------------------- Evaluate model on PERTURBED samples only ----------------------
y_true_sel = np.array([int(dataset[int(idx)].y.item()) for idx in selected_idxs])
y_pred_sel = np.array([int(per["adv_pred"]) for per in per_sample_records])

acc_sel = accuracy_score(y_true_sel, y_pred_sel)
prec_sel, rec_sel, f1_sel, _ = precision_recall_fscore_support(y_true_sel, y_pred_sel, average='weighted', zero_division=0)

print("\n=== Evaluation on PERTURBED samples only ===")
print(f"Selected perturbed samples: {len(selected_idxs)}")
print(f"Accuracy: {acc_sel*100:.2f}%")
print(f"Precision: {prec_sel:.4f}, Recall: {rec_sel:.4f}, F1: {f1_sel:.4f}\n")
print("Classification report (perturbed samples):")
print(classification_report(y_true_sel, y_pred_sel, target_names=['clean','trojan'], digits=4))
print("Confusion Matrix (perturbed samples):")
print(confusion_matrix(y_true_sel, y_pred_sel, labels=[0,1]))

# perturbation success (flips) breakdown by class
orig = np.array([orig_preds[int(i)] for i in selected_idxs])
adv  = np.array([adv_preds[int(i)] for i in selected_idxs])
flips_total = (orig != adv).sum()
print(f"\nPerturbation success (flips): {flips_total}/{len(selected_idxs)} = {100.0*flips_total/len(selected_idxs):.2f}%")
for cls in [0,1]:
    idxs_cls = [j for j,r in enumerate(selected_idxs) if labels_all[r]==cls]
    if len(idxs_cls)==0:
        continue
    flips_cls = (orig[idxs_cls] != adv[idxs_cls]).sum()
    print(f"  class {cls}: {flips_cls}/{len(idxs_cls)} = {100.0*flips_cls/len(idxs_cls):.2f}%")

print("\nDone.")


Test graphs: 102; class counts: [11 91]
Selected perturbation pool: {0: 11, 1: 20}

--- Creating PGD perturbations for selected graphs (shared across metrics) ---
  perturbed 10/31
  perturbed 20/31
  perturbed 30/31
  perturbed 31/31
Finished PGD perturbations in 49.0s. Perturbed samples: 31

Selected graphs: 31. Flipped after perturbation: 3 (9.68%).

--- Computing metrics at PERTURBED inputs (this can be slow) ---
  computed metrics for 10/31 samples
