1. Graph Convolutional Networks (GCN)
GCN is a basic method for applying deep learning to graph-structured data. It works by aggregating information from neighboring nodes to update each nodeâ€™s representation.

How it works?

Each node collects features from its neighbors.

The features are averaged (or summed) with fixed weights (predefined).

A neural network transforms the aggregated features.

ðŸ”¹ Limitation of GCN:

All neighbors contribute equally (fixed weights).

Cannot distinguish which neighbors are more important.

2. Graph Attention Networks (GAT)
GAT improves GCN by introducing attention mechanisms, meaning it learns to focus more on important neighbors and less on irrelevant ones.

How it works?

Instead of fixed weights, GAT computes dynamic attention scores between nodes.

The model learns which neighbors are more relevant for a given node.

The aggregation is now a weighted sum, where important neighbors contribute more.

ðŸ”¹ Key Differences from GCN:

Dynamic Weights (Attention):

GCN â†’ Fixed weights for all neighbors.

GAT â†’ Learns weights based on node features.

Interpretability:

Attention scores tell us which neighbors are more important.

Flexibility:

Can handle varying importance of neighbors (e.g., in social networks, some friends influence you more than others).



Download the dataset from "https://github.com/jpsety/verilog_benchmark_circuits" and put it under "verilog_benchmark_circuits" folder.

#### Fault Injection

Adding faults  ramdonly.

In [1]:
import os
import shutil
import re
import random
from pathlib import Path

# === CONFIGURATION === #
SRC_DIR = Path("verilog_benchmark_circuits")
FAULT_TYPES = ["stuck_at", "glitch", "bridging"]
OUT_DIRS = {ft: Path(ft) for ft in FAULT_TYPES}
CLEAN_DIR = Path("clean_netlists")
META_LOG = Path("fault_metadata.csv")
random.seed(42)  # reproducibility

# === CLEANUP === #
for d in OUT_DIRS.values():
    if d.exists():
        shutil.rmtree(d)
    d.mkdir(parents=True, exist_ok=True)

CLEAN_DIR.mkdir(exist_ok=True)
if META_LOG.exists():
    META_LOG.unlink()

# === REGEX to Extract Nets === #
net_decl_re = re.compile(r'\b(?:wire|reg)\s+([^;]+);')

# === Fault Injection Helper === #
def inject_into_module(verilog_text: str, injection: str) -> str:
    parts = verilog_text.rsplit("endmodule", maxsplit=1)
    if len(parts) != 2:
        raise ValueError(f"No endmodule found in {verilog_text[:50]}...")
    return parts[0] + injection + "\nendmodule" + parts[1]

# === MAIN LOOP === #
with open(META_LOG, "w") as metaf:
    metaf.write("filename,fault_type,net1,net2(optional)\n")

    for vfile in sorted(SRC_DIR.glob("*.v")):
        text = vfile.read_text()

        # Save clean copy
        shutil.copy(vfile, CLEAN_DIR / vfile.name)

        # Extract nets
        nets = []
        for m in net_decl_re.finditer(text):
            for tok in re.split(r'[\s,]+', m.group(1).strip()):
                if tok:
                    nets.append(tok)
        if not nets:
            continue

        # Determine fault count: 10% of nets, min 10, max 100
        k = min(100, max(10, int(0.1 * len(nets))))

        ### === STUCK-AT FAULTS === ###
        sa_nets = random.sample(nets, k)
        sa_lines = []
        for net_sa in sa_nets:
            val = random.choice(["1'b0", "1'b1"])
            sa_lines.append(f"  // __INJECTED_FAULT__ type=stuck_at net={net_sa} value={val}")
            sa_lines.append(f"  assign {net_sa} = {val};")
            metaf.write(f"{vfile.name},stuck_at,{net_sa},\n")
        sa_block = "\n".join(sa_lines)
        out_sa = OUT_DIRS["stuck_at"] / f"stuck_at_{vfile.name}"
        out_sa.write_text(inject_into_module(text, sa_block))

        ### === GLITCH FAULTS === ###
        gl_nets = random.sample(nets, k)
        gl_lines = []
        for net_gl in gl_nets:
            delay1 = random.randint(1, 20)
            delay2 = delay1 + random.randint(1, 10)
            gl_lines.append(f"  // __INJECTED_FAULT__ type=glitch net={net_gl} delays=({delay1},{delay2})")
            gl_lines.append(f"  initial begin")
            gl_lines.append(f"    #{delay1} {net_gl} = ~{net_gl};")
            gl_lines.append(f"    #{delay2} {net_gl} = ~{net_gl};")
            gl_lines.append(f"  end")
            metaf.write(f"{vfile.name},glitch,{net_gl},\n")
        gl_block = "\n".join(gl_lines)
        out_gl = OUT_DIRS["glitch"] / f"glitch_{vfile.name}"
        out_gl.write_text(inject_into_module(text, gl_block))

        ### === BRIDGING FAULTS === ###
        br_lines = []
        used_pairs = set()
        for _ in range(k):
            while True:
                n1, n2 = random.sample(nets, 2)
                if n1 != n2 and (n1, n2) not in used_pairs and (n2, n1) not in used_pairs:
                    used_pairs.add((n1, n2))
                    break
            br_lines.append(f"  // __INJECTED_FAULT__ type=bridging nets=({n1},{n2})")
            br_lines.append(f"  assign {n1} = {n2};")
            metaf.write(f"{vfile.name},bridging,{n1},{n2}\n")
        br_block = "\n".join(br_lines)
        out_br = OUT_DIRS["bridging"] / f"bridging_{vfile.name}"
        out_br.write_text(inject_into_module(text, br_block))

print("âœ… Fault injection complete. Metadata saved to:", META_LOG)


âœ… Fault injection complete. Metadata saved to: fault_metadata.csv


#### Parser: Netlist to Graph

In [7]:
import re
import csv
import shutil
from pathlib import Path
from collections import defaultdict, deque

# ----------------- CONFIG -----------------
SRC_DIR    = Path("verilog_benchmark_circuits")  # original clean netlists
FAULT_DIRS = {"stuck_at": Path("stuck_at"),
              "glitch"  : Path("glitch"),
              "bridging": Path("bridging")}
OUT_NODE   = Path("nodes.csv")
OUT_EDGE   = Path("edges.csv")

# ----------------- REGEX HELPERS -----------------
COMMENT_RE        = re.compile(r'//.*')
BLOCK_COMMENT_RE  = re.compile(r'/\*.*?\*/', re.DOTALL)
MODULE_RE         = re.compile(r'\bmodule\b.*?\bendmodule\b', re.DOTALL)
PORT_DECL_RE      = re.compile(r'\b(input|output|inout)\b\s+([^;]+);', re.DOTALL)
NET_DECL_RE       = re.compile(r'\b(wire|reg)\b\s+([^;]+);', re.DOTALL)
# Instance like: <GateType> <InstName> ( ... ) ;   (allow multiline ports)
INST_RE           = re.compile(r'^\s*([A-Za-z_]\w*)\s+([A-Za-z_]\w*)\s*\(\s*(.*?)\s*\)\s*$', re.DOTALL | re.MULTILINE)
# Named port maps: .PORT(<anything not crossing ')'>)
PORT_MAP_RE       = re.compile(r'\.(\w+)\s*\(\s*([^)]+)\s*\)')
# Identifiers: escaped (\foo ) or standard
ID_RE             = re.compile(r'(\\[^\s]+\s|[A-Za-z_]\w*)')
# Replace (...) groups to flatten newlines within parens (for easier splitting)
PARENS_RE         = re.compile(r'\((.*?)\)', re.DOTALL)
# Constants / empties / concat
CONST_LIKE        = re.compile(r"^\s*(?:\d+'[bdhoBDHO][0-9xzXZ_]+|['\"].*['\"]|\{.*\}|\s*)$")

# Known output port names across common cell libs
OUT_PORTS = {"Y","ZN","Z","Q","O","CO","S","SUM","QN","QB","OUT"}

# Qualifiers we donâ€™t want to treat as nets
QUAL_STOP = {
    "signed","unsigned","wire","reg","logic","tri","supply0","supply1",
    "wand","wor","tri0","tri1"
}

# Stopwords to ignore when scanning fault-comments
FAULT_STOP = {
    "fault","force","initial","short","to","faults","glitch","bridging","bridge",
    "type","net","nets","value","delays","assign","__INJECTED_FAULT__","begin","end"
}

