#### Node Classification: Trojan vs non-trojan. 

Below is a single, end‑to‑end GCN training script that:

reads GNNDatasets/node.csv + GNNDatasets/node_edges.csv

auto‑builds a multi‑circuit graph (nodes keyed by circuit_name::node)

one‑hot encodes gate_type, standardizes numeric features

adds zero‑feature pseudo nodes for edge‑only items (ASSIGN_*, DFF_*, PIs/POs)

trains a 2‑layer GCN with class‑weighted loss + early stopping

reports accuracy, classification report, confusion matrix

In [3]:
# train_gcn_node_fixed.py
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

NODE_CSV  = "GNNDatasets/node.csv"
EDGE_CSV  = "GNNDatasets/node_edges.csv"
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)

# ----------------------------- Load nodes -----------------------------
nodes_df = pd.read_csv(NODE_CSV)

label_col = None
for cand in ["label", "is_trojan", "trojan", "target"]:
    if cand in nodes_df.columns:
        label_col = cand; break
if label_col is None:
    nodes_df["label"] = nodes_df["circuit_name"].astype(str).str.contains("__trojan_").astype(int)
    label_col = "label"

nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

feat_df = nodes_df.copy()
if "gate_type" in feat_df.columns:
    gate_oh = pd.get_dummies(feat_df["gate_type"], prefix="gt")
    feat_df = pd.concat([feat_df.drop(columns=["gate_type"]), gate_oh], axis=1)

exclude = {"uid","node","circuit_name",label_col}
num_cols = [c for c in feat_df.columns if c not in exclude and pd.api.types.is_numeric_dtype(feat_df[c])]
X = feat_df[num_cols].fillna(0.0).values.astype(np.float32)
y = nodes_df[label_col].values.astype(np.int64)

# ----------------------------- Load edges; add missing nodes -----------------------------
edges_df = pd.read_csv(EDGE_CSV)
edges_df["src_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["src"].astype(str)
edges_df["dst_uid"] = edges_df["circuit_name"].astype(str) + "::" + edges_df["dst"].astype(str)

known_uids = set(nodes_df["uid"])
edge_uids = set(edges_df["src_uid"]).union(set(edges_df["dst_uid"]))
missing = list(edge_uids - known_uids)

if missing:
    zero_row = np.zeros((1, X.shape[1]), dtype=np.float32)
    addX = np.repeat(zero_row, len(missing), axis=0)
    addY = -1*np.ones(len(missing), dtype=np.int64)
    add_df = pd.DataFrame({
        "uid": missing,
        "circuit_name": [u.split("::",1)[0] for u in missing],
        "node": [u.split("::",1)[1] for u in missing],
        label_col: addY
    })
    X = np.vstack([X, addX])
    y = np.concatenate([y, addY])
    nodes_df = pd.concat([nodes_df, add_df], ignore_index=True)

uid_to_idx = {u:i for i,u in enumerate(nodes_df["uid"].tolist())}
src_idx = edges_df["src_uid"].map(uid_to_idx).dropna().astype(int).values
dst_idx = edges_df["dst_uid"].map(uid_to_idx).dropna().astype(int).values
edge_index = np.stack([np.concatenate([src_idx, dst_idx]),
                       np.concatenate([dst_idx, src_idx])], axis=0)

# ----------------------------- Scale features -----------------------------
labeled_mask_np = (y >= 0)
scaler = StandardScaler()
X_scaled = X.copy()
X_scaled[labeled_mask_np] = scaler.fit_transform(X_scaled[labeled_mask_np])
if (~labeled_mask_np).any():
    X_scaled[~labeled_mask_np] = (X_scaled[~labeled_mask_np] - scaler.mean_) / np.sqrt(scaler.var_ + 1e-8)

# ----------------------------- Splits -----------------------------
idx_all = np.where(labeled_mask_np)[0]
y_all = y[labeled_mask_np]

idx_train, idx_tmp, y_train, y_tmp = train_test_split(
    idx_all, y_all, test_size=0.30, random_state=SEED, stratify=y_all
)
idx_val, idx_test, y_val, y_test = train_test_split(
    idx_tmp, y_tmp, test_size=0.50, random_state=SEED, stratify=y_tmp
)

# ----------------------------- Torch tensors (FIX: masks as torch.bool) -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_t = torch.from_numpy(X_scaled).to(device)
y_t = torch.from_numpy(y).to(device)

edge_index_t = torch.from_numpy(edge_index).long().to(device)

train_mask_t = torch.zeros(len(y), dtype=torch.bool, device=device); train_mask_t[idx_train] = True
val_mask_t   = torch.zeros(len(y), dtype=torch.bool, device=device); val_mask_t[idx_val]   = True
test_mask_t  = torch.zeros(len(y), dtype=torch.bool, device=device); test_mask_t[idx_test]  = True
labeled_mask_t = torch.from_numpy(labeled_mask_np).to(device)

# ----------------------------- Build GCN adjacency -----------------------------
def build_adj(num_nodes, edge_index):
    self_loops = torch.arange(num_nodes, device=edge_index.device)
    ei = torch.cat([edge_index, torch.stack([self_loops, self_loops])], dim=1)
    deg = torch.bincount(ei[0], minlength=num_nodes).float()
    deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
    w = deg_inv_sqrt[ei[0]] * deg_inv_sqrt[ei[1]]
    A = torch.sparse_coo_tensor(ei, w, (num_nodes, num_nodes))
    return A.coalesce()

A_t = build_adj(X_t.size(0), edge_index_t)

# ----------------------------- Model -----------------------------
class GCNLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.0):
        super().__init__()
        self.lin = nn.Linear(in_dim, out_dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.lin.weight)
    def forward(self, x, adj):
        x = self.dropout(x)
        x = torch.sparse.mm(adj, x)
        x = self.lin(x)
        return x

