#### Node Classification: Trojan vs non-trojan. 

Below is a single, end‑to‑end GCN training script that:

reads GNNDatasets/node.csv + GNNDatasets/node_edges.csv

auto‑builds a multi‑circuit graph (nodes keyed by circuit_name::node)

one‑hot encodes gate_type, standardizes numeric features

adds zero‑feature pseudo nodes for edge‑only items (ASSIGN_*, DFF_*, PIs/POs)

trains a 2‑layer GCN with class‑weighted loss + early stopping

reports accuracy, classification report, confusion matrix

In [3]:
# train_gcn_node_fixed.py
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

NODE_CSV  = "GNNDatasets/node.csv"
EDGE_CSV  = "GNNDatasets/node_edges.csv"
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)

# ----------------------------- Load nodes -----------------------------
nodes_df = pd.read_csv(NODE_CSV)

label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        label_col = cand; break
if label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

feat_df = nodes_df.copy()
if "gate_type" in feat_df.columns:
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name",label_col}
num_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
X = feat_df[num_cols].fillna(0.0).values.astype(np.float32)
y = nodes_df[label_col].values.astype(np.int64)

# ----------------------------- Load edges; add missing nodes -----------------------------
edges_df = pd.read_csv(EDGE_CSV)
edges_df["src_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["src"].astype(str)
edges_df["dst_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["dst"].astype(str)

known_uids = set(nodes_df["uid"])
edge_uids = set(edges_df["src_uid"]).union(set(edges_df["dst_uid"]))
missing = list(edge_uids - known_uids)

if missing:
    zero_row = np.zeros((1, X.shape[1]), dtype=np.float32)
    addX = np.repeat(zero_row, len(missing), axis=0)
    addY = -1*np.ones(len(missing), dtype=np.int64)
    add_df = pd.DataFrame({
        "uid": missing,
        "circuit_name": [u.split("::",1)[0] for u in missing],
        "node": [u.split("::",1)[1] for u in missing],
        label_col: addY
    })
    X = np.vstack([X, addX])
    y = np.concatenate([y, addY])
    nodes_df = pd.concat([nodes_df, add_df], ignore_index=True)

uid_to_idx = {u:i for i,u in enumerate(nodes_df["uid"].tolist())}
src_idx = edges_df["src_uid"].map(uid_to_idx).dropna().astype(int).values
dst_idx = edges_df["dst_uid"].map(uid_to_idx).dropna().astype(int).values
edge_index = np.stack([np.concatenate([src_idx, dst_idx]),
                       np.concatenate([dst_idx, src_idx])], axis=0)

# ----------------------------- Scale features -----------------------------
labeled_mask_np = (y >= 0)
scaler = StandardScaler()
X_scaled = X.copy()
X_scaled[labeled_mask_np] = scaler.fit_transform(X_scaled[labeled_mask_np])
if (~labeled_mask_np).any():
    X_scaled[~labeled_mask_np] = (X_scaled[~labeled_mask_np] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# ----------------------------- Splits -----------------------------
idx_all = np.where(labeled_mask_np)[0]
y_all = y[labeled_mask_np]

idx_train, idx_tmp, y_train, y_tmp = train_test_split(
    idx_all, y_all, test_size=0.30, random_state=SEED, stratify=y_all
)
idx_val, idx_test, y_val, y_test = train_test_split(
    idx_tmp, y_tmp, test_size=0.50, random_state=SEED, stratify=y_tmp
)

# ----------------------------- Torch tensors (FIX: masks as torch.bool) -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_t = torch.from_numpy(X_scaled).to(device)
y_t = torch.from_numpy(y).to(device)

edge_index_t = torch.from_numpy(edge_index).long().to(device)

train_mask_t = torch.zeros(len(y), dtype=torch.bool, device=device); train_mask_t[idx_train] = True
val_mask_t   = torch.zeros(len(y), dtype=torch.bool, device=device); val_mask_t[idx_val]   = True
test_mask_t  = torch.zeros(len(y), dtype=torch.bool, device=device); test_mask_t[idx_test]  = True
labeled_mask_t = torch.from_numpy(labeled_mask_np).to(device)

# ----------------------------- Build GCN adjacency -----------------------------
def build_adj(num_nodes, edge_index):
    self_loops = torch.arange(num_nodes, device=edge_index.device)
    ei = torch.cat([edge_index, torch.stack([self_loops, self_loops])], dim=1)
    deg = torch.bincount(ei[0], minlength=num_nodes).float()
    deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
    w = deg_inv_sqrt[ei[0]] * deg_inv_sqrt[ei[1]]
    A = torch.sparse_coo_tensor(ei, w, (num_nodes, num_nodes))
    return A.coalesce()

A_t = build_adj(X_t.size(0), edge_index_t)

# ----------------------------- Model -----------------------------
class GCNLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.0):
        super().__init__()
        self.lin = nn.Linear(in_dim, out_dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.lin.weight)
    def forward(self, x, adj):
        x = self.dropout(x)
        x = torch.sparse.mm(adj, x)
        x = self.lin(x)
        return x

class GCN(nn.Module):
    def __init__(self, in_dim, hid_dim=96, out_dim=2, dropout=0.35):
        super().__init__()
        self.g1 = GCNLayer(in_dim, hid_dim, dropout)
        self.g2 = GCNLayer(hid_dim, out_dim, dropout)
        self.do = nn.Dropout(dropout)
    def forward(self, x, adj):
        x = self.g1(x, adj); x = F.relu(x); x = self.do(x)
        x = self.g2(x, adj)
        return x

model = GCN(in_dim=X_t.size(1), hid_dim=96, out_dim=2, dropout=0.35).to(device)

# ----------------------------- Loss, optimizer -----------------------------
train_labels = y_t[train_mask_t]
classes, counts = torch.unique(train_labels, return_counts=True)
num_pos = counts[classes==1].item() if (classes==1).any() else 1
num_neg = counts[classes==0].item() if (classes==0).any() else 1
weight_pos = (num_neg + num_pos) / (2.0 * num_pos)
weight_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights = torch.tensor([weight_neg, weight_pos], dtype=torch.float32, device=device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3, weight_decay=5e-4)

# ----------------------------- Training -----------------------------
def evaluate(mask_t):
    model.eval()
    with torch.no_grad():
        logits = model(X_t, A_t)
        pred = logits.argmax(dim=1)
        msk = mask_t & (y_t >= 0)
        if msk.sum() == 0: return 0.0
        return (pred[msk] == y_t[msk]).float().mean().item()

best_val, best_state = -1.0, None
patience, patience_cnt = 20, 0
EPOCHS = 300

for epoch in range(1, EPOCHS+1):
    model.train()
    optimizer.zero_grad()
    logits = model(X_t, A_t)
    loss = criterion(logits[train_mask_t], y_t[train_mask_t])
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    optimizer.step()

    if epoch % 10 == 0 or epoch == 1:
        val_acc = evaluate(val_mask_t)
        test_acc = evaluate(test_mask_t)
        print(f"Epoch {epoch:03d} | Loss {loss.item():.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print("Early stopping."); break

if best_state is not None:
    model.load_state_dict(best_state)

# ----------------------------- Final eval -----------------------------
model.eval()
with torch.no_grad():
    logits = model(X_t, A_t)
    preds = logits.argmax(dim=1)

msk = (test_mask_t & (y_t >= 0)).cpu().numpy()
y_true = y_t.cpu().numpy()[msk]
y_pred = preds.cpu().numpy()[msk]

acc = (y_true == y_pred).mean()
print("\nFinal Evaluation (Node-Level)")
print("=============================")
print(f"Test Accuracy: {acc:.4f}\n")

print("Classification Report:")
print(classification_report(y_true, y_pred, labels=[0,1], target_names=["clean","trojan"], digits=4))

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred, labels=[0,1]))


Epoch 001 | Loss 0.8937 | Val 0.4607 | Test 0.4631
Epoch 010 | Loss 0.6115 | Val 0.9505 | Test 0.9496
Epoch 020 | Loss 0.4193 | Val 0.9883 | Test 0.9885
Epoch 030 | Loss 0.2920 | Val 0.9926 | Test 0.9931
Epoch 040 | Loss 0.2113 | Val 0.9981 | Test 0.9981
Epoch 050 | Loss 0.1559 | Val 0.9994 | Test 0.9995
Epoch 060 | Loss 0.1176 | Val 0.9999 | Test 0.9998
Epoch 070 | Loss 0.0936 | Val 0.9999 | Test 0.9998
Epoch 080 | Loss 0.0780 | Val 1.0000 | Test 1.0000
Epoch 090 | Loss 0.0668 | Val 1.0000 | Test 1.0000
Epoch 100 | Loss 0.0553 | Val 1.0000 | Test 1.0000
Epoch 110 | Loss 0.0493 | Val 1.0000 | Test 1.0000
Epoch 120 | Loss 0.0435 | Val 1.0000 | Test 1.0000
Epoch 130 | Loss 0.0407 | Val 1.0000 | Test 1.0000
Epoch 140 | Loss 0.0371 | Val 1.0000 | Test 1.0000
Epoch 150 | Loss 0.0346 | Val 1.0000 | Test 1.0000
Epoch 160 | Loss 0.0329 | Val 1.0000 | Test 1.0000
Epoch 170 | Loss 0.0301 | Val 1.0000 | Test 1.0000
Epoch 180 | Loss 0.0283 | Val 1.0000 | Test 1.0000
Epoch 190 | Loss 0.0273 | Val 1

#### Subgraph Classification

Below is a complete, end‑to‑end script for subgraph‑level classification. It:

Loads GNNDatasets/subgraph.csv, GNNDatasets/node.csv, and GNNDatasets/node_edges.csv.

Reconstructs per‑circuit graphs (NetworkX) and extracts each subgraph as the K‑hop ego graph around center_node (uses K=2, same as before).

Builds a PyTorch‑Geometric Data object per subgraph with node features taken from node.csv (numeric features + one‑hot gate type); missing nodes get zero features.

Trains a GraphSAGE model (two layers) with global mean pooling for graph classification (label_subgraph: 0/1).

Prints accuracy, classification report, and confusion matrix.

#### Subgraph Classification

In [5]:
# train_subgraph_gnn_fixed.py
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv, global_mean_pool

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 80
LR = 1e-3
HID_DIM = 64

# -----------------------
# Load node features
# -----------------------
nodes_df = pd.read_csv("GNNDatasets/node.csv")
nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# one-hot gate_type
if "gate_type" in nodes_df.columns:
    gate_ohe = pd.get_dummies(nodes_df["gate_type"], prefix="gt")
    nodes_feat_df = pd.concat([nodes_df.drop(columns=["gate_type"]), gate_ohe], axis=1)
else:
    nodes_feat_df = nodes_df.copy()

meta_cols = {"uid","node","circuit_name","label","label_node","label_graph","label_subgraph","folder"}
num_cols = [c for c in nodes_feat_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(nodes_feat_df[c])]

uid_to_feat = {}
for _, r in nodes_feat_df.iterrows():
    uid_to_feat[r["uid"]] = r[num_cols].astype(float).values

scaler = StandardScaler().fit(np.stack(list(uid_to_feat.values())))
for k in list(uid_to_feat.keys()):
    uid_to_feat[k] = scaler.transform(uid_to_feat[k].reshape(1,-1)).reshape(-1)

feat_dim = len(uid_to_feat[list(uid_to_feat.keys())[0]])

# -----------------------
# Merge edge CSVs
# -----------------------
edge_files = [
    "GNNDatasets/node_edges.csv",
    "GNNDatasets/subgraph_edges_andxor.csv",
    "GNNDatasets/subgraph_edges_countermux.csv",
    "GNNDatasets/subgraph_edges_fsmor.csv",
]

edges_by_circuit = defaultdict(list)
for ef in edge_files:
    df = pd.read_csv(ef)
    for _, r in df.iterrows():
        edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# -----------------------
# Build dataset
# -----------------------
sub_df = pd.read_csv("GNNDatasets/subgraph.csv")
data_list, labels = [], []

for idx, row in tqdm(sub_df.iterrows(), total=len(sub_df), desc="Building subgraphs"):
    ckt = row["circuit_name"]
    lbl = int(row.get("label_subgraph", row.get("label", 0)))
    
    if ckt not in edges_by_circuit:
        continue
    
    # collect node uids from node.csv that belong to this circuit
    sub_nodes = nodes_df[nodes_df["circuit_name"]==ckt]["node"].tolist()
    if not sub_nodes: 
        continue
    
    uid_map = {n:i for i,n in enumerate(sub_nodes)}
    x_list = []
    for n in sub_nodes:
        uid = f"{ckt}::{n}"
        if uid in uid_to_feat:
            x_list.append(uid_to_feat[uid])
        else:
            x_list.append(np.zeros(feat_dim))
    x = torch.tensor(np.vstack(x_list), dtype=torch.float)
    
    # build edge_index
    edge_idx = [[], []]
    for u,v in edges_by_circuit[ckt]:
        if u in uid_map and v in uid_map:
            edge_idx[0].append(uid_map[u]); edge_idx[1].append(uid_map[v])
            edge_idx[0].append(uid_map[v]); edge_idx[1].append(uid_map[u])
    if not edge_idx[0]:
        continue
    edge_index = torch.tensor(edge_idx, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    data_list.append(data)
    labels.append(lbl)

print(f"Built {len(data_list)} subgraphs (usable)")

# -----------------------
# Split
# -----------------------
labels = np.array(labels)
idxs = np.arange(len(data_list))
train_idx, temp_idx, y_train, y_temp = train_test_split(idxs, labels, test_size=0.3, 
                                                        stratify=labels, random_state=RANDOM_SEED)
val_idx, test_idx, y_val, y_test = train_test_split(temp_idx, y_temp, test_size=0.5,
                                                    stratify=y_temp, random_state=RANDOM_SEED)

train_loader = DataLoader([data_list[i] for i in train_idx], batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader([data_list[i] for i in val_idx], batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader([data_list[i] for i in test_idx], batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

# -----------------------
# Model
# -----------------------
class SubgraphClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hid_dim)
        self.conv2 = SAGEConv(hid_dim, hid_dim)
        self.lin = nn.Linear(hid_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.4, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = SubgraphClassifier(feat_dim).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)

# class weights
cls_counts = np.bincount(labels)
w = torch.tensor([cls_counts.sum()/c for c in cls_counts], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=w)

# -----------------------
# Training loop
# -----------------------
def evaluate(loader):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for b in loader:
            b = b.to(DEVICE)
            out = model(b.x, b.edge_index, b.batch)
            p = out.argmax(dim=1).cpu().numpy()
            ys.extend(b.y.cpu().numpy())
            preds.extend(p)
    return np.array(ys), np.array(preds)

best_val, best_state = -1, None
for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss=0
    for b in train_loader:
        b = b.to(DEVICE)
        optimizer.zero_grad()
        out = model(b.x, b.edge_index, b.batch)
        loss = criterion(out, b.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch%5==0 or epoch==1:
        yv,pv = evaluate(val_loader)
        acc = accuracy_score(yv,pv)
        print(f"Epoch {epoch:03d} | Loss {total_loss/len(train_loader):.4f} | Val {acc:.4f}")
        if acc>best_val: 
            best_val=acc; best_state=model.state_dict().copy()

# load best
if best_state: model.load_state_dict(best_state)

# -----------------------
# Final test
# -----------------------
yt,pt = evaluate(test_loader)
print("\nFinal Evaluation (Subgraph-Level)")
print("=================================")
print(f"Test Accuracy: {accuracy_score(yt,pt):.4f}\n")
print("Classification Report:")
print(classification_report(yt,pt,digits=4))
print("Confusion Matrix:")
print(confusion_matrix(yt,pt))


Building subgraphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 674/674 [00:19<00:00, 34.49it/s]


Built 674 subgraphs (usable)
Train: 471, Val: 101, Test: 102
Epoch 001 | Loss 0.7464 | Val 0.2673
Epoch 005 | Loss 0.6193 | Val 0.6238
Epoch 010 | Loss 0.5896 | Val 0.5941
Epoch 015 | Loss 0.5101 | Val 0.6832
Epoch 020 | Loss 0.3784 | Val 0.8020
Epoch 025 | Loss 0.2963 | Val 0.8317
Epoch 030 | Loss 0.2182 | Val 0.8812
Epoch 035 | Loss 0.1549 | Val 0.9208
Epoch 040 | Loss 0.1265 | Val 0.9208
Epoch 045 | Loss 0.1028 | Val 0.9208
Epoch 050 | Loss 0.0897 | Val 0.9208
Epoch 055 | Loss 0.0840 | Val 0.9208
Epoch 060 | Loss 0.0851 | Val 0.9307
Epoch 065 | Loss 0.0641 | Val 0.9307
Epoch 070 | Loss 0.0713 | Val 0.9604
Epoch 075 | Loss 0.0604 | Val 0.9604
Epoch 080 | Loss 0.0645 | Val 0.9604

Final Evaluation (Subgraph-Level)
Test Accuracy: 0.9902

Classification Report:
              precision    recall  f1-score   support

           0     0.9167    1.0000    0.9565        11
           1     1.0000    0.9890    0.9945        91

    accuracy                         0.9902       102
   macro av

#### Graph Classification

At this level we want:
- Input: GNNDatasets/graph.csv (circuit labels)
- Edges: use GNNDatasets/graph_edges.csv (and optionally node_edges/subgraph_edges, but graph_edges.csv should already be the merged top-level edge file).
- Nodes: GNNDatasets/node.csv still provides features.

We’ll build each circuit as one graph, then classify whether it is clean or trojaned.
- Uses GINConv (better for graph classification than vanilla GCN/GraphSAGE).
- Loads graph-level edges (graph_edges.csv).
- Uses graph.csv for circuit labels.
- Includes class-weighted loss (handles class imbalance).
- Prints classification report + confusion matrix just like node/subgraph levels.

In [6]:
# train_graph_gnn.py
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, global_mean_pool

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
EPOCHS = 100
LR = 1e-3
HID_DIM = 64

# -----------------------
# Load node features
# -----------------------
nodes_df = pd.read_csv("GNNDatasets/node.csv")
nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# one-hot gate_type
if "gate_type" in nodes_df.columns:
    gate_ohe = pd.get_dummies(nodes_df["gate_type"], prefix="gt")
    nodes_feat_df = pd.concat([nodes_df.drop(columns=["gate_type"]), gate_ohe], axis=1)
else:
    nodes_feat_df = nodes_df.copy()

meta_cols = {"uid","node","circuit_name","label","label_node","label_graph","label_subgraph","folder"}
num_cols = [c for c in nodes_feat_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(nodes_feat_df[c])]

uid_to_feat = {}
for _, r in nodes_feat_df.iterrows():
    uid_to_feat[r["uid"]] = r[num_cols].astype(float).values

scaler = StandardScaler().fit(np.stack(list(uid_to_feat.values())))
for k in list(uid_to_feat.keys()):
    uid_to_feat[k] = scaler.transform(uid_to_feat[k].reshape(1,-1)).reshape(-1)

feat_dim = len(uid_to_feat[list(uid_to_feat.keys())[0]])

# -----------------------
# Load graph-level labels
# -----------------------
graph_df = pd.read_csv("GNNDatasets/graph.csv")
graph_labels = {r["circuit_name"]: int(r.get("label_graph", r.get("label", 0))) for _, r in graph_df.iterrows()}

# -----------------------
# Load graph-level edges
# -----------------------
edges_df = pd.read_csv("GNNDatasets/graph_edges.csv")
edges_by_circuit = defaultdict(list)
for _, r in edges_df.iterrows():
    edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# -----------------------
# Build dataset
# -----------------------
data_list, labels = [], []

for ckt, lbl in graph_labels.items():
    if ckt not in edges_by_circuit:
        continue
    
    # collect node list
    sub_nodes = nodes_df[nodes_df["circuit_name"]==ckt]["node"].tolist()
    if not sub_nodes: 
        continue
    
    uid_map = {n:i for i,n in enumerate(sub_nodes)}
    x_list = []
    for n in sub_nodes:
        uid = f"{ckt}::{n}"
        if uid in uid_to_feat:
            x_list.append(uid_to_feat[uid])
        else:
            x_list.append(np.zeros(feat_dim))
    x = torch.tensor(np.vstack(x_list), dtype=torch.float)
    
    # build edge_index
    edge_idx = [[], []]
    for u,v in edges_by_circuit[ckt]:
        if u in uid_map and v in uid_map:
            edge_idx[0].append(uid_map[u]); edge_idx[1].append(uid_map[v])
            edge_idx[0].append(uid_map[v]); edge_idx[1].append(uid_map[u])
    if not edge_idx[0]:
        continue
    edge_index = torch.tensor(edge_idx, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    data_list.append(data)
    labels.append(lbl)

print(f"Built {len(data_list)} graphs (usable)")

# -----------------------
# Split
# -----------------------
labels = np.array(labels)
idxs = np.arange(len(data_list))
train_idx, temp_idx, y_train, y_temp = train_test_split(idxs, labels, test_size=0.3, 
                                                        stratify=labels, random_state=RANDOM_SEED)
val_idx, test_idx, y_val, y_test = train_test_split(temp_idx, y_temp, test_size=0.5,
                                                    stratify=y_temp, random_state=RANDOM_SEED)

train_loader = DataLoader([data_list[i] for i in train_idx], batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader([data_list[i] for i in val_idx], batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader([data_list[i] for i in test_idx], batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

# -----------------------
# Model (GIN for graphs)
# -----------------------
class GraphClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(in_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim))
        self.conv1 = GINConv(nn1)
        nn2 = nn.Sequential(nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim))
        self.conv2 = GINConv(nn2)
        self.lin = nn.Linear(hid_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.4, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = GraphClassifier(feat_dim).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)

# class weights
cls_counts = np.bincount(labels)
w = torch.tensor([cls_counts.sum()/c for c in cls_counts], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=w)

# -----------------------
# Training loop
# -----------------------
def evaluate(loader):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for b in loader:
            b = b.to(DEVICE)
            out = model(b.x, b.edge_index, b.batch)
            p = out.argmax(dim=1).cpu().numpy()
            ys.extend(b.y.cpu().numpy())
            preds.extend(p)
    return np.array(ys), np.array(preds)

best_val, best_state = -1, None
for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss=0
    for b in train_loader:
        b = b.to(DEVICE)
        optimizer.zero_grad()
        out = model(b.x, b.edge_index, b.batch)
        loss = criterion(out, b.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch%5==0 or epoch==1:
        yv,pv = evaluate(val_loader)
        acc = accuracy_score(yv,pv)
        print(f"Epoch {epoch:03d} | Loss {total_loss/len(train_loader):.4f} | Val {acc:.4f}")
        if acc>best_val: 
            best_val=acc; best_state=model.state_dict().copy()

# load best
if best_state: model.load_state_dict(best_state)

# -----------------------
# Final test
# -----------------------
yt,pt = evaluate(test_loader)
print("\nFinal Evaluation (Graph-Level)")
print("=================================")
print(f"Test Accuracy: {accuracy_score(yt,pt):.4f}\n")
print("Classification Report:")
print(classification_report(yt,pt,digits=4))
print("Confusion Matrix:")
print(confusion_matrix(yt,pt))


Built 92 graphs (usable)
Train: 64, Val: 14, Test: 14




Epoch 001 | Loss 0.7626 | Val 0.7143
Epoch 005 | Loss 0.6486 | Val 0.2857
Epoch 010 | Loss 0.6627 | Val 0.5000
Epoch 015 | Loss 0.6499 | Val 0.2857
Epoch 020 | Loss 0.6382 | Val 0.2143
Epoch 025 | Loss 0.6565 | Val 0.3571
Epoch 030 | Loss 0.6186 | Val 0.2143
Epoch 035 | Loss 0.6348 | Val 0.1429
Epoch 040 | Loss 0.5834 | Val 0.2143
Epoch 045 | Loss 0.5532 | Val 0.2143
Epoch 050 | Loss 0.5080 | Val 0.2143
Epoch 055 | Loss 0.4923 | Val 0.2143
Epoch 060 | Loss 0.4683 | Val 0.3571
Epoch 065 | Loss 0.4470 | Val 0.3571
Epoch 070 | Loss 0.4055 | Val 0.4286
Epoch 075 | Loss 0.3791 | Val 0.6429
Epoch 080 | Loss 0.3507 | Val 0.7143
Epoch 085 | Loss 0.3147 | Val 0.7857
Epoch 090 | Loss 0.2813 | Val 0.7857
Epoch 095 | Loss 0.2523 | Val 0.7857
Epoch 100 | Loss 0.2401 | Val 0.7857

Final Evaluation (Graph-Level)
Test Accuracy: 0.5714

Classification Report:
              precision    recall  f1-score   support

           0     0.3333    0.5000    0.4000         4
           1     0.7500    0.6000   

An improved graph‑level training pipeline that combines two tactics to boost real performance on a small dataset:
- Stronger regularization + data balancing: oversample the minority class in the training split and use DropEdge augmentation during training.
- Model ensemble: train several models (different seeds and architectures — GIN, GraphSAGE, GCN) and average their predicted probabilities on the same test set.

This approach reduces variance on small datasets and makes decisions more robust than a single model. Save as train_graph_gnn_ensemble.py and run in the same folder as your GNNDatasets/* CSVs.

Quick notes on what this script does and why it should help: 
- Oversampling balances the small training set so the model sees enough positive (trojan) circuits during training.
- DropEdge is a light graph augmentation which acts like dropout for edges — it regularizes models and helps generalization on small graphs.
- Ensemble of different architectures & seeds reduces model variance and smooths out model-specific biases that were likely causing your earlier instability.
- Early stopping avoids overfitting on a very small validation set.

In [1]:
# train_graph_gnn_ensemble.py
import os
import random
import numpy as np
import pandas as pd
from collections import defaultdict
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, SAGEConv, GCNConv, global_mean_pool

# -----------------------
# Config
# -----------------------
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED); np.random.seed(RANDOM_SEED); random.seed(RANDOM_SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

N_ENSEMBLE = 5
BATCH_SIZE = 8
EPOCHS = 200
LR = 1e-3
HID_DIM = 64
DROPOUT_PROB = 0.4
EDGE_DROPOUT = 0.15  # DropEdge probability during training
PATIENCE = 30

# -----------------------
# Load node features (same scheme as before)
# -----------------------
nodes_df = pd.read_csv("GNNDatasets/node.csv")
nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# one-hot gate_type (if present)
if "gate_type" in nodes_df.columns:
    gate_ohe = pd.get_dummies(nodes_df["gate_type"], prefix="gt")
    nodes_feat_df = pd.concat([nodes_df.drop(columns=["gate_type"]), gate_ohe], axis=1)
else:
    nodes_feat_df = nodes_df.copy()

meta_cols = {"uid","node","circuit_name","label","label_node","label_graph","label_subgraph","folder"}
num_cols = [c for c in nodes_feat_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(nodes_feat_df[c])]

uid_to_feat = {}
for _, r in nodes_feat_df.iterrows():
    uid_to_feat[r["uid"]] = r[num_cols].astype(float).values

# Fit scaler on all nodes
all_feats = np.stack(list(uid_to_feat.values()))
scaler = StandardScaler().fit(all_feats)
for k in list(uid_to_feat.keys()):
    uid_to_feat[k] = scaler.transform(uid_to_feat[k].reshape(1,-1)).reshape(-1)
feat_dim = all_feats.shape[1]

# -----------------------
# Load graph labels
# -----------------------
graph_df = pd.read_csv("GNNDatasets/graph.csv")
graph_labels = {r["circuit_name"]: int(r.get("label_graph", r.get("label", 0))) for _, r in graph_df.iterrows()}

# -----------------------
# Load merged graph edges
# -----------------------
edges_df = pd.read_csv("GNNDatasets/graph_edges.csv")
edges_by_circuit = defaultdict(list)
for _, r in edges_df.iterrows():
    edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# -----------------------
# Build dataset (Data objects per circuit)
# -----------------------
data_list = []
labels = []
for ckt, lbl in graph_labels.items():
    if ckt not in edges_by_circuit:
        continue
    nodes_in_ckt = nodes_df[nodes_df["circuit_name"]==ckt]["node"].tolist()
    if not nodes_in_ckt:
        continue
    uid_map = {n:i for i,n in enumerate(nodes_in_ckt)}
    x_list = []
    for n in nodes_in_ckt:
        uid = f"{ckt}::{n}"
        if uid in uid_to_feat:
            x_list.append(uid_to_feat[uid])
        else:
            x_list.append(np.zeros(feat_dim))
    x = torch.tensor(np.vstack(x_list), dtype=torch.float)

    # build edge_index (undirected)
    srcs, dsts = [], []
    for u,v in edges_by_circuit[ckt]:
        if u in uid_map and v in uid_map:
            srcs.extend([uid_map[u], uid_map[v]])
            dsts.extend([uid_map[v], uid_map[u]])
    if len(srcs) == 0:
        continue
    edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    data_list.append(data)
    labels.append(lbl)

print(f"Built {len(data_list)} graphs (usable)")

if len(data_list) < 10:
    raise SystemExit("Too few graphs to train; consider expanding dataset or lowering filters.")

labels = np.array(labels)

# -----------------------
# Create one fixed stratified split (user wanted reproducible test set)
# -----------------------
idxs = np.arange(len(data_list))
train_idx, temp_idx, y_train, y_temp = train_test_split(idxs, labels, test_size=0.3,
                                                        stratify=labels, random_state=RANDOM_SEED)
val_idx, test_idx, y_val, y_test = train_test_split(temp_idx, y_temp, test_size=0.5,
                                                    stratify=y_temp, random_state=RANDOM_SEED)

train_dataset = [data_list[i] for i in train_idx]
val_dataset   = [data_list[i] for i in val_idx]
test_dataset  = [data_list[i] for i in test_idx]

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# -----------------------
# Oversample minority class in training set by simple replicate-with-replacement
# -----------------------
def oversample_dataset(dataset):
    ys = np.array([int(d.y.item()) for d in dataset])
    unique, counts = np.unique(ys, return_counts=True)
    if len(unique) == 1:
        return dataset  # nothing to do
    maxc = counts.max()
    new_list = []
    for cls in unique:
        cls_inds = [i for i,y in enumerate(ys) if y==cls]
        times = maxc // len(cls_inds)
        rem = maxc % len(cls_inds)
        for _ in range(times):
            for i in cls_inds:
                new_list.append(dataset[i])
        sel = np.random.choice(cls_inds, size=rem, replace=False)
        for i in sel:
            new_list.append(dataset[i])
    random.shuffle(new_list)
    return new_list

train_dataset_bal = oversample_dataset(train_dataset)
print(f"After oversampling train size: {len(train_dataset_bal)}")

# -----------------------
# Utility: edge dropout
# -----------------------
def drop_edges(edge_index, p):
    if p <= 0.0:
        return edge_index
    E = edge_index.size(1)
    mask = (torch.rand(E, device=edge_index.device) > p)
    # if mask is all false, keep at least one
    if mask.sum() == 0:
        idx = torch.randint(0, E, (1,), device=edge_index.device)
        mask[idx] = True
    return edge_index[:, mask]

# -----------------------
# Models (three types)
# -----------------------
class GINClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2, dropout=DROPOUT_PROB):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(in_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim))
        self.conv1 = GINConv(nn1)
        nn2 = nn.Sequential(nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim))
        self.conv2 = GINConv(nn2)
        self.lin = nn.Linear(hid_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

class SAGEClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2, dropout=DROPOUT_PROB):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hid_dim)
        self.conv2 = SAGEConv(hid_dim, hid_dim)
        self.lin = nn.Linear(hid_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index)); x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

class GCNClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2, dropout=DROPOUT_PROB):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim)
        self.lin = nn.Linear(hid_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index)); x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

archs = ["GIN", "SAGE", "GCN", "GIN", "SAGE"]  # cycle through

# -----------------------
# Training helpers
# -----------------------
def train_single_model(model, train_dataset, val_dataset, seed, edge_drop=EDGE_DROPOUT):
    torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
    model = model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)
    # class weights from train_dataset
    ytrain = np.array([int(d.y.item()) for d in train_dataset])
    unique, counts = np.unique(ytrain, return_counts=True)
    weights = None
    if len(unique) == 2:
        w_pos = (counts.sum() / counts[1]) / 2.0
        w_neg = (counts.sum() / counts[0]) / 2.0
        weights = torch.tensor([w_neg, w_pos], dtype=torch.float32, device=DEVICE)
    else:
        weights = torch.tensor([1.0, 1.0], device=DEVICE)
    criterion = nn.CrossEntropyLoss(weight=weights)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    best_val = -1.0; best_state = None; patience_cnt = 0
    for epoch in range(1, EPOCHS+1):
        model.train()
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(DEVICE)
            # apply edge dropout on per-batch edge_index
            ei = drop_edges(batch.edge_index, edge_drop)
            optimizer.zero_grad()
            out = model(batch.x, ei, batch.batch)
            loss = criterion(out, batch.y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs
        avg_loss = total_loss / len(train_dataset)
        # validation
        model.eval()
        with torch.no_grad():
            ys=[]; ps=[]
            for b in val_loader:
                b = b.to(DEVICE)
                out = model(b.x, b.edge_index, b.batch)
                p = out.argmax(dim=1).cpu().numpy()
                ys.extend(b.y.cpu().numpy()); ps.extend(p)
            if len(ys)>0:
                val_acc = accuracy_score(ys, ps)
            else:
                val_acc = 0.0
        # early stopping
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= PATIENCE:
                break
    if best_state is not None:
        model.load_state_dict(best_state)
    return model

def predict_probs(model, dataset):
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    model.eval()
    probs_all = []
    ys = []
    with torch.no_grad():
        for b in loader:
            b = b.to(DEVICE)
            logits = model(b.x, b.edge_index, b.batch)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            probs_all.append(probs)
            ys.extend(b.y.cpu().numpy())
    probs_all = np.vstack(probs_all)
    return np.array(ys), probs_all

# -----------------------
# Ensemble training
# -----------------------
test_probs_accum = None
true_test_y = None
per_model_test_accs = []

for m in range(N_ENSEMBLE):
    arch = archs[m % len(archs)]
    if arch == "GIN":
        model = GINClassifier(feat_dim, hid_dim=HID_DIM, out_dim=2)
    elif arch == "SAGE":
        model = SAGEClassifier(feat_dim, hid_dim=HID_DIM, out_dim=2)
    else:
        model = GCNClassifier(feat_dim, hid_dim=HID_DIM, out_dim=2)

    # use the balanced training dataset (oversampled) for every member
    trained = train_single_model(model, train_dataset_bal, val_dataset, seed=RANDOM_SEED + m, edge_drop=EDGE_DROPOUT)
    ys_test, probs_test = predict_probs(trained, test_dataset)
    preds = probs_test.argmax(axis=1)
    acc = accuracy_score(ys_test, preds)
    per_model_test_accs.append(acc)
    print(f"[Model {m+1}/{N_ENSEMBLE}] arch={arch} seed={RANDOM_SEED+m} test_acc={acc:.4f}")

    if test_probs_accum is None:
        test_probs_accum = probs_test.copy()
        true_test_y = ys_test.copy()
    else:
        # ensure test ordering aligns (it should: same test_dataset order)
        test_probs_accum += probs_test

# Average probabilities
test_probs_avg = test_probs_accum / float(N_ENSEMBLE)
test_preds_avg = test_probs_avg.argmax(axis=1)

# -----------------------
# Final report
# -----------------------
print("\nPer-model test accuracies:", per_model_test_accs)
print("\nEnsembled Test Results")
print("=======================")
print("Accuracy:", accuracy_score(true_test_y, test_preds_avg))
print("\nClassification Report:")
print(classification_report(true_test_y, test_preds_avg, digits=4))
print("\nConfusion Matrix:")
print(confusion_matrix(true_test_y, test_preds_avg))




Built 92 graphs (usable)
Train: 64, Val: 14, Test: 14
After oversampling train size: 96
[Model 1/5] arch=GIN seed=42 test_acc=0.5714
[Model 2/5] arch=SAGE seed=43 test_acc=0.4286
[Model 3/5] arch=GCN seed=44 test_acc=0.5714
[Model 4/5] arch=GIN seed=45 test_acc=0.7143
[Model 5/5] arch=SAGE seed=46 test_acc=0.5000

Per-model test accuracies: [0.5714285714285714, 0.42857142857142855, 0.5714285714285714, 0.7142857142857143, 0.5]

Ensembled Test Results
Accuracy: 0.6428571428571429

Classification Report:
              precision    recall  f1-score   support

           0     0.4000    0.5000    0.4444         4
           1     0.7778    0.7000    0.7368        10

    accuracy                         0.6429        14
   macro avg     0.5889    0.6000    0.5906        14
weighted avg     0.6698    0.6429    0.6533        14


Confusion Matrix:
[[2 2]
 [3 7]]


New Strategy — Pretraining + Fine-tuning: 
- Pretrain GNN encoders at node-level (where you already reached ~99% accuracy).
- Then use that encoder as initialization for graph classification — instead of starting from scratch.
- This leverages the fact that node embeddings already learned discriminative trojan vs. non-trojan features.

Here’s the end-to-end pipeline:
- Load node-level pretrained GCN weights
- Reuse encoder layers (dropout + convolution stack)
- Add a graph pooling head (e.g., global mean pooling) + classifier
- Fine-tune on graph-level dataset (GNNDatasets/graph.csv, GNNDatasets/graph_edges.csv)

In [5]:
# transfer_pretrain_node_finetune_graph.py
# Node pretraining (GCNConv) -> Graph fine-tune (reuse exact conv weights; no mapping)
# Safe: circuits in graph test set are excluded from node pretraining (prevents leakage).

import os, random, math
from collections import defaultdict
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# ----------------- Config -----------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NODE_CSV = "GNNDatasets/node.csv"
NODE_EDGE_CSV = "GNNDatasets/node_edges.csv"
GRAPH_CSV = "GNNDatasets/graph.csv"
GRAPH_EDGE_CSV = "GNNDatasets/graph_edges.csv"

PRETRAIN_CHECKPOINT = "node_gcn_pretrained.pth"

# Hyperparams (tweakable)
HID_DIM = 64
PRETRAIN_LR = 1e-3
PRETRAIN_EPOCHS = 100
PRETRAIN_BATCH = None   # full-graph pretraining uses adjacency; we do whole-graph training (no batch)

FT_HEAD_LR = 1e-3
FT_FINETUNE_LR = 5e-4
FT_EPOCHS_HEAD = 60
FT_EPOCHS_FINETUNE = 120
BATCH_SIZE = 16
EARLY_STOPPING = 30
DROPOUT = 0.35

# ----------------- Load graph-level labels and split circuits -----------------
graph_df = pd.read_csv(GRAPH_CSV)
# find graph label column
graph_label_col = None
for cand in ["label_graph", "label", "is_trojan", "trojan"]:
    if cand in graph_df.columns:
        graph_label_col = cand; break
if graph_label_col is None:
    graph_df["label_graph"] = graph_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    graph_label_col = "label_graph"

circuits = graph_df["circuit_name"].tolist()
graph_labels = [int(x) for x in graph_df[graph_label_col].tolist()]

# stratified split of circuits for final graph-level evaluation
train_circuits, temp_circuits, y_train_c, y_temp_c = train_test_split(
    circuits, graph_labels, test_size=0.30, random_state=SEED, stratify=graph_labels
)
val_circuits, test_circuits, y_val_c, y_test_c = train_test_split(
    temp_circuits, y_temp_c, test_size=0.50, random_state=SEED, stratify=y_temp_c
)

print(f"Circuits split -> train:{len(train_circuits)} val:{len(val_circuits)} test:{len(test_circuits)}")

# ----------------- Prepare node data for pretraining (exclude test circuits!) -----------------
nodes_df = pd.read_csv(NODE_CSV)
edges_df = pd.read_csv(NODE_EDGE_CSV)

# identify node label column
node_label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        node_label_col = cand; break
if node_label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    node_label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# Pretrain set circuits: combine train + val circuits (we exclude graph test circuits)
pretrain_circuits = set(train_circuits + val_circuits)
print(f"Pretraining will use nodes from {len(pretrain_circuits)} circuits (train+val) and exclude {len(test_circuits)} test circuits.")

# Filter nodes/edges for pretraining graph
nodes_pretrain_df = nodes_df[nodes_df["circuit_name"].isin(pretrain_circuits)].reset_index(drop=True)
edges_pretrain_df = edges_df[edges_df["circuit_name"].isin(pretrain_circuits)].reset_index(drop=True)

# feature columns (numeric) — exclude metadata
feat_df = nodes_pretrain_df.copy()
if "gate_type" in feat_df.columns:
    # one-hot encode gate_type
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name", node_label_col}
feature_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
if len(feature_cols) == 0:
    raise SystemExit("No numeric feature columns found in node CSV. Please include features.")

# Build X_all for pretraining nodes and mapping
X_pre = feat_df[feature_cols].fillna(0.0).to_numpy(dtype=np.float32)
y_pre = nodes_pretrain_df[node_label_col].to_numpy(dtype=np.int64)
uids_pre = nodes_pretrain_df["uid"].tolist()
uid_to_idx_pre = {u: i for i,u in enumerate(uids_pre)}

# Some edges may include nodes not present in nodes_pretrain_df (rare) — filter edges
def map_uid(signal, circuit):
    return f"{circuit}::{signal}"

edge_src_uids = edges_pretrain_df.apply(lambda r: map_uid(r["src"], r["circuit_name"]), axis=1)
edge_dst_uids = edges_pretrain_df.apply(lambda r: map_uid(r["dst"], r["circuit_name"]), axis=1)

edge_src_idx = edge_src_uids.map(uid_to_idx_pre).dropna().astype(int).values
edge_dst_idx = edge_dst_uids.map(uid_to_idx_pre).dropna().astype(int).values

if len(edge_src_idx) == 0:
    raise SystemExit("No edges left after filtering to pretrain circuits; check your GNNDatasets files.")

edge_index_pre = np.stack([np.concatenate([edge_src_idx, edge_dst_idx]),
                           np.concatenate([edge_dst_idx, edge_src_idx])], axis=0)  # undirected

# Scale features using labeled nodes only (within pretrain set)
# ensure labels are available
labeled_mask_pre = (y_pre >= 0)
scaler = StandardScaler()
X_pre_scaled = X_pre.copy()
if labeled_mask_pre.sum() == 0:
    raise SystemExit("No labeled nodes in pretraining set.")
X_pre_scaled[labeled_mask_pre] = scaler.fit_transform(X_pre_scaled[labeled_mask_pre])
X_pre_scaled[~labeled_mask_pre] = (X_pre_scaled[~labeled_mask_pre] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# Convert to torch
X_pre_t = torch.from_numpy(X_pre_scaled).to(DEVICE)
y_pre_t = torch.from_numpy(y_pre).to(DEVICE)
edge_index_pre_t = torch.from_numpy(edge_index_pre).long().to(DEVICE)

# Create train/val split (node-level) for early stopping on pretraining
idx_nodes = np.where(labeled_mask_pre)[0]
y_nodes = y_pre[labeled_mask_pre]
n_train_nodes, n_tmp_nodes = train_test_split(idx_nodes, test_size=0.30, random_state=SEED, stratify=y_nodes)
n_val_nodes, n_test_nodes = train_test_split(n_tmp_nodes, test_size=0.50, random_state=SEED,
                                             stratify=y_pre[n_tmp_nodes])

train_mask_nodes = torch.zeros(len(y_pre), dtype=torch.bool, device=DEVICE); train_mask_nodes[n_train_nodes] = True
val_mask_nodes   = torch.zeros(len(y_pre), dtype=torch.bool, device=DEVICE);   val_mask_nodes[n_val_nodes] = True
# note: we won't use node_test_nodes later

# ----------------- Node GCN pretraining (PyG-style conv) -----------------
class NodeGCN(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, dropout=DROPOUT):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim)
        self.head = nn.Linear(hid_dim, 2)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return self.head(x)  # per-node logits

node_model = NodeGCN(in_dim=X_pre_t.shape[1], hid_dim=HID_DIM).to(DEVICE)

# class weights on node-level training labels
train_labels_nodes = y_pre_t[train_mask_nodes]
classes, counts = torch.unique(train_labels_nodes, return_counts=True)
num_pos = int((train_labels_nodes==1).sum().item()) if (train_labels_nodes==1).any() else 1
num_neg = int((train_labels_nodes==0).sum().item()) if (train_labels_nodes==0).any() else 1
w_pos = (num_neg + num_pos) / (2.0 * num_pos)
w_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights_nodes = torch.tensor([w_neg, w_pos], dtype=torch.float32, device=DEVICE)
crit_node = nn.CrossEntropyLoss(weight=class_weights_nodes)
opt_node = torch.optim.Adam(node_model.parameters(), lr=PRETRAIN_LR, weight_decay=5e-4)

# Pretraining loop
best_val = -1.0; best_state = None; patience_cnt = 0
print("Node pretraining sizes (labeled): train=%d val=%d" % (train_mask_nodes.sum().item(), val_mask_nodes.sum().item()))
for epoch in range(1, PRETRAIN_EPOCHS+1):
    node_model.train()
    opt_node.zero_grad()
    logits_nodes = node_model(X_pre_t, edge_index_pre_t)
    loss = crit_node(logits_nodes[train_mask_nodes], y_pre_t[train_mask_nodes])
    loss.backward()
    opt_node.step()
    if epoch % 5 == 0 or epoch == 1:
        node_model.eval()
        with torch.no_grad():
            logits_val = node_model(X_pre_t, edge_index_pre_t)
            preds_val = logits_val.argmax(dim=1)
            if val_mask_nodes.sum() > 0:
                val_acc = (preds_val[val_mask_nodes] == y_pre_t[val_mask_nodes]).float().mean().item()
            else:
                val_acc = 0.0
        print(f"Pretrain Epoch {epoch:03d} | Loss {loss.item():.4f} | Val {val_acc:.4f}")
        if val_acc > best_val + 1e-5:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in node_model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= EARLY_STOPPING:
                print("Early stopping node pretrain.")
                break

if best_state is not None:
    node_model.load_state_dict(best_state)
torch.save(node_model.state_dict(), PRETRAIN_CHECKPOINT)
print("Saved node-level model checkpoint:", PRETRAIN_CHECKPOINT)

# ----------------- Build graph-level dataset (per-circuit Data objects) -----------------
print("\nBuilding graph-level Data objects for fine-tuning (train/val/test circuits) ...")
# load full nodes/edges for graphs (use nodes_df, edges_df from earlier)
nodes_full_df = pd.read_csv(NODE_CSV)
edges_full_df = pd.read_csv(GRAPH_EDGE_CSV)

# Prepare feature scaler: use the scaler fitted during pretraining (we already have StandardScaler scaler)
# For nodes not seen in pretrain set, we'll apply same scaler transform using scaler.mean_/var_
def node_uid(circuit, node):
    return f"{circuit}::{node}"

# build uid->feature map for all nodes (apply scaler to full dataset)
feat_full_df = nodes_full_df.copy()
if "gate_type" in feat_full_df.columns:
    gate_oh = pd.get_dummies(feat_full_df["gate_type"], prefix="gt")
    feat_full_df = pd.concat([feat_full_df.drop(columns=["gate_type"]), gate_oh], axis=1)
# ensure feature columns compatible: if full has extra gate dummies or missing ones compared to pretrain, align
for col in feature_cols:
    if col not in feat_full_df.columns:
        feat_full_df[col] = 0.0
feat_full_df = feat_full_df[["circuit_name","node"] + feature_cols]

# scale using pretrain scaler (note: scaler was fitted on pretrain labeled nodes' features)
full_X = feat_full_df[feature_cols].fillna(0.0).to_numpy(dtype=np.float32)
full_X_scaled = (full_X - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)  # consistent transform
feat_full_df["scaled_feat"] = list(full_X_scaled.tolist())

# build edges by circuit
edges_full_df["src_uid"] = edges_full_df["circuit_name"].astype(str) + "::" + edges_full_df["src"].astype(str)
edges_full_df["dst_uid"] = edges_full_df["circuit_name"].astype(str) + "::" + edges_full_df["dst"].astype(str)
edges_by_circuit = defaultdict(list)
for _, r in edges_full_df.iterrows():
    edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# build per-circuit Data (only circuits present in graph_labels)
graph_data_list = []
graph_names = []
graph_target = []
for _, row in graph_df.iterrows():
    ckt = row["circuit_name"]
    lbl = int(row[graph_label_col])
    # nodes of this circuit
    sub_nodes = feat_full_df[feat_full_df["circuit_name"]==ckt]
    if sub_nodes.shape[0] == 0: 
        continue
    node_names = sub_nodes["node"].tolist()
    uid_map = {n:i for i,n in enumerate(node_names)}
    X_nodes = np.vstack(sub_nodes["scaled_feat"].values).astype(np.float32)
    # build edge_index
    srcs, dsts = [], []
    if ckt in edges_by_circuit:
        for u,v in edges_by_circuit[ckt]:
            if u in uid_map and v in uid_map:
                srcs.extend([uid_map[u], uid_map[v]])
                dsts.extend([uid_map[v], uid_map[u]])
    if len(srcs) == 0:
        # skip graphs without edges (unlikely)
        continue
    edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
    data = Data(x=torch.tensor(X_nodes, dtype=torch.float), edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    graph_data_list.append(data)
    graph_names.append(ckt)
    graph_target.append(lbl)

print(f"Built {len(graph_data_list)} graphs.")

# Create train/val/test lists by circuit split we made earlier
def filter_by_circuit(list_data, circuits_set):
    idxs = [i for i,d in enumerate(list_data) if d.circuit_name in circuits_set]
    return [list_data[i] for i in idxs]

train_graphs = filter_by_circuit(graph_data_list, set(train_circuits))
val_graphs   = filter_by_circuit(graph_data_list, set(val_circuits))
test_graphs  = filter_by_circuit(graph_data_list, set(test_circuits))

print(f"Graph counts -> train: {len(train_graphs)}, val: {len(val_graphs)}, test: {len(test_graphs)}")

# ----------------- Graph classifier reusing conv layers from node_model -----------------
class GraphClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, dropout=DROPOUT):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hid_dim, 2)
    def forward(self, x, edge_index, batch):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.conv2(h, edge_index)
        g = global_mean_pool(h, batch)
        return self.classifier(g)

graph_model = GraphClassifier(in_dim=X_pre_t.shape[1], hid_dim=HID_DIM).to(DEVICE)

# Copy conv weights from node_model -> graph_model (shapes align because both use GCNConv)
node_state = torch.load(PRETRAIN_CHECKPOINT, map_location="cpu")
# node_state contains keys: conv1.lin.weight, conv1.lin.bias? Check keys
for k_src, v_src in node_state.items():
    if "conv1" in k_src and "weight" in k_src:
        # copy to graph_model conv1 weight if exists
        if k_src in graph_model.state_dict() and graph_model.state_dict()[k_src].shape == v_src.shape:
            graph_model.state_dict()[k_src].copy_(v_src)
    if "conv2" in k_src and "weight" in k_src:
        if k_src in graph_model.state_dict() and graph_model.state_dict()[k_src].shape == v_src.shape:
            graph_model.state_dict()[k_src].copy_(v_src)

from torch_geometric.loader import DataLoader

# --- replace the function definition (rename it) ---
def train_graphs_fn(model, train_set, val_set, test_set, freeze_encoder=True):
    # dataloaders
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

    if freeze_encoder:
        for p in model.conv1.parameters(): p.requires_grad = False
        for p in model.conv2.parameters(): p.requires_grad = False
    else:
        for p in model.parameters(): p.requires_grad = True

    # class weights
    ytrain = np.array([int(d.y.item()) for d in train_set])
    if len(np.unique(ytrain)) == 2:
        counts = np.bincount(ytrain); w = torch.tensor([ (counts.sum()/counts[0]), (counts.sum()/counts[1]) ], dtype=torch.float32).to(DEVICE)
        criterion = nn.CrossEntropyLoss(weight=w)
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=FT_HEAD_LR if freeze_encoder else FT_FINETUNE_LR, weight_decay=5e-4)
    best_val = -1.0; best_state = None; pcount = 0

    for epoch in range(1, FT_EPOCHS_HEAD + 1 if freeze_encoder else FT_EPOCHS_FINETUNE + 1):
        model.train()
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(DEVICE)
            optimizer.zero_grad()
            logits = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(logits, batch.y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs

        # validation
        model.eval()
        ys, ps = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(DEVICE)
                out = model(batch.x, batch.edge_index, batch.batch)
                preds = out.argmax(dim=1)
                ys.extend(batch.y.cpu().numpy()); ps.extend(preds.cpu().numpy())
        val_acc = accuracy_score(ys, ps) if len(ys)>0 else 0.0
        if epoch % 5 == 0 or epoch == 1:
            print(f"Fine-tune ({'frozen' if freeze_encoder else 'all'}) Epoch {epoch:03d} | Val Acc {val_acc:.4f} | AvgLoss {total_loss / max(1,len(train_set)):.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            pcount = 0
        else:
            pcount += 1
            if pcount >= EARLY_STOPPING:
                print("Early stopping fine-tune stage.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    # test eval
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(DEVICE)
            out = model(batch.x, batch.edge_index, batch.batch)
            preds = out.argmax(dim=1)
            ys.extend(batch.y.cpu().numpy()); ps.extend(preds.cpu().numpy())
    test_acc = accuracy_score(ys, ps) if len(ys)>0 else 0.0
    return test_acc, ys, ps

# Stage 1: freeze encoder, train head
print("\nStage 1: freeze encoder and train classifier head")
for p in graph_model.conv1.parameters(): p.requires_grad = False
for p in graph_model.conv2.parameters(): p.requires_grad = False
# Stage 1 call:
test_acc_1, ys1, ps1 = train_graphs_fn(graph_model, train_graphs, val_graphs, test_graphs, freeze_encoder=True)
print("After head training -> Test Acc: {:.4f}".format(test_acc_1))

# Stage 2: unfreeze and fine-tune all (lower lr)
print("\nStage 2: unfreeze encoder and fine-tune entire model")
for p in graph_model.parameters(): p.requires_grad = True
# update global lr for finetune
FT_HEAD_LR = FT_FINETUNE_LR
# Stage 2 call:
test_acc_2, ys2, ps2 = train_graphs_fn(graph_model, train_graphs, val_graphs, test_graphs, freeze_encoder=False)
print("After full fine-tune -> Test Acc: {:.4f}".format(test_acc_2))

# Final report
final_ys, final_ps = (ys2, ps2) if len(ys2)>0 else (ys1, ps1)
print("\nFinal Evaluation (Graph-level after transfer)")
print("=============================================")
print(f"Test Accuracy: {accuracy_score(final_ys, final_ps):.4f}\n")
print("Classification Report:")
print(classification_report(final_ys, final_ps, digits=4))
print("Confusion Matrix:")
print(confusion_matrix(final_ys, final_ps))


Circuits split -> train:64 val:14 test:14
Pretraining will use nodes from 78 circuits (train+val) and exclude 14 test circuits.
Node pretraining sizes (labeled): train=130706 val=28008
Pretrain Epoch 001 | Loss 0.7781 | Val 0.6488
Pretrain Epoch 005 | Loss 0.6883 | Val 0.6977
Pretrain Epoch 010 | Loss 0.5939 | Val 0.9208
Pretrain Epoch 015 | Loss 0.5102 | Val 0.9453
Pretrain Epoch 020 | Loss 0.4307 | Val 0.9594
Pretrain Epoch 025 | Loss 0.3524 | Val 0.9863
Pretrain Epoch 030 | Loss 0.2770 | Val 0.9936
Pretrain Epoch 035 | Loss 0.2106 | Val 0.9944
Pretrain Epoch 040 | Loss 0.1548 | Val 0.9961
Pretrain Epoch 045 | Loss 0.1138 | Val 0.9991
Pretrain Epoch 050 | Loss 0.0838 | Val 0.9996
Pretrain Epoch 055 | Loss 0.0618 | Val 0.9997
Pretrain Epoch 060 | Loss 0.0478 | Val 0.9998
Pretrain Epoch 065 | Loss 0.0380 | Val 0.9998
Pretrain Epoch 070 | Loss 0.0316 | Val 0.9998
Pretrain Epoch 075 | Loss 0.0256 | Val 0.9998
Pretrain Epoch 080 | Loss 0.0221 | Val 0.9998
Pretrain Epoch 085 | Loss 0.0190 

  node_state = torch.load(PRETRAIN_CHECKPOINT, map_location="cpu")


Fine-tune (frozen) Epoch 001 | Val Acc 1.0000 | AvgLoss 0.5945
Fine-tune (frozen) Epoch 005 | Val Acc 1.0000 | AvgLoss 0.4577
Fine-tune (frozen) Epoch 010 | Val Acc 1.0000 | AvgLoss 0.2530
Fine-tune (frozen) Epoch 015 | Val Acc 1.0000 | AvgLoss 0.1850
Fine-tune (frozen) Epoch 020 | Val Acc 1.0000 | AvgLoss 0.1366
Fine-tune (frozen) Epoch 025 | Val Acc 1.0000 | AvgLoss 0.0988
Fine-tune (frozen) Epoch 030 | Val Acc 1.0000 | AvgLoss 0.1019
Early stopping fine-tune stage.
After head training -> Test Acc: 0.9286

Stage 2: unfreeze encoder and fine-tune entire model
Fine-tune (all) Epoch 001 | Val Acc 1.0000 | AvgLoss 0.5179
Fine-tune (all) Epoch 005 | Val Acc 1.0000 | AvgLoss 0.3262
Fine-tune (all) Epoch 010 | Val Acc 1.0000 | AvgLoss 0.2033
Fine-tune (all) Epoch 015 | Val Acc 1.0000 | AvgLoss 0.1076
Fine-tune (all) Epoch 020 | Val Acc 1.0000 | AvgLoss 0.0639
Fine-tune (all) Epoch 025 | Val Acc 1.0000 | AvgLoss 0.0353
Fine-tune (all) Epoch 030 | Val Acc 1.0000 | AvgLoss 0.0246
Early stoppin