def strip_comments(txt: str) -> str:
    txt = BLOCK_COMMENT_RE.sub("", txt)
    return COMMENT_RE.sub("", txt)

def extract_module(txt: str) -> str:
    """
    Choose the module block with the highest count of instance-like lines,
    rather than the first occurrence.
    """
    candidates = list(MODULE_RE.finditer(txt))
    if not candidates:
        return ""
    best_block = None
    best_score = -1
    for m in candidates:
        block = m.group(0)
        # Count lines that look like "<type> <inst> (" as a heuristic
        score = len(re.findall(r'^\s*[A-Za-z_]\w*\s+[A-Za-z_]\w*\s*\(', block, flags=re.MULTILINE))
        if score > best_score:
            best_block, best_score = block, score
    return best_block

def find_ids(s: str):
    return [m.group(1).rstrip() for m in ID_RE.finditer(s)]

def parse_named_ports(s: str):
    """
    Parse .PORT(EXPR) pairs. Skip constants/concat/empty. For expressions,
    pick the first identifier-like token as the net name.
    """
    pm = {}
    for m in PORT_MAP_RE.finditer(s):
        p, raw = m.group(1), m.group(2).strip()
        if CONST_LIKE.match(raw):
            continue
        cand = ID_RE.search(raw)
        if cand:
            pm[p] = cand.group(1).rstrip()
    return pm

def extract_fault_nets(raw: str, known_nets: set):
    """
    Robustly find nets that were targeted by injected faults.

    Strategies:
     - comments containing keywords (fault / short / bridge / glitch / force)
     - assign <net> = 1'b0 / 1'b1   --> stuck-at
     - initial ... #<delay> <net> = ... --> glitch (with/without 'begin')
     - assign <n1> = <n2> near file tail (injection appended before endmodule) --> bridging

    Returns a set of net identifiers (strings) intersected with known_nets.
    """
    nets = set()
    lines = raw.splitlines()

    # 1) comment-keyword scan
    for ln in lines:
        low = ln.lower()
        if any(k in low for k in ("fault", "short", "bridge", "bridging", "glitch", "force")):
            for tok in find_ids(ln):
                if tok in FAULT_STOP:
                    continue
                nets.add(tok)

    # 2) assign to constant -> stuck-at
    for m in re.finditer(r'assign\s+([\w\[\]\\]+)\s*=\s*1\'b([01])', raw, flags=re.IGNORECASE):
        nets.add(m.group(1))

    # 3) glitch toggles: initial [begin] ... #delay NET = ...
    for m in re.finditer(r'initial\b(?:\s+begin)?[\s\S]*?#\s*\d+\s*([\w\[\]\\]+)\s*=', raw, flags=re.IGNORECASE):
        nets.add(m.group(1))

    # 4) assign net = net near tail (bridging)
    tail = "\n".join(lines[-300:])  # examine last ~300 lines for injected assigns
    for m in re.finditer(r'assign\s+([\w\[\]\\]+)\s*=\s*([\w\[\]\\]+)\s*;', tail, flags=re.IGNORECASE):
        n1, n2 = m.group(1), m.group(2)
        nets.add(n1); nets.add(n2)

    # 5) fallback keywords
    if not nets:
        for ln in lines:
            low = ln.lower()
            if "short" in low or "bridge" in low or "glitch" in low:
                for tok in find_ids(ln):
                    if tok in FAULT_STOP:
                        continue
                    nets.add(tok)

    # Keep only known nets
    return nets & known_nets

# ----------------- OUTPUT RESET -----------------
for f in (OUT_NODE, OUT_EDGE):
    if f.exists():
        f.unlink()

with OUT_NODE.open("w", newline="") as nf, OUT_EDGE.open("w", newline="") as ef:
    node_w = csv.writer(nf)
    edge_w = csv.writer(ef)

    node_w.writerow([
        "netlist_file","fault_type","node_id","gate_type",
        "in_degree","out_degree","total_degree",
        "path_depth","CC0","CC1","CO",
        "is_input","is_output","label"
    ])
    edge_w.writerow([
        "netlist_file","fault_type","src_node","dst_node","fan_out"
    ])

    def process_verilog_and_write(vfile_path: Path, fault_type: str):
        raw_txt = vfile_path.read_text()
        body = extract_module(strip_comments(raw_txt))
        if not body:
            return False
        fname = vfile_path.name
        # base name: if file is stuck_at_adder.v -> base = adder.v
        if fault_type != "clean" and fname.startswith(f"{fault_type}_"):
            base_name = fname[len(fault_type)+1:]
        else:
            base_name = fname

        # --- Parse IOs ---
        directions = {}
        for m in PORT_DECL_RE.finditer(body):
            d, nets_str = m.groups()
            for net in find_ids(nets_str):
                if net in QUAL_STOP:
                    continue
                directions[net] = d  # input/output/inout

        # --- Net declarations ---
        nets = set(directions)
        for m in NET_DECL_RE.finditer(body):
            for net in find_ids(m.group(2)):
                if net in QUAL_STOP:
                    continue
                nets.add(net)

        # --- Gate instances ---
        driver = {}
        loads  = defaultdict(list)
        gates  = {}
        # Flatten newlines within parentheses to simplify statement splitting
        tmp = PARENS_RE.sub(lambda x: x.group(0).replace('\n',' '), body)
        for stmt in tmp.split(';'):
            stmt = stmt.strip()
            if not stmt or "(" not in stmt:
                continue
            head = stmt.split()[0]
            if head in {"module","endmodule","input","output","inout","wire","reg","assign"}:
                continue
            m = INST_RE.match(stmt)
            if not m:
                continue
            gatetype, inst, ports_str = m.groups()
            gates[inst] = gatetype
            if "." in ports_str:
                pm = parse_named_ports(ports_str)
                for p, net in pm.items():
                    if p in OUT_PORTS:
                        driver[net] = inst
                    else:
                        loads[net].append(inst)
            else:
                tokens = find_ids(ports_str)
                if not tokens:
                    continue
                out, ins = tokens[0], tokens[1:]
                driver[out] = inst
                for net in ins:
                    loads[net].append(inst)

        nets |= set(driver) | set(loads)

        # --- Build graph structure (instance-level with PI_/PO_ sentinels) ---
        Gnodes = set(gates.keys())
        for net in nets:
            if net not in driver:
                Gnodes.add(f"PI_{net}")
            if not loads.get(net):
                Gnodes.add(f"PO_{net}")

        succ = defaultdict(list)
        pred = defaultdict(list)
        for net in nets:
            src = driver.get(net, f"PI_{net}")
            if loads.get(net):
                for dst in loads[net]:
                    succ[src].append(dst)
                    pred[dst].append(src)
            else:
                dst = f"PO_{net}"
                succ[src].append(dst)
                pred[dst].append(src)
        for n in Gnodes:
            succ.setdefault(n, [])
            pred.setdefault(n, [])

        # --- Path depth via BFS from PIs ---
        path_depth = {}
        dq = deque([n for n in Gnodes if n.startswith("PI_")])
        for n in dq:
            path_depth[n] = 0
        while dq:
            u = dq.popleft()
            for v in succ[u]:
                if v not in path_depth:
                    path_depth[v] = path_depth[u] + 1
                    dq.append(v)
        maxd = max(path_depth.values(), default=0)
        for n in Gnodes:
            path_depth.setdefault(n, maxd)

        # --- Fault nets detection ---
        if fault_type == "clean":
            fault_nets = set()
        else:
            known_nets = set(nets)
            fault_nets = extract_fault_nets(raw_txt, known_nets)

        # map fault_nets -> nodes: prefer driver[net] if present else PI_net
        fault_nodes = set()
        for net in fault_nets:
            if net in driver:
                fault_nodes.add(driver[net])
            else:
                fault_nodes.add(f"PI_{net}")

        # --- approx SCOAP ---
        def approx_scoap(n):
            typ = "PI" if n.startswith("PI_") else ("PO" if n.startswith("PO_") else gates.get(n,"unk"))
            d   = path_depth.get(n, maxd)
            cc0 = cc1 = max(1, 10 - d)
            co  = d + 1
            if typ == "PI": cc0 = cc1 = 1
            if typ == "PO": co = 1
            return cc0, cc1, co

        # --- Dump nodes ---
        for n in sorted(Gnodes):
            indeg = len(pred[n])
            outdeg = len(succ[n])
            deg = indeg + outdeg
            cc0, cc1, co = approx_scoap(n)
            is_in = int(n.startswith("PI_"))
            is_out = int(n.startswith("PO_"))
            lbl = int(n in fault_nodes)  # 1 for faulty node, 0 otherwise
            node_w.writerow([
                base_name, fault_type, n, ("PI" if is_in else ("PO" if is_out else gates.get(n,"unk"))),
                indeg, outdeg, deg,
                path_depth.get(n, maxd), cc0, cc1, co,
                is_in, is_out, lbl
            ])

        # --- Dump edges ---
        for net in sorted(nets):
            src = driver.get(net, f"PI_{net}")
            if loads.get(net):
                fout = len(loads[net])
                for dst in loads[net]:
                    edge_w.writerow([base_name, fault_type, src, dst, fout])
            else:
                fout = 0
                dst = f"PO_{net}"
                edge_w.writerow([base_name, fault_type, src, dst, fout])

        return True

    # ----------------- PROCESS CLEAN NETLISTS -----------------
    clean_count = 0
    for v in sorted(SRC_DIR.glob("*.v")):
        ok = process_verilog_and_write(v, "clean")
        if ok:
            clean_count += 1

    # ----------------- PROCESS FAULTED NETLISTS -----------------
    ft_counts = defaultdict(int)
    for ft, folder in FAULT_DIRS.items():
        if not folder.exists():
            continue
        for v in sorted(folder.glob("*.v")):
            ok = process_verilog_and_write(v, ft)
            if ok:
                ft_counts[ft] += 1

    print(f"Wrote clean netlists: {clean_count}")
    for ft, c in ft_counts.items():
        print(f"Wrote {c} files for fault_type={ft}")