class GCN(nn.Module):
    def __init__(self, in_dim, hid_dim=96, out_dim=2, dropout=0.35):
        super().__init__()
        self.g1 = GCNLayer(in_dim, hid_dim, dropout)
        self.g2 = GCNLayer(hid_dim, out_dim, dropout)
        self.do = nn.Dropout(dropout)
    def forward(self, x, adj):
        x = self.g1(x, adj); x = F.relu(x); x = self.do(x)
        x = self.g2(x, adj)
        return x

model = GCN(in_dim=X_t.size(1), hid_dim=96, out_dim=2, dropout=0.35).to(device)

# ----------------------------- Loss, optimizer -----------------------------
train_labels = y_t[train_mask_t]
classes, counts = torch.unique(train_labels, return_counts=True)
num_pos = counts[classes==1].item() if (classes==1).any() else 1
num_neg = counts[classes==0].item() if (classes==0).any() else 1
weight_pos = (num_neg + num_pos) / (2.0 * num_pos)
weight_neg = (num_neg + num_pos) / (2.0 * num_neg)
class_weights = torch.tensor([weight_neg, weight_pos], dtype=torch.float32, device=device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3, weight_decay=5e-4)

# ----------------------------- Training -----------------------------
def evaluate(mask_t):
    model.eval()
    with torch.no_grad():
        logits = model(X_t, A_t)
        pred = logits.argmax(dim=1)
        msk = mask_t & (y_t >= 0)
        if msk.sum() == 0: return 0.0
        return (pred[msk] == y_t[msk]).float().mean().item()

best_val, best_state = -1.0, None
patience, patience_cnt = 20, 0
EPOCHS = 300

for epoch in range(1, EPOCHS+1):
    model.train()
    optimizer.zero_grad()
    logits = model(X_t, A_t)
    loss = criterion(logits[train_mask_t], y_t[train_mask_t])
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    optimizer.step()

    if epoch % 10 == 0 or epoch == 1:
        val_acc = evaluate(val_mask_t)
        test_acc = evaluate(test_mask_t)
        print(f"Epoch {epoch:03d} | Loss {loss.item():.4f} | Val {val_acc:.4f} | Test {test_acc:.4f}")
        if val_acc > best_val + 1e-4:
            best_val = val_acc
            best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print("Early stopping."); break

if best_state is not None:
    model.load_state_dict(best_state)

# ----------------------------- Final eval -----------------------------
model.eval()
with torch.no_grad():
    logits = model(X_t, A_t)
    preds = logits.argmax(dim=1)

msk = (test_mask_t & (y_t >= 0)).cpu().numpy()
y_true = y_t.cpu().numpy()[msk]
y_pred = preds.cpu().numpy()[msk]

acc = (y_true == y_pred).mean()
print("\nFinal Evaluation (Node-Level)")
print("=============================")
print(f"Test Accuracy: {acc:.4f}\n")

print("Classification Report:")
print(classification_report(y_true, y_pred, labels=[0,1], target_names=["clean","trojan"], digits=4))

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred, labels=[0,1]))