print("nodes.csv and edges.csv regenerated.")


Wrote clean netlists: 30
Wrote 30 files for fault_type=stuck_at
Wrote 30 files for fault_type=glitch
Wrote 30 files for fault_type=bridging
nodes.csv and edges.csv regenerated.


Inspecting the data gathered in the csv files. 

In [8]:
import pandas as pd

def inspect_csv(path):
    """
    Print file shape, and for each column:
      - numeric: min, max, number of unique values
      - non-numeric: number of unique values and sample uniques
    """
    df = pd.read_csv(path)
    print(f"\n=== Inspecting {path} ===")
    print(f"Shape: {df.shape}\n")
    for col in df.columns:
        series = df[col]
        print(f"Column: {col}")
        if pd.api.types.is_numeric_dtype(series):
            print(f"  dtype      : {series.dtype}")
            print(f"  min        : {series.min()}")
            print(f"  max        : {series.max()}")
            print(f"  unique vals: {series.nunique()}\n")
        else:
            uniques = series.dropna().unique()
            print(f"  dtype      : {series.dtype}")
            print(f"  unique vals: {len(uniques)}")
            # show up to 10 sample unique values
            print(f"  samples    : {uniques[:10]}\n")

if __name__ == "__main__":
    for fname in ["nodes.csv", "edges.csv"]:
        inspect_csv(fname)



=== Inspecting nodes.csv ===
Shape: (1746912, 14)

Column: netlist_file
  dtype      : object
  unique vals: 30
  samples    : ['adder.v' 'arbiter.v' 'bar.v' 'c1355.v' 'c17.v' 'c1908.v' 'c2670.v'
 'c3540.v' 'c432.v' 'c499.v']

Column: fault_type
  dtype      : object
  unique vals: 4
  samples    : ['clean' 'stuck_at' 'glitch' 'bridging']

Column: node_id
  dtype      : object
  unique vals: 117701
  samples    : ['PI_\\a[0]' 'PI_\\a[100]' 'PI_\\a[101]' 'PI_\\a[102]' 'PI_\\a[103]'
 'PI_\\a[104]' 'PI_\\a[105]' 'PI_\\a[106]' 'PI_\\a[107]' 'PI_\\a[108]']

Column: gate_type
  dtype      : object
  unique vals: 8
  samples    : ['PI' 'PO' 'not' 'and' 'or' 'nand' 'nor' 'xor']

Column: in_degree
  dtype      : int64
  min        : 0
  max        : 8
  unique vals: 7

Column: out_degree
  dtype      : int64
  min        : 0
  max        : 1667
  unique vals: 176

Column: total_degree
  dtype      : int64
  min        : 1
  max        : 1667
  unique vals: 177

Column: path_depth
  dtype      

In [9]:
import pandas as pd
df = pd.read_csv("nodes.csv")
print(df.fault_type.value_counts())
print("Labels per fault type:")
print(df.groupby("fault_type").label.sum())
# Example: show some rows for stuck_at with label==1
print(df[(df.fault_type=="stuck_at") & (df.label==1)].head(20))

nodes_all = pd.read_csv("nodes.csv")
filtered_nodes = nodes_all[(nodes_all['fault_type'] != 'stuck_at') & (nodes_all['label']==1)]
print(filtered_nodes.shape)

fault_type
clean       436728
stuck_at    436728
glitch      436728
bridging    436728
Name: count, dtype: int64
Labels per fault type:
fault_type
bridging    4532
clean          0
glitch      2376
stuck_at    2373
Name: label, dtype: int64
       netlist_file fault_type     node_id gate_type  in_degree  out_degree  \
436731      adder.v   stuck_at  PI_\a[102]        PI          0           1   
436737      adder.v   stuck_at  PI_\a[108]        PI          0           1   
436753      adder.v   stuck_at  PI_\a[122]        PI          0           1   
436774      adder.v   stuck_at   PI_\a[26]        PI          0           1   
436858      adder.v   stuck_at  PI_\b[101]        PI          0           1   
436906      adder.v   stuck_at    PI_\b[2]        PI          0           1   
436917      adder.v   stuck_at    PI_\b[3]        PI          0           1   
436927      adder.v   stuck_at   PI_\b[49]        PI          0           1   
436938      adder.v   stuck_at   PI_\b[59]      

Confirmation: Why No Clock/Reset Nets?
ISCAS-85:
 - These are purely combinational circuits.
 - No flip-flops, registers, clocks, or resets.
 - So our result â€” zero matches for clk, rst, etc. â€” is expected.

EPFL Benchmark Suite: 
- EPFL includes both combinational and sequential circuits.
- However, the versions in verilog_benchmark_circuits repo (like from jpsety) are typically synthesized as combinational gate-level netlists.

Clocks and registers are flattened away unless specifically preserved (e.g., for formal verification or sequential test).

#### Generating the clean_nodes.csv

In [19]:
import re
import csv
import networkx as nx
from pathlib import Path
from collections import defaultdict, deque

# -------- CONFIGURATION --------
SRC_DIR     = Path("verilog_benchmark_circuits")
NODE_CSV    = Path("clean_nodes.csv")
EDGE_CSV    = Path("clean_edges.csv")

# -------- REGEX PATTERNS --------
COMMENT_RE    = re.compile(r'//.*')
MODULE_RE     = re.compile(r'\bmodule\b.*?\bendmodule\b', re.DOTALL)
PORT_DECL_RE  = re.compile(r'\b(input|output)\b\s+([^;]+);')
NET_DECL_RE   = re.compile(r'\b(wire|reg)\b\s+([^;]+);')
INST_RE       = re.compile(r'^\s*([A-Za-z_]\w*)\s+([A-Za-z_]\w*)\s*\(\s*(.*?)\s*\)\s*$', re.DOTALL)
PORT_MAP_RE   = re.compile(r'\.(\w+)\s*\(\s*([\w\[\]\\]+)\s*\)')
ID_RE         = re.compile(r'(\\[^\s]+\s|[A-Za-z_]\w*)')
OUT_PORTS     = {"Y","ZN","Z","Q","O","CO","S","SUM"}

# -------- HELPERS --------
def strip_comments(txt): return COMMENT_RE.sub("", txt)
def extract_module(txt): return MODULE_RE.search(txt).group(0) if MODULE_RE.search(txt) else ""
def find_ids(s): return [m.group(1).rstrip() for m in ID_RE.finditer(s)]
def parse_named_ports(s): return {m.group(1): m.group(2).rstrip() for m in PORT_MAP_RE.finditer(s)}

# -------- RESET OUTPUT --------
for f in (NODE_CSV, EDGE_CSV):
    if f.exists(): f.unlink()

with NODE_CSV.open("w", newline="") as nf, EDGE_CSV.open("w", newline="") as ef:
    node_w = csv.writer(nf)
    edge_w = csv.writer(ef)

    node_w.writerow([
        "netlist_file", "node_id", "gate_type",
        "in_degree", "out_degree", "total_degree",
        "path_depth", "CC0", "CC1", "CO",
        "is_input", "is_output", "label"
    ])
    edge_w.writerow([
        "netlist_file", "src_node", "dst_node", "fan_out", "net_class"
    ])

    for vfile in sorted(SRC_DIR.glob("*.v")):
        raw_txt = vfile.read_text()
        body    = extract_module(strip_comments(raw_txt))
        if not body: continue
        fname = vfile.name

        # --- Parse IOs ---
        directions = {}
        for m in PORT_DECL_RE.finditer(body):
            d, nets = m.groups()
            for net in find_ids(nets):
                directions[net] = d

        # --- Net declarations ---
        nets = set(directions)
        for m in NET_DECL_RE.finditer(body):
            for net in find_ids(m.group(2)):
                nets.add(net)

        # --- Gate instances ---
        driver = {}
        loads  = defaultdict(list)
        gates  = {}

        tmp = re.sub(r'\((.*?)\)', lambda x: x.group(0).replace('\n',' '), body, flags=re.DOTALL)
        for stmt in tmp.split(';'):
            stmt = stmt.strip()
            if not stmt or "(" not in stmt:
                continue
            if stmt.split()[0] in {"module","endmodule","input","output","wire","reg","assign"}:
                continue
            m = INST_RE.match(stmt)
            if not m: continue
            gatetype, inst, ports_str = m.groups()
            gates[inst] = gatetype
            if "." in ports_str:
                pm = parse_named_ports(ports_str)
                for p, net in pm.items():
                    if p in OUT_PORTS:
                        driver[net] = inst
                    else:
                        loads[net].append(inst)
            else:
                tokens = find_ids(ports_str)
                if tokens:
                    out, ins = tokens[0], tokens[1:]
                    driver[out] = inst
                    for net in ins:
                        loads[net].append(inst)

        nets |= set(driver) | set(loads)

        # --- Build graph ---
        G = nx.DiGraph()
        for inst, gt in gates.items():
            G.add_node(inst, gate_type=gt)

        for net in nets:
            if net not in driver:
                G.add_node(f"PI_{net}", gate_type="PI")
            if not loads.get(net):
                G.add_node(f"PO_{net}", gate_type="PO")
        for net in nets:
            src = driver.get(net, f"PI_{net}")
            for dst in loads.get(net, []):
                G.add_edge(src, dst, net=net)
            if not loads.get(net):
                G.add_edge(src, f"PO_{net}", net=net)

        # --- Path depth (BFS) ---
        path_depth = {}
        dq = deque([n for n in G if n.startswith("PI_")])
        for n in dq: path_depth[n] = 0
        while dq:
            u = dq.popleft()
            for v in G.successors(u):
                if v not in path_depth:
                    path_depth[v] = path_depth[u]+1
                    dq.append(v)
        maxd = max(path_depth.values(), default=0)
        for n in G:
            path_depth.setdefault(n, maxd)

        # --- Approximated SCOAP ---
        def approx_scoap(n):
            typ = G.nodes[n]["gate_type"]
            d   = path_depth[n]
            cc0 = cc1 = max(1, 10 - d)
            co  = d + 1
            if typ == "PI": cc0 = cc1 = 1
            if typ == "PO": co = 1
            return cc0, cc1, co

        # --- Dump nodes (all label=0 for clean) ---
        for n, data in G.nodes(data=True):
            indeg, outdeg = G.in_degree(n), G.out_degree(n)
            deg = indeg + outdeg
            cc0, cc1, co = approx_scoap(n)
            node_w.writerow([
                fname, n, data["gate_type"],
                indeg, outdeg, deg,
                path_depth[n], cc0, cc1, co,
                int(n.startswith("PI_")), int(n.startswith("PO_")),
                0  # label = 0 (no fault)
            ])

        # --- Dump edges ---
        for u, v, ed in G.edges(data=True):
            net = ed["net"]
            fout = len(loads.get(net, []))
            nclass = "clock" if any(t in net.lower() for t in ["clk", "clock", "rst", "reset"]) else "data"
            edge_w.writerow([
                fname, u, v, fout, nclass
            ])

print("Clean CSV generation complete:")
print(" - clean_nodes.csv")
print(" - clean_edges.csv")


Clean CSV generation complete:
 - clean_nodes.csv
 - clean_edges.csv


In [20]:
import pandas as pd

def inspect_csv(path):
    """
    Print file shape, and for each column:
      - numeric: min, max, number of unique values
      - non-numeric: number of unique values and sample uniques
    """
    df = pd.read_csv(path)
    print(f"\n=== Inspecting {path} ===")
    print(f"Shape: {df.shape}\n")
    for col in df.columns:
        series = df[col]
        print(f"Column: {col}")
        if pd.api.types.is_numeric_dtype(series):
            print(f"  dtype      : {series.dtype}")
            print(f"  min        : {series.min()}")
            print(f"  max        : {series.max()}")
            print(f"  unique vals: {series.nunique()}\n")
        else:
            uniques = series.dropna().unique()
            print(f"  dtype      : {series.dtype}")
            print(f"  unique vals: {len(uniques)}")
            # show up to 10 sample unique values
            print(f"  samples    : {uniques[:10]}\n")

if __name__ == "__main__":
    for fname in ["clean_nodes.csv", "clean_edges.csv"]:
        inspect_csv(fname)



=== Inspecting clean_nodes.csv ===
Shape: (436728, 13)

Column: netlist_file
  dtype      : object
  unique vals: 30
  samples    : ['adder.v' 'arbiter.v' 'bar.v' 'c1355.v' 'c17.v' 'c1908.v' 'c2670.v'
 'c3540.v' 'c432.v' 'c499.v']

Column: node_id
  dtype      : object
  unique vals: 117701
  samples    : ['g1' 'g2' 'g3' 'g4' 'g5' 'g6' 'g7' 'g8' 'g9' 'g10']

Column: gate_type
  dtype      : object
  unique vals: 8
  samples    : ['not' 'and' 'or' 'PI' 'PO' 'nand' 'nor' 'xor']

Column: in_degree
  dtype      : int64
  min        : 0
  max        : 8
  unique vals: 7

Column: out_degree
  dtype      : int64
  min        : 0
  max        : 1667
  unique vals: 176

Column: total_degree
  dtype      : int64
  min        : 1
  max        : 1667
  unique vals: 177

Column: path_depth
  dtype      : int64
  min        : 0
  max        : 40
  unique vals: 41

Column: CC0
  dtype      : int64
  min        : 1
  max        : 9
  unique vals: 9

Column: CC1
  dtype      : int64
  min        : 1
 

#### Trainign & Evaluation GAT

Currently the clean_nodes,csv and clean_edges.csv is not given to maintain the class-balance. 

In [None]:
import os
import math
import json
import random
import argparse
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
from typing import List, Dict, Tuple

# --------------------
# Reproducibility
# --------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --------------------
# Data loading and preprocessing
# --------------------
NUMERIC_NODE_FEATS = ["in_degree","out_degree","total_degree","path_depth","CC0","CC1","CO","is_input","is_output"]
EDGE_FEATS = ["fan_out"]
FAULT_TYPES = ["stuck_at", "bridging", "glitch", "clean"]

def load_graph_tables(nodes_path="nodes.csv", edges_path="edges.csv"):
    nodes = pd.read_csv(nodes_path)
    edges = pd.read_csv(edges_path)
    # Basic sanity
    assert set(["netlist_file","fault_type","node_id","gate_type","label"]).issubset(nodes.columns)
    assert set(["netlist_file","fault_type","src_node","dst_node","fan_out"]).issubset(edges.columns)
    return nodes, edges

def build_graph_index(nodes_df: pd.DataFrame) -> List[Tuple[str,str]]:
    # Graph = (netlist_file, fault_type)
    graphs = sorted(nodes_df[["netlist_file","fault_type"]].drop_duplicates().itertuples(index=False, name=None))
    return graphs