Epoch 001 | Loss 0.8937 | Val 0.4607 | Test 0.4631
Epoch 010 | Loss 0.6115 | Val 0.9505 | Test 0.9496
Epoch 020 | Loss 0.4193 | Val 0.9883 | Test 0.9885
Epoch 030 | Loss 0.2920 | Val 0.9926 | Test 0.9931
Epoch 040 | Loss 0.2113 | Val 0.9981 | Test 0.9981
Epoch 050 | Loss 0.1559 | Val 0.9994 | Test 0.9995
Epoch 060 | Loss 0.1176 | Val 0.9999 | Test 0.9998
Epoch 070 | Loss 0.0936 | Val 0.9999 | Test 0.9998
Epoch 080 | Loss 0.0780 | Val 1.0000 | Test 1.0000
Epoch 090 | Loss 0.0668 | Val 1.0000 | Test 1.0000
Epoch 100 | Loss 0.0553 | Val 1.0000 | Test 1.0000
Epoch 110 | Loss 0.0493 | Val 1.0000 | Test 1.0000
Epoch 120 | Loss 0.0435 | Val 1.0000 | Test 1.0000
Epoch 130 | Loss 0.0407 | Val 1.0000 | Test 1.0000
Epoch 140 | Loss 0.0371 | Val 1.0000 | Test 1.0000
Epoch 150 | Loss 0.0346 | Val 1.0000 | Test 1.0000
Epoch 160 | Loss 0.0329 | Val 1.0000 | Test 1.0000
Epoch 170 | Loss 0.0301 | Val 1.0000 | Test 1.0000
Epoch 180 | Loss 0.0283 | Val 1.0000 | Test 1.0000
Epoch 190 | Loss 0.0273 | Val 1

#### Subgraph Classification

Below is a complete, end‑to‑end script for subgraph‑level classification. It:

Loads GNNDatasets/subgraph.csv, GNNDatasets/node.csv, and GNNDatasets/node_edges.csv.

Reconstructs per‑circuit graphs (NetworkX) and extracts each subgraph as the K‑hop ego graph around center_node (uses K=2, same as before).

Builds a PyTorch‑Geometric Data object per subgraph with node features taken from node.csv (numeric features + one‑hot gate type); missing nodes get zero features.

Trains a GraphSAGE model (two layers) with global mean pooling for graph classification (label_subgraph: 0/1).

Prints accuracy, classification report, and confusion matrix.

In [5]:
# train_subgraph_gnn_fixed.py
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv, global_mean_pool

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 80
LR = 1e-3
HID_DIM = 64

# -----------------------
# Load node features
# -----------------------
nodes_df = pd.read_csv("GNNDatasets/node.csv")
nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# one-hot gate_type
if "gate_type" in nodes_df.columns:
    gate_ohe = pd.get_dummies(nodes_df["gate_type"], prefix="gt")
    nodes_feat_df = pd.concat([nodes_df.drop(columns=["gate_type"]), gate_ohe], axis=1)
else:
    nodes_feat_df = nodes_df.copy()

meta_cols = {"uid","node","circuit_name","label","label_node","label_graph","label_subgraph","folder"}
num_cols = [c for c in nodes_feat_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(nodes_feat_df[c])]

uid_to_feat = {}
for _, r in nodes_feat_df.iterrows():
    uid_to_feat[r["uid"]] = r[num_cols].astype(float).values

scaler = StandardScaler().fit(np.stack(list(uid_to_feat.values())))
for k in list(uid_to_feat.keys()):
    uid_to_feat[k] = scaler.transform(uid_to_feat[k].reshape(1,-1)).reshape(-1)

feat_dim = len(uid_to_feat[list(uid_to_feat.keys())[0]])

# -----------------------
# Merge edge CSVs
# -----------------------
edge_files = [
    "GNNDatasets/node_edges.csv",
    "GNNDatasets/subgraph_edges_andxor.csv",
    "GNNDatasets/subgraph_edges_countermux.csv",
    "GNNDatasets/subgraph_edges_fsmor.csv",
]

edges_by_circuit = defaultdict(list)
for ef in edge_files:
    df = pd.read_csv(ef)
    for _, r in df.iterrows():
        edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# -----------------------
# Build dataset
# -----------------------
sub_df = pd.read_csv("GNNDatasets/subgraph.csv")
data_list, labels = [], []