def fit_feature_encoders(nodes_df: pd.DataFrame, top_k_gates: int = 64):
    # Gate type vocab (top-K + other). Keep PI and PO as explicit gate types.
    gt_counts = Counter(nodes_df["gate_type"].astype(str).tolist())
    # Ensure PI and PO are kept
    forced = ["PI","PO"]
    most_common = [g for g,_ in gt_counts.most_common() if g not in forced]
    vocab = forced + most_common[:max(0, top_k_gates - len(forced))]
    gate_to_idx = {g:i for i,g in enumerate(vocab)}
    gate_other_idx = len(vocab)  # for unknowns
    meta = {
        "gate_vocab": vocab,
        "gate_other_idx": gate_other_idx,
        "numeric_means": {},
        "numeric_stds": {}
    }
    # Fit numeric feature stats on all nodes but compute later only on TRAIN split
    # We'll recompute after split. Here we just return placeholders.
    return gate_to_idx, gate_other_idx, meta

def compute_graphwise_positional(nodes_g: pd.DataFrame):
    # Positional encodings p_v: [norm_depth, is_input, is_output]
    # Normalize depth per graph
    max_depth = max(nodes_g["path_depth"].max(), 1)
    norm_depth = nodes_g["path_depth"] / max_depth
    p = np.stack([
        norm_depth.values.astype(np.float32),
        nodes_g["is_input"].values.astype(np.float32),
        nodes_g["is_output"].values.astype(np.float32),
    ], axis=1)
    return p

def one_hot_gate_type(gate_type: str, gate_to_idx: Dict[str,int], gate_other_idx: int, dim: int):
    idx = gate_to_idx.get(gate_type, gate_other_idx)
    v = np.zeros(dim, dtype=np.float32)
    v[idx] = 1.0
    return v

def standardize_features(train_arr: np.ndarray, arr: np.ndarray, eps: float = 1e-6):
    mean = train_arr.mean(axis=0)
    std = train_arr.std(axis=0)
    std = np.where(std < eps, 1.0, std)
    arr_std = (arr - mean) / std
    return arr_std, mean, std

def collate_graph(nodes_df: pd.DataFrame, edges_df: pd.DataFrame,
                  gkey: Tuple[str,str],
                  gate_to_idx, gate_other_idx, gate_feat_dim,
                  numeric_scaler=None):
    nf, ft = gkey
    n_g = nodes_df[(nodes_df["netlist_file"]==nf) & (nodes_df["fault_type"]==ft)].copy()
    e_g = edges_df[(edges_df["netlist_file"]==nf) & (edges_df["fault_type"]==ft)].copy()
    # Node index map
    node_ids = n_g["node_id"].astype(str).tolist()
    nidx = {n:i for i,n in enumerate(node_ids)}
    N = len(node_ids)

    # Node features
    gate_oh = np.stack([one_hot_gate_type(g, gate_to_idx, gate_other_idx, gate_feat_dim) for g in n_g["gate_type"].astype(str)], axis=0)
    numeric = n_g[NUMERIC_NODE_FEATS].values.astype(np.float32)
    # Positional encodings p_v
    pos = compute_graphwise_positional(n_g)  # shape (N, 3)

    # Standardize numeric features with provided scaler
    if numeric_scaler is not None:
        mean, std = numeric_scaler
        std = np.where(std < 1e-6, 1.0, std)
        numeric = (numeric - mean) / std

    x = np.concatenate([numeric, gate_oh], axis=1).astype(np.float32)

    # Edge index and edge features
    if len(e_g) > 0:
        src = e_g["src_node"].astype(str).map(nidx).values
        dst = e_g["dst_node"].astype(str).map(nidx).values
        # Filter edges with unknown nodes just in case
        mask = (~pd.isna(src)) & (~pd.isna(dst))
        src = src[mask].astype(np.int64)
        dst = dst[mask].astype(np.int64)
        E = len(src)
        edge_index = np.stack([src, dst], axis=0)
        # Edge features: fan_out -> normalized log1p
        fan_out = e_g.loc[mask, "fan_out"].values.astype(np.float32)
        edge_attr = np.log1p(fan_out)
        # Normalize per-graph to [0,1]
        if E > 0:
            e_min = edge_attr.min()
            e_max = edge_attr.max()
            if e_max > e_min:
                edge_attr = (edge_attr - e_min) / (e_max - e_min)
            else:
                edge_attr = np.zeros_like(edge_attr)
        edge_attr = edge_attr[:,None]
    else:
        edge_index = np.zeros((2,0), dtype=np.int64)
        edge_attr = np.zeros((0,1), dtype=np.float32)
        E = 0

    y = n_g["label"].values.astype(np.int64)  # 0/1
    return {
        "node_ids": node_ids,
        "x": x,
        "pos": pos,
        "edge_index": edge_index,
        "edge_attr": edge_attr,
        "y": y,
        "fault_type": ft,
        "netlist_file": nf
    }

# --------------------
# Model: Edge-aware, position-encoded, residual GAT
# --------------------
class EdgeAwareGATLayer(nn.Module):
    def __init__(self, in_dim, edge_dim, pos_dim, out_dim, heads=4, dropout=0.2, leaky_relu_neg=0.2):
        super().__init__()
        self.heads = heads
        self.out_dim = out_dim
        self.dropout = nn.Dropout(dropout)
        self.leaky_relu = nn.LeakyReLU(leaky_relu_neg)
        # Per-head projections
        self.W = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=False) for _ in range(heads)])
        self.U = nn.ModuleList([nn.Linear(edge_dim, out_dim, bias=False) for _ in range(heads)])
        self.a = nn.ParameterList([nn.Parameter(torch.randn(2*out_dim + out_dim + 2*pos_dim)) for _ in range(heads)])
        self.norm = nn.LayerNorm(heads*out_dim)

    @staticmethod
    def segment_softmax(dst, logits):
        # Stable softmax over incoming edges per destination node
        # Sort edges by dst
        E = logits.size(0)
        device = logits.device
        perm = torch.argsort(dst)
        dst_s = dst[perm]
        z_s = logits[perm]

        # boundaries where dst changes
        b = torch.cat([torch.tensor([0], device=device),
                       torch.nonzero(dst_s[1:] != dst_s[:-1], as_tuple=False).flatten()+1,
                       torch.tensor([E], device=device)])
        alphas = torch.empty_like(z_s)
        for i in range(b.numel()-1):
            s, e = b[i].item(), b[i+1].item()
            seg = z_s[s:e]
            m = torch.max(seg)
            exp = torch.exp(seg - m)
            denom = torch.sum(exp) + 1e-9
            alphas[s:e] = exp / denom
        # Unsort back
        inv = torch.empty_like(perm)
        inv[perm] = torch.arange(E, device=device)
        return alphas[inv]

    def forward(self, x, edge_index, edge_attr, pos):
        # x: (N, in_dim), edge_index: (2, E), edge_attr:(E, e_dim), pos:(N, pos_dim)
        N = x.size(0)
        src = edge_index[0]
        dst = edge_index[1]
        head_outs = []
        for h in range(self.heads):
            x_proj = self.W[h](x)              # (N, out_dim)
            e_proj = self.U[h](edge_attr)      # (E, out_dim)
            xu = x_proj[src]                   # (E, out_dim)
            xv = x_proj[dst]                   # (E, out_dim)
            pv = pos[dst]                      # (E, pos_dim)
            pu = pos[src]                      # (E, pos_dim)
            cat = torch.cat([xv, xu, e_proj, pv, pu], dim=1)  # (E, 2*d + d_e + 2*p)
            z = self.leaky_relu(torch.matmul(cat, self.a[h])) # (E,)
            alpha = self.segment_softmax(dst, z)              # (E,)
            alpha = self.dropout(alpha)

            # Aggregate to nodes: sum_u alpha_{vu} * x_proj[u]
            out_h = torch.zeros((N, self.out_dim), device=x.device, dtype=x.dtype)
            out_h.index_add_(0, dst, (alpha.unsqueeze(1) * xu))
            head_outs.append(out_h)
        H = torch.cat(head_outs, dim=1)  # (N, heads*out_dim)
        return self.norm(H)

class EdgeAwareGAT(nn.Module):
    def __init__(self, in_dim, edge_dim, pos_dim, hidden_dim=64, heads=4, num_layers=3, dropout=0.2, num_classes=2):
        super().__init__()
        self.layers = nn.ModuleList()
        self.res_proj = nn.ModuleList()
        self.drop = nn.Dropout(dropout)
        self.act = nn.ELU()
        # First layer
        self.layers.append(EdgeAwareGATLayer(in_dim, edge_dim, pos_dim, hidden_dim, heads=heads, dropout=dropout))
        self.res_proj.append(nn.Linear(in_dim, hidden_dim*heads, bias=False))
        # Hidden layers
        for _ in range(num_layers-1):
            self.layers.append(EdgeAwareGATLayer(hidden_dim*heads, edge_dim, pos_dim, hidden_dim, heads=heads, dropout=dropout))
            self.res_proj.append(nn.Identity())  # same dim residual
        self.final_norm = nn.LayerNorm(hidden_dim*heads)
        self.out_lin = nn.Linear(hidden_dim*heads, num_classes)

    def forward(self, x, edge_index, edge_attr, pos):
        h = x
        for i, gat in enumerate(self.layers):
            z = gat(h, edge_index, edge_attr, pos)
            # Residual + dropout + activation
            res = self.res_proj[i](h)
            h = self.final_norm(res + self.drop(self.act(z)))
        logits = self.out_lin(h)
        return logits

# --------------------
# Training utilities
# --------------------
def compute_class_weights(y_list: List[np.ndarray]):
    y = np.concatenate(y_list, axis=0)
    counts = np.bincount(y, minlength=2).astype(np.float32)
    total = counts.sum()
    weights = total / (2.0 * np.maximum(counts, 1.0))
    return torch.tensor(weights, dtype=torch.float32)

def train_epoch(model, optimizer, criterion, graphs, device):
    model.train()
    total_loss = 0.0
    for g in graphs:
        x = torch.tensor(g["x"], dtype=torch.float32, device=device)
        pos = torch.tensor(g["pos"], dtype=torch.float32, device=device)
        ei = torch.tensor(g["edge_index"], dtype=torch.long, device=device)
        ea = torch.tensor(g["edge_attr"], dtype=torch.float32, device=device)
        y = torch.tensor(g["y"], dtype=torch.long, device=device)
        optimizer.zero_grad()
        logits = model(x, ei, ea, pos)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(len(graphs),1)

@torch.no_grad()
def predict_graph(model, g, device):
    model.eval()
    x = torch.tensor(g["x"], dtype=torch.float32, device=device)
    pos = torch.tensor(g["pos"], dtype=torch.float32, device=device)
    ei = torch.tensor(g["edge_index"], dtype=torch.long, device=device)
    ea = torch.tensor(g["edge_attr"], dtype=torch.float32, device=device)
    logits = model(x, ei, ea, pos)
    probs = F.softmax(logits, dim=1).cpu().numpy()
    preds = probs.argmax(axis=1)
    return preds, probs