for idx, row in tqdm(sub_df.iterrows(), total=len(sub_df), desc="Building subgraphs"):
    ckt = row["circuit_name"]
    lbl = int(row.get("label_subgraph", row.get("label", 0)))
    
    if ckt not in edges_by_circuit:
        continue
    
    # collect node uids from node.csv that belong to this circuit
    sub_nodes = nodes_df[nodes_df["circuit_name"]==ckt]["node"].tolist()
    if not sub_nodes: 
        continue
    
    uid_map = {n:i for i,n in enumerate(sub_nodes)}
    x_list = []
    for n in sub_nodes:
        uid = f"{ckt}::{n}"
        if uid in uid_to_feat:
            x_list.append(uid_to_feat[uid])
        else:
            x_list.append(np.zeros(feat_dim))
    x = torch.tensor(np.vstack(x_list), dtype=torch.float)
    
    # build edge_index
    edge_idx = [[], []]
    for u,v in edges_by_circuit[ckt]:
        if u in uid_map and v in uid_map:
            edge_idx[0].append(uid_map[u]); edge_idx[1].append(uid_map[v])
            edge_idx[0].append(uid_map[v]); edge_idx[1].append(uid_map[u])
    if not edge_idx[0]:
        continue
    edge_index = torch.tensor(edge_idx, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    data_list.append(data)
    labels.append(lbl)

print(f"Built {len(data_list)} subgraphs (usable)")

# -----------------------
# Split
# -----------------------
labels = np.array(labels)
idxs = np.arange(len(data_list))
train_idx, temp_idx, y_train, y_temp = train_test_split(idxs, labels, test_size=0.3, 
                                                        stratify=labels, random_state=RANDOM_SEED)
val_idx, test_idx, y_val, y_test = train_test_split(temp_idx, y_temp, test_size=0.5,
                                                    stratify=y_temp, random_state=RANDOM_SEED)

train_loader = DataLoader([data_list[i] for i in train_idx], batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader([data_list[i] for i in val_idx], batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader([data_list[i] for i in test_idx], batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

# -----------------------
# Model
# -----------------------
class SubgraphClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hid_dim)
        self.conv2 = SAGEConv(hid_dim, hid_dim)
        self.lin = nn.Linear(hid_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.4, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = SubgraphClassifier(feat_dim).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)

# class weights
cls_counts = np.bincount(labels)
w = torch.tensor([cls_counts.sum()/c for c in cls_counts], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=w)

# -----------------------
# Training loop
# -----------------------
def evaluate(loader):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for b in loader:
            b = b.to(DEVICE)
            out = model(b.x, b.edge_index, b.batch)
            p = out.argmax(dim=1).cpu().numpy()
            ys.extend(b.y.cpu().numpy())
            preds.extend(p)
    return np.array(ys), np.array(preds)

best_val, best_state = -1, None
for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss=0
    for b in train_loader:
        b = b.to(DEVICE)
        optimizer.zero_grad()
        out = model(b.x, b.edge_index, b.batch)
        loss = criterion(out, b.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch%5==0 or epoch==1:
        yv,pv = evaluate(val_loader)
        acc = accuracy_score(yv,pv)
        print(f"Epoch {epoch:03d} | Loss {total_loss/len(train_loader):.4f} | Val {acc:.4f}")
        if acc>best_val: 
            best_val=acc; best_state=model.state_dict().copy()

# load best
if best_state: model.load_state_dict(best_state)

# -----------------------
# Final test
# -----------------------
yt,pt = evaluate(test_loader)
print("\nFinal Evaluation (Subgraph-Level)")
print("=================================")
print(f"Test Accuracy: {accuracy_score(yt,pt):.4f}\n")
print("Classification Report:")
print(classification_report(yt,pt,digits=4))
print("Confusion Matrix:")
print(confusion_matrix(yt,pt))


Building subgraphs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 674/674 [00:19<00:00, 34.49it/s]


Built 674 subgraphs (usable)
Train: 471, Val: 101, Test: 102
Epoch 001 | Loss 0.7464 | Val 0.2673
Epoch 005 | Loss 0.6193 | Val 0.6238
Epoch 010 | Loss 0.5896 | Val 0.5941
Epoch 015 | Loss 0.5101 | Val 0.6832
Epoch 020 | Loss 0.3784 | Val 0.8020
Epoch 025 | Loss 0.2963 | Val 0.8317
Epoch 030 | Loss 0.2182 | Val 0.8812
Epoch 035 | Loss 0.1549 | Val 0.9208
Epoch 040 | Loss 0.1265 | Val 0.9208
Epoch 045 | Loss 0.1028 | Val 0.9208
Epoch 050 | Loss 0.0897 | Val 0.9208
Epoch 055 | Loss 0.0840 | Val 0.9208
Epoch 060 | Loss 0.0851 | Val 0.9307
Epoch 065 | Loss 0.0641 | Val 0.9307
Epoch 070 | Loss 0.0713 | Val 0.9604
Epoch 075 | Loss 0.0604 | Val 0.9604
Epoch 080 | Loss 0.0645 | Val 0.9604

Final Evaluation (Subgraph-Level)
Test Accuracy: 0.9902

Classification Report:
              precision    recall  f1-score   support

           0     0.9167    1.0000    0.9565        11
           1     1.0000    0.9890    0.9945        91

    accuracy                         0.9902       102
   macro av

#### Graph Classification

At this level we want:
- Input: GNNDatasets/graph.csv (circuit labels)
- Edges: use GNNDatasets/graph_edges.csv (and optionally node_edges/subgraph_edges, but graph_edges.csv should already be the merged top-level edge file).
- Nodes: GNNDatasets/node.csv still provides features.

We’ll build each circuit as one graph, then classify whether it is clean or trojaned.
- Uses GINConv (better for graph classification than vanilla GCN/GraphSAGE).
- Loads graph-level edges (graph_edges.csv).
- Uses graph.csv for circuit labels.
- Includes class-weighted loss (handles class imbalance).
- Prints classification report + confusion matrix just like node/subgraph levels.

In [6]:
# train_graph_gnn.py
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, global_mean_pool

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
EPOCHS = 100
LR = 1e-3
HID_DIM = 64

# -----------------------
# Load node features
# -----------------------
nodes_df = pd.read_csv("GNNDatasets/node.csv")
nodes_df["uid"] = nodes_df["circuit_name"].astype(str) + "::" + nodes_df["node"].astype(str)

# one-hot gate_type
if "gate_type" in nodes_df.columns:
    gate_ohe = pd.get_dummies(nodes_df["gate_type"], prefix="gt")
    nodes_feat_df = pd.concat([nodes_df.drop(columns=["gate_type"]), gate_ohe], axis=1)
else:
    nodes_feat_df = nodes_df.copy()

meta_cols = {"uid","node","circuit_name","label","label_node","label_graph","label_subgraph","folder"}
num_cols = [c for c in nodes_feat_df.columns if c not in meta_cols and pd.api.types.is_numeric_dtype(nodes_feat_df[c])]

uid_to_feat = {}
for _, r in nodes_feat_df.iterrows():
    uid_to_feat[r["uid"]] = r[num_cols].astype(float).values

scaler = StandardScaler().fit(np.stack(list(uid_to_feat.values())))
for k in list(uid_to_feat.keys()):
    uid_to_feat[k] = scaler.transform(uid_to_feat[k].reshape(1,-1)).reshape(-1)

feat_dim = len(uid_to_feat[list(uid_to_feat.keys())[0]])

# -----------------------
# Load graph-level labels
# -----------------------
graph_df = pd.read_csv("GNNDatasets/graph.csv")
graph_labels = {r["circuit_name"]: int(r.get("label_graph", r.get("label", 0))) for _, r in graph_df.iterrows()}

# -----------------------
# Load graph-level edges
# -----------------------
edges_df = pd.read_csv("GNNDatasets/graph_edges.csv")
edges_by_circuit = defaultdict(list)
for _, r in edges_df.iterrows():
    edges_by_circuit[r["circuit_name"]].append((r["src"], r["dst"]))

# -----------------------
# Build dataset
# -----------------------
data_list, labels = [], []

for ckt, lbl in graph_labels.items():
    if ckt not in edges_by_circuit:
        continue
    
    # collect node list
    sub_nodes = nodes_df[nodes_df["circuit_name"]==ckt]["node"].tolist()
    if not sub_nodes: 
        continue
    
    uid_map = {n:i for i,n in enumerate(sub_nodes)}
    x_list = []
    for n in sub_nodes:
        uid = f"{ckt}::{n}"
        if uid in uid_to_feat:
            x_list.append(uid_to_feat[uid])
        else:
            x_list.append(np.zeros(feat_dim))
    x = torch.tensor(np.vstack(x_list), dtype=torch.float)
    
    # build edge_index
    edge_idx = [[], []]
    for u,v in edges_by_circuit[ckt]:
        if u in uid_map and v in uid_map:
            edge_idx[0].append(uid_map[u]); edge_idx[1].append(uid_map[v])
            edge_idx[0].append(uid_map[v]); edge_idx[1].append(uid_map[u])
    if not edge_idx[0]:
        continue
    edge_index = torch.tensor(edge_idx, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([lbl], dtype=torch.long))
    data.circuit_name = ckt
    data_list.append(data)
    labels.append(lbl)

print(f"Built {len(data_list)} graphs (usable)")

# -----------------------
# Split
# -----------------------
labels = np.array(labels)
idxs = np.arange(len(data_list))
train_idx, temp_idx, y_train, y_temp = train_test_split(idxs, labels, test_size=0.3, 
                                                        stratify=labels, random_state=RANDOM_SEED)
val_idx, test_idx, y_val, y_test = train_test_split(temp_idx, y_temp, test_size=0.5,
                                                    stratify=y_temp, random_state=RANDOM_SEED)

train_loader = DataLoader([data_list[i] for i in train_idx], batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader([data_list[i] for i in val_idx], batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader([data_list[i] for i in test_idx], batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

# -----------------------
# Model (GIN for graphs)
# -----------------------
class GraphClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim=HID_DIM, out_dim=2):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(in_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim))
        self.conv1 = GINConv(nn1)
        nn2 = nn.Sequential(nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim))
        self.conv2 = GINConv(nn2)
        self.lin = nn.Linear(hid_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.4, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = GraphClassifier(feat_dim).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)

# class weights
cls_counts = np.bincount(labels)
w = torch.tensor([cls_counts.sum()/c for c in cls_counts], dtype=torch.float32).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=w)

# -----------------------
# Training loop
# -----------------------
def evaluate(loader):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for b in loader:
            b = b.to(DEVICE)
            out = model(b.x, b.edge_index, b.batch)
            p = out.argmax(dim=1).cpu().numpy()
            ys.extend(b.y.cpu().numpy())
            preds.extend(p)
    return np.array(ys), np.array(preds)

best_val, best_state = -1, None
for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss=0
    for b in train_loader:
        b = b.to(DEVICE)
        optimizer.zero_grad()
        out = model(b.x, b.edge_index, b.batch)
        loss = criterion(out, b.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch%5==0 or epoch==1:
        yv,pv = evaluate(val_loader)
        acc = accuracy_score(yv,pv)
        print(f"Epoch {epoch:03d} | Loss {total_loss/len(train_loader):.4f} | Val {acc:.4f}")
        if acc>best_val: 
            best_val=acc; best_state=model.state_dict().copy()

# load best
if best_state: model.load_state_dict(best_state)

# -----------------------
# Final test
# -----------------------
yt,pt = evaluate(test_loader)
print("\nFinal Evaluation (Graph-Level)")
print("=================================")
print(f"Test Accuracy: {accuracy_score(yt,pt):.4f}\n")
print("Classification Report:")
print(classification_report(yt,pt,digits=4))
print("Confusion Matrix:")
print(confusion_matrix(yt,pt))


Built 92 graphs (usable)
Train: 64, Val: 14, Test: 14




Epoch 001 | Loss 0.7626 | Val 0.7143
Epoch 005 | Loss 0.6486 | Val 0.2857
Epoch 010 | Loss 0.6627 | Val 0.5000
Epoch 015 | Loss 0.6499 | Val 0.2857
Epoch 020 | Loss 0.6382 | Val 0.2143
Epoch 025 | Loss 0.6565 | Val 0.3571
Epoch 030 | Loss 0.6186 | Val 0.2143
Epoch 035 | Loss 0.6348 | Val 0.1429
Epoch 040 | Loss 0.5834 | Val 0.2143
Epoch 045 | Loss 0.5532 | Val 0.2143
Epoch 050 | Loss 0.5080 | Val 0.2143
Epoch 055 | Loss 0.4923 | Val 0.2143
Epoch 060 | Loss 0.4683 | Val 0.3571
Epoch 065 | Loss 0.4470 | Val 0.3571
Epoch 070 | Loss 0.4055 | Val 0.4286
Epoch 075 | Loss 0.3791 | Val 0.6429
Epoch 080 | Loss 0.3507 | Val 0.7143
Epoch 085 | Loss 0.3147 | Val 0.7857
Epoch 090 | Loss 0.2813 | Val 0.7857
Epoch 095 | Loss 0.2523 | Val 0.7857
Epoch 100 | Loss 0.2401 | Val 0.7857

Final Evaluation (Graph-Level)
Test Accuracy: 0.5714

Classification Report:
              precision    recall  f1-score   support

           0     0.3333    0.5000    0.4000         4
           1     0.7500    0.6000   