def plot_confusion(y_true, y_pred, title, savepath=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    plt.figure(figsize=(4.2,3.6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
                xticklabels=["neg","pos"], yticklabels=["neg","pos"])
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    if savepath:
        plt.savefig(savepath, dpi=150)
    plt.show()
    plt.close()

# --------------------
# Main script
# --------------------
def main(args):
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")

    nodes, edges = load_graph_tables(args.nodes, args.edges)

    # Build graph keys and encoders
    graph_keys = build_graph_index(nodes)
    gate_to_idx, gate_other_idx, meta = fit_feature_encoders(nodes, top_k_gates=args.top_k_gates)
    gate_feat_dim = len(meta["gate_vocab"]) + 1  # +other

    # Split by graphs to avoid leakage
    # Keep proportion across fault types
    keys_by_type = defaultdict(list)
    for (nf, ft) in graph_keys:
        keys_by_type[ft].append((nf, ft))

    train_keys, val_keys, test_keys = [], [], []
    rng = np.random.RandomState(args.seed)
    for ft, keys in keys_by_type.items():
        keys = keys.copy()
        rng.shuffle(keys)
        n = len(keys)
        n_train = int(n * args.train_ratio)
        n_val = int(n * args.val_ratio)
        train_keys += keys[:n_train]
        val_keys += keys[n_train:n_train+n_val]
        test_keys += keys[n_train+n_val:]

    # Fit numeric scalers on TRAIN graphs only
    train_numeric = []
    for nf, ft in train_keys:
        n_g = nodes[(nodes["netlist_file"]==nf) & (nodes["fault_type"]==ft)]
        train_numeric.append(n_g[NUMERIC_NODE_FEATS].values.astype(np.float32))
    train_numeric = np.concatenate(train_numeric, axis=0) if len(train_numeric)>0 else np.zeros((1,len(NUMERIC_NODE_FEATS)), dtype=np.float32)
    num_mean = train_numeric.mean(axis=0)
    num_std = train_numeric.std(axis=0)
    num_std[num_std < 1e-6] = 1.0

    # Build datasets
    def build_set(keys):
        gs = []
        for gk in keys:
            gs.append(collate_graph(nodes, edges, gk,
                                    gate_to_idx, gate_other_idx, gate_feat_dim,
                                    numeric_scaler=(num_mean, num_std)))
        return gs

    train_graphs = build_set(train_keys)
    val_graphs   = build_set(val_keys)
    test_graphs  = build_set(test_keys)

    # Model dims
    in_dim = len(NUMERIC_NODE_FEATS) + gate_feat_dim
    edge_dim = 1
    pos_dim = 3

    model = EdgeAwareGAT(in_dim=in_dim,
                         edge_dim=edge_dim,
                         pos_dim=pos_dim,
                         hidden_dim=args.hidden_dim,
                         heads=args.heads,
                         num_layers=args.layers,
                         dropout=args.dropout,
                         num_classes=2).to(device)

    # Class weights for imbalance (computed from TRAIN nodes)
    train_labels = [g["y"] for g in train_graphs]
    class_weights = compute_class_weights(train_labels).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # Training loop
    best_val = float("inf")
    best_state = None
    for epoch in range(1, args.epochs+1):
        tr_loss = train_epoch(model, optimizer, criterion, train_graphs, device)
        # Simple validation loss
        with torch.no_grad():
            model.eval()
            val_loss = 0.0
            for g in val_graphs:
                x = torch.tensor(g["x"], dtype=torch.float32, device=device)
                pos = torch.tensor(g["pos"], dtype=torch.float32, device=device)
                ei = torch.tensor(g["edge_index"], dtype=torch.long, device=device)
                ea = torch.tensor(g["edge_attr"], dtype=torch.float32, device=device)
                y = torch.tensor(g["y"], dtype=torch.long, device=device)
                logits = model(x, ei, ea, pos)
                loss = criterion(logits, y)
                val_loss += loss.item()
            val_loss = val_loss / max(len(val_graphs),1)
        if val_loss < best_val:
            best_val = val_loss
            best_state = {k:v.cpu().clone() for k,v in model.state_dict().items()}
        if epoch % max(1, args.log_every) == 0:
            print(f"Epoch {epoch:03d} | train_loss={tr_loss:.4f} | val_loss={val_loss:.4f}")

    if best_state is not None:
        model.load_state_dict(best_state)

    # Evaluation on test graphs
    all_true, all_pred = [], []
    all_ft_true, all_ft_pred = defaultdict(list), defaultdict(list)

    for g in test_graphs:
        preds, _ = predict_graph(model, g, device)
        y = g["y"]
        all_true.append(y)
        all_pred.append(preds)
        ft = g["fault_type"]
        # Only collect per-fault stats for target fault types (exclude clean)
        if ft in ["stuck_at","bridging","glitch"]:
            all_ft_true[ft].append(y)
            all_ft_pred[ft].append(preds)

    y_true = np.concatenate(all_true, axis=0) if len(all_true)>0 else np.array([], dtype=np.int64)
    y_pred = np.concatenate(all_pred, axis=0) if len(all_pred)>0 else np.array([], dtype=np.int64)

    # Overall metrics
    print("\n=== Overall evaluation (all test graphs) ===")
    if y_true.size > 0:
        print(classification_report(y_true, y_pred, digits=4))
        print(f"Overall accuracy: {accuracy_score(y_true, y_pred):.4f}")
        plot_confusion(y_true, y_pred, "Confusion Matrix - Overall", savepath="cm_overall.png")
    else:
        print("No test data available.")

    # Per-fault-type metrics
    for ft in ["stuck_at","bridging","glitch"]:
        print(f"\n=== Evaluation for fault_type={ft} ===")
        if ft in all_ft_true and len(all_ft_true[ft])>0:
            y_t = np.concatenate(all_ft_true[ft], axis=0)
            y_p = np.concatenate(all_ft_pred[ft], axis=0)
            print(classification_report(y_t, y_p, digits=4))
            print(f"Accuracy ({ft}): {accuracy_score(y_t, y_p):.4f}")
            plot_confusion(y_t, y_p, f"Confusion Matrix - {ft}", savepath=f"cm_{ft}.png")
        else:
            print(f"No test graphs for {ft}.")

    # Save training metadata
    meta = {
        "in_dim": in_dim,
        "edge_dim": edge_dim,
        "pos_dim": pos_dim,
        "gate_vocab": list(gate_to_idx.keys()),
        "train_graphs": len(train_graphs),
        "val_graphs": len(val_graphs),
        "test_graphs": len(test_graphs),
        "class_weights": class_weights.cpu().numpy().tolist()
    }
    with open("training_meta.json","w") as f:
        json.dump(meta, f, indent=2)
    torch.save(model.state_dict(), "edgeaware_gat.pt")
    print("\nSaved model to edgeaware_gat.pt and metadata to training_meta.json")
    print("Confusion matrices saved as cm_overall.png, cm_stuck_at.png, cm_bridging.png, cm_glitch.png")

if __name__ == "__main__":
    # =========================
    # JUPYTER NOTEBOOK FRIENDLY VERSION
    # =========================

    # 1) Parameter setup (edit these directly)
    nodes_path      = "nodes.csv"
    edges_path      = "edges.csv"
    epochs          = 50
    lr              = 2e-3
    weight_decay    = 1e-5
    hidden_dim      = 64
    heads           = 4
    layers          = 3
    dropout         = 0.2
    train_ratio     = 0.7
    val_ratio       = 0.1
    top_k_gates     = 64
    seed            = 42
    use_cpu         = False
    log_every       = 5

    # 2) Paste here the ENTIRE script body from the version I gave you earlier,
    #    but remove ONLY the argparse parts at the end, and replace the
    #    `main(args)` call with this:

    class Args: pass
    args = Args()
    args.nodes = nodes_path
    args.edges = edges_path
    args.epochs = epochs
    args.lr = lr
    args.weight_decay = weight_decay
    args.hidden_dim = hidden_dim
    args.heads = heads
    args.layers = layers
    args.dropout = dropout
    args.train_ratio = train_ratio
    args.val_ratio = val_ratio
    args.top_k_gates = top_k_gates
    args.seed = seed
    args.cpu = use_cpu
    args.log_every = log_every
    main(args)  # call the function from the pasted code

#### Trying Simple Architecture: GraphSAGE

In [2]:
# graphsage_baseline.py
import numpy as np
import pandas as pd
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from dgl import add_self_loop

# -------------------------------
# Load and preprocess data
# -------------------------------
nodes_df = pd.read_csv("nodes.csv")
edges_df = pd.read_csv("edges.csv")

# numeric feature columns used in your pipeline
num_feats = ["total_degree", "path_depth", "CC0", "CC1", "CO"]

# normalize numeric columns (same as your GAT code)
for c in num_feats:
    mean, std = nodes_df[c].mean(), nodes_df[c].std()
    nodes_df[c] = (nodes_df[c] - mean) / (std + 1e-6)

# gate type one-hot mapping
gate_types = sorted(nodes_df.gate_type.unique().tolist())
gate2idx = {g: i for i, g in enumerate(gate_types)}
NUM_GATES = len(gate_types)

def build_dgl_graph(nodes, edges, netfile):
    nd = nodes[nodes.netlist_file == netfile].copy()
    ed = edges[edges.netlist_file == netfile].copy()
    if nd.empty:
        return None
    node_ids = nd.node_id.values
    idx_map = {nid: i for i, nid in enumerate(node_ids)}
    src = [idx_map[u] for u in ed.src_node if u in idx_map]
    dst = [idx_map[v] for v in ed.dst_node if v in idx_map]
    g = dgl.graph((src, dst), num_nodes=len(node_ids))
    # add self-loop to avoid 0-in-degree issues
    g = add_self_loop(g)
    onehot = np.zeros((len(nd), NUM_GATES), dtype=np.float32)
    for i, gt in enumerate(nd.gate_type):
        onehot[i, gate2idx[gt]] = 1.0
    nums = nd[num_feats].values.astype(np.float32)
    x = np.concatenate([onehot, nums], axis=1)
    g.ndata["x"] = torch.from_numpy(x)
    g.ndata["y"] = torch.from_numpy(nd.label.values.astype(np.float32))
    # store the fault_type at construction time (returned alongside graph)
    ft = nd.fault_type.iloc[0] if "fault_type" in nd.columns else "unknown"
    return g, ft

# build list of (graph, fault_type)
netlist_files = nodes_df.netlist_file.unique().tolist()
graphs_with_type = []
for nf in netlist_files:
    pair = build_dgl_graph(nodes_df, edges_df, nf)
    if pair is not None:
        graphs_with_type.append(pair)  # (graph, fault_type)

if len(graphs_with_type) == 0:
    raise SystemExit("No graphs found. Check nodes.csv / edges.csv")

# -------------------------------
# compute global pos_weight for BCEWithLogitsLoss
# -------------------------------
total_pos = int(nodes_df.label.sum())
total_neg = int(len(nodes_df) - total_pos)
if total_pos == 0:
    pos_weight_val = 1.0
else:
    pos_weight_val = float(total_neg) / float(max(total_pos, 1))
# clip extreme ratios to avoid numeric instability
pos_weight_val = float(min(pos_weight_val, 500.0))

print(f"Global positives: {total_pos}, negatives: {total_neg}, pos_weight={pos_weight_val:.3f}")

# -------------------------------
# GraphSAGE model
# -------------------------------
class GraphSAGEModel(nn.Module):
    def __init__(self, in_feats, hidden_feats=64, num_layers=2, dropout=0.2):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = in_feats if i == 0 else hidden_feats
            self.layers.append(SAGEConv(in_dim, hidden_feats, aggregator_type='mean'))
        self.dropout = nn.Dropout(dropout)
        self.pred = nn.Linear(hidden_feats, 1)

    def forward(self, g):
        h = g.ndata["x"]
        for layer in self.layers:
            h = layer(g, h)
            h = F.relu(h)
            h = self.dropout(h)
        return self.pred(h).squeeze(-1)

# -------------------------------
# training loop
# -------------------------------
def train(graphs_with_type, device="cpu", epochs=10, lr=1e-3):
    in_dim = graphs_with_type[0][0].ndata["x"].shape[1]
    model = GraphSAGEModel(in_dim).to(device)
    pos_weight = torch.tensor(pos_weight_val, device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    for ep in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for g, _ in graphs_with_type:
            g = g.to(device)
            logits = model(g)
            y = g.ndata["y"].to(device)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += float(loss.item())
        avg_loss = total_loss / len(graphs_with_type)
        print(f"[Epoch {ep}] loss={avg_loss:.4f}")
    return model

# -------------------------------
# evaluation (overall + per-fault-type)
# -------------------------------
@torch.no_grad()
def evaluate(model, graphs_with_type, device="cpu"):
    model.eval()
    y_true_all = []
    y_pred_all = []
    ftypes_all = []

    for g, ftype in graphs_with_type:
        g = g.to(device)
        logits = model(g)
        probs = torch.sigmoid(logits).cpu().numpy()
        preds = (probs > 0.5).astype(int).tolist()
        labels = g.ndata["y"].cpu().numpy().astype(int).tolist()
        y_true_all.extend(labels)
        y_pred_all.extend(preds)
        ftypes_all.extend([ftype] * len(labels))

    # overall
    acc = 100.0 * accuracy_score(y_true_all, y_pred_all)
    print("\n=== Overall Classification Report ===")
    print(classification_report(y_true_all, y_pred_all, digits=4, zero_division=0))
    print("Confusion Matrix:")
    print(confusion_matrix(y_true_all, y_pred_all))
    print(f"Overall accuracy: {acc:.2f}%")

    # per fault type
    uniq_types = sorted(set(ftypes_all))
    for ft in uniq_types:
        idx = [i for i, f in enumerate(ftypes_all) if f == ft]
        if len(idx) == 0:
            continue
        yt = [y_true_all[i] for i in idx]
        yp = [y_pred_all[i] for i in idx]
        acc_ft = 100.0 * accuracy_score(yt, yp)
        print(f"\n--- Fault Type: {ft} ---")
        print(classification_report(yt, yp, digits=4, zero_division=0))
        print("Confusion Matrix:")
        print(confusion_matrix(yt, yp))
        print(f"Accuracy (fault type={ft}): {acc_ft:.2f}%")

# -------------------------------
# run
# -------------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = train(graphs_with_type, device=device, epochs=10, lr=1e-3)
    evaluate(model, graphs_with_type, device=device)


Global positives: 9281, negatives: 1737631, pos_weight=187.225
[Epoch 1] loss=6.6194
[Epoch 2] loss=4.1068
[Epoch 3] loss=3.6614
[Epoch 4] loss=3.3925
[Epoch 5] loss=3.3549
[Epoch 6] loss=3.2031
[Epoch 7] loss=3.1964
[Epoch 8] loss=3.1170
[Epoch 9] loss=3.1024
[Epoch 10] loss=3.1056

=== Overall Classification Report ===
              precision    recall  f1-score   support

           0     0.9964    0.0551    0.1044   1737631
           1     0.0054    0.9625    0.0108      9281

    accuracy                         0.0599   1746912
   macro avg     0.5009    0.5088    0.0576   1746912
weighted avg     0.9911    0.0599    0.1039   1746912

Confusion Matrix:
[[  95673 1641958]
 [    348    8933]]
Overall accuracy: 5.99%

--- Fault Type: clean ---
              precision    recall  f1-score   support

           0     0.9964    0.0551    0.1044   1737631
           1     0.0054    0.9625    0.0108      9281

    accuracy                         0.0599   1746912
   macro avg     0.5009 

In [5]:
# gcn_baseline.py
import numpy as np
import pandas as pd
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from dgl import add_self_loop

# -------------------------------
# Load and preprocess data
# -------------------------------
nodes_df = pd.read_csv("nodes.csv")
edges_df = pd.read_csv("edges.csv")

num_feats = ["total_degree", "path_depth", "CC0", "CC1", "CO"]
for c in num_feats:
    mean, std = nodes_df[c].mean(), nodes_df[c].std()
    nodes_df[c] = (nodes_df[c] - mean) / (std + 1e-6)

gate_types = sorted(nodes_df.gate_type.unique().tolist())
gate2idx = {g: i for i, g in enumerate(gate_types)}
NUM_GATES = len(gate_types)

def build_dgl_graph(nodes, edges, netfile):
    nd = nodes[nodes.netlist_file == netfile].copy()
    ed = edges[edges.netlist_file == netfile].copy()
    if nd.empty:
        return None
    node_ids = nd.node_id.values
    idx_map = {nid: i for i, nid in enumerate(node_ids)}
    src = [idx_map[u] for u in ed.src_node if u in idx_map]
    dst = [idx_map[v] for v in ed.dst_node if v in idx_map]
    g = dgl.graph((src, dst), num_nodes=len(node_ids))
    g = add_self_loop(g)
    onehot = np.zeros((len(nd), NUM_GATES), dtype=np.float32)
    for i, gt in enumerate(nd.gate_type):
        onehot[i, gate2idx[gt]] = 1.0
    nums = nd[num_feats].values.astype(np.float32)
    x = np.concatenate([onehot, nums], axis=1)
    g.ndata["x"] = torch.from_numpy(x)
    g.ndata["y"] = torch.from_numpy(nd.label.values.astype(np.float32))
    ft = nd.fault_type.iloc[0] if "fault_type" in nd.columns else "unknown"
    return g, ft

netlist_files = nodes_df.netlist_file.unique().tolist()
graphs_with_type = []
for nf in netlist_files:
    pair = build_dgl_graph(nodes_df, edges_df, nf)
    if pair is not None:
        graphs_with_type.append(pair)

if len(graphs_with_type) == 0:
    raise SystemExit("No graphs found. Check nodes.csv / edges.csv")

# pos_weight
total_pos = int(nodes_df.label.sum())
total_neg = int(len(nodes_df) - total_pos)
pos_weight_val = float(total_neg) / float(max(total_pos, 1)) if total_pos > 0 else 1.0
pos_weight_val = float(min(pos_weight_val, 500.0))
print(f"Global positives: {total_pos}, negatives: {total_neg}, pos_weight={pos_weight_val:.3f}")

# -------------------------------
# GCN model
# -------------------------------
class SimpleGCN(nn.Module):
    def __init__(self, in_feats, hidden_feats=64, num_layers=2, dropout=0.2):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = in_feats if i == 0 else hidden_feats
            self.layers.append(GraphConv(in_dim, hidden_feats))
        self.dropout = nn.Dropout(dropout)
        self.pred = nn.Linear(hidden_feats, 1)

    def forward(self, g):
        h = g.ndata["x"]
        for layer in self.layers:
            h = layer(g, h)
            h = F.relu(h)
            h = self.dropout(h)
        return self.pred(h).squeeze(-1)

# -------------------------------
# training & evaluation (same pattern)
# -------------------------------
def train(graphs_with_type, device="cpu", epochs=10, lr=1e-3):
    in_dim = graphs_with_type[0][0].ndata["x"].shape[1]
    model = SimpleGCN(in_dim).to(device)
    pos_weight = torch.tensor(pos_weight_val, device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    for ep in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for g, _ in graphs_with_type:
            g = g.to(device)
            logits = model(g)
            y = g.ndata["y"].to(device)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += float(loss.item())
        avg_loss = total_loss / len(graphs_with_type)
        print(f"[Epoch {ep}] loss={avg_loss:.4f}")
    return model

@torch.no_grad()
def evaluate(model, graphs_with_type, device="cpu"):
    model.eval()
    y_true_all = []
    y_pred_all = []
    ftypes_all = []
    for g, ftype in graphs_with_type:
        g = g.to(device)
        logits = model(g)
        probs = torch.sigmoid(logits).cpu().numpy()
        preds = (probs > 0.5).astype(int).tolist()
        labels = g.ndata["y"].cpu().numpy().astype(int).tolist()
        y_true_all.extend(labels)
        y_pred_all.extend(preds)
        ftypes_all.extend([ftype] * len(labels))

    acc = 100.0 * accuracy_score(y_true_all, y_pred_all)
    print("\n=== Overall Classification Report ===")
    print(classification_report(y_true_all, y_pred_all, digits=4, zero_division=0))
    print("Confusion Matrix:")
    print(confusion_matrix(y_true_all, y_pred_all))
    print(f"Overall accuracy: {acc:.2f}%")

    for ft in sorted(set(ftypes_all)):
        idx = [i for i, f in enumerate(ftypes_all) if f == ft]
        yt = [y_true_all[i] for i in idx]
        yp = [y_pred_all[i] for i in idx]
        acc_ft = 100.0 * accuracy_score(yt, yp)
        print(f"\n--- Fault Type: {ft} ---")
        print(classification_report(yt, yp, digits=4, zero_division=0))
        print("Confusion Matrix:")
        print(confusion_matrix(yt, yp))
        print(f"Accuracy (fault type={ft}): {acc_ft:.2f}%")

# -------------------------------
# run
# -------------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = train(graphs_with_type, device=device, epochs=10, lr=1e-3)
    evaluate(model, graphs_with_type, device=device)


Global positives: 9281, negatives: 1737631, pos_weight=187.225
[Epoch 1] loss=6.7355
[Epoch 2] loss=5.0567
[Epoch 3] loss=4.1255
[Epoch 4] loss=3.8169
[Epoch 5] loss=3.6201
[Epoch 6] loss=3.5248
[Epoch 7] loss=3.4590
[Epoch 8] loss=3.4095
[Epoch 9] loss=3.3706
[Epoch 10] loss=3.3440

=== Overall Classification Report ===
              precision    recall  f1-score   support

           0     0.9967    0.0661    0.1240   1737631
           1     0.0055    0.9587    0.0108      9281

    accuracy                         0.0708   1746912
   macro avg     0.5011    0.5124    0.0674   1746912
weighted avg     0.9914    0.0708    0.1234   1746912

Confusion Matrix:
[[ 114856 1622775]
 [    383    8898]]
Overall accuracy: 7.08%

--- Fault Type: clean ---
              precision    recall  f1-score   support

           0     0.9967    0.0661    0.1240   1737631
           1     0.0055    0.9587    0.0108      9281

    accuracy                         0.0708   1746912
   macro avg     0.5011 