# PRB-GraphSAGE: Resource Management Reconstruction GNN  
## CPSC 8810 — ML4G Term Project
### Authored by Ryan Barker

This notebook implements a minimal viable **PRB-GraphSAGE** model: an offline, UE-level graph neural network that reconstructs resource-management behavior from noisy per-TTI proportional-fair (PF) scheduler logs. Using the existing TTI-Trust Anamoly Detector + Classifier preprocessing pipeline, each sliding window of TTIs is converted into a **UE-contention graph** whose nodes represent UEs active in the window and whose edges capture co-scheduling (PRB contention) relationships. Node features summarize PRB usage statistics, fairness/PF surrogates, and starvation patterns, while a compact **GraphSAGE → global pooling → MLP** architecture performs **graph-level classification** to distinguish benign scheduling behavior from PF-based PRB starvation indicative of attack conditions. This MVP focuses on methodology rather than headline accuracy, acknowledging the dataset’s timing noise and simulator artifacts, and serves as a prototype for future training on GPU-native, slot-accurate Aerial/cuMAC traces.

> <span style="color:red; font-weight:bold">Disclaimer — Prototype PRB-GraphSAGE Results</span>  
>  
> The PRB-GraphSAGE results in this notebook were obtained **only on a small, noisy prototype dataset** derived from the original ~30 GB Open Air Interface simulator logs. Due to known timing noise within the low-latency experiment which generated the data, label corruption, and severe class imbalance in this dataset, **the model was deliberately *not* tuned or optimized to the full data volume**.  
>  
> These experiments are intended to validate the **model architecture and end-to-end graph construction pipeline only**. The reported metrics (including the best configuration  
> `hidden_dim=32, num_layers=1, dropout=0.0, conv_type="sage", aggr="mean", lr=2e-4` with `F1 = 0.0`) **should not be interpreted as meaningful performance claims for PRB starvation detection in real RAN systems**.  
>  
> A proper hyperparameter search and performance evaluation will be conducted **only after more accurate PF ledgers are collected from the new NVIDIA AI RAN stack**, at which point the same PRB-GraphSAGE code will be re-used on cleaner, slot-accurate traces. The limitations of the current dataset and their impact on model behavior will be discussed in detail in the final report.

### preproc_tti_trust.py

This script is a preprocessing scaffold for the TTI-Trust project that converts lightweight attack/benign CSV “shims” into parquet datasets, derives scheduler-native, identity-agnostic features, and emits windowed shards plus cross-validation splits for PyTorch. It first materializes CSVs into partitioned parquet (by `run_id`), adding a relative time axis `sec_rel` if needed, then for each run it auto-detects UE resource-block columns, computes `total_rb`, normalizes RB usage into per-UE PRB shares, and derives a Jain’s fairness index `J`. On top of these per-UE shares it builds “roles” by tracking an exponential-moving-average of the shares, selecting the top-K dominant UEs per TTI, and aggregating everyone else into a `rest_share` and corresponding “small RB” indicators, plus 1-second presence flags to capture coarse duty cycle behavior. Using these role-based features and the `phase` labels, it slides long and short windows (e.g., 240- and 64-TTI) with multiple strides, labels each window as benign, attack, or ambiguous via a dominance test on the top role (with thresholds like `ATTACK_DOMINANCE`, `OTHERS_MAX`, and `DUTY_KEEP`), and batches them into compressed NumPy `.npz` shards alongside parquet meta tables that record window boundaries, labels, and summary statistics. The script writes enriched per-run parquet files, then populates a `win_shards` directory with both long and short window shards for attack and benign sources. Finally, it builds grouped K-fold splits where entire `run_id`s are kept together in either train or validation, returning index pairs for each fold and printing a brief summary so downstream training can stream shards or assemble cross-validation sets directly from the generated metadata.

In [7]:
# TTI-Trust preprocessing scaffold for PyTorch: identity-agnostic feature composer,
# phase-aware windowing, and grouped K-fold splits.

from __future__ import annotations
import os, re, warnings, itertools
from dataclasses import dataclass
from typing import List, Tuple
from glob import glob

import numpy as np
import pandas as pd
import pyarrow as pa, pyarrow.csv as pv, pyarrow.parquet as pq

# -------------------- DATASETS --------------------
ATTACK_FN = "data/shims/attack_shim_16000_24000.csv" # Committing lightweight shims to the repo due to storage quotas, full 15GB dataset available upon request
BENIGN_FN = "data/shims/benign_shim_8000_12000.csv"  # Committing lightweight shims to the repo due to storage quotas, full 15GB dataset available upon request

# -------------------- Constants aligned to paper --------------------
TTI_SEC = 0.0005
C_PRB   = 106

W_LONG  = 240
W_SHORT = 64

STRIDES_LONG  = [8, 12]
STRIDES_SHORT = [4, 6]

TOP_K = 6
SMALL_RB_EPS = 4

ATTACK_DOMINANCE = 0.90
OTHERS_MAX       = 0.10
DUTY_KEEP        = 0.90

# -------------------- Helpers --------------------

def csv_to_parquet(src_csv: str, dest_root: str):
    os.makedirs(dest_root, exist_ok=True)
    for _root, _dirs, files in os.walk(dest_root):
        if any(f.endswith(".parquet") for f in files):
            return

    convert_cols = {"run_id": pa.large_string(), "phase": pa.large_string()}
    read_opts  = pv.ReadOptions(block_size=1<<26)
    conv_opts  = pv.ConvertOptions(column_types=convert_cols, strings_can_be_null=True)
    t = pv.read_csv(src_csv, read_options=read_opts, convert_options=conv_opts)

    if "run_id" not in t.schema.names:
        stem = os.path.splitext(os.path.basename(src_csv))[0]
        run_id = pa.array(np.full(t.num_rows, stem), type=pa.large_string())
        t = t.append_column("run_id", run_id)

    if "sec_rel" not in t.schema.names:
        n = t.num_rows
        sec_rel = pa.array((np.arange(n, dtype=np.int32) * np.float32(TTI_SEC)), type=pa.float32())
        t = t.append_column("sec_rel", sec_rel)

    pq.write_to_dataset(t, root_path=dest_root, partition_cols=["run_id"],
                        compression="zstd", use_dictionary=True)

def _detect_ue_rb_columns(df: pd.DataFrame) -> List[str]:
    cand = [c for c in df.columns if re.fullmatch(r'UE\d+_rb', c, flags=re.IGNORECASE)]
    if not cand:
        cand = [c for c in df.columns if c.lower().endswith('_rb')]
    def _key(c):
        m = re.search(r'(\d+)', c)
        return int(m.group(1)) if m else 10**9
    return sorted(set(cand), key=_key)

def _ensure_total_rb(df: pd.DataFrame, ue_cols: List[str]) -> pd.DataFrame:
    if 'total_rb' not in df.columns:
        df['total_rb'] = df[ue_cols].sum(axis=1)
    return df

def _compute_shares_and_fairness(df: pd.DataFrame, ue_cols: List[str]) -> pd.DataFrame:
    shares = (df[ue_cols].clip(lower=0).astype('float32') / np.float32(C_PRB))
    shares.columns = [c.replace('_rb','_share') for c in ue_cols]
    for c in shares.columns:
        df[c] = shares[c]
    x = shares.to_numpy(dtype='float32', copy=False)
    sum_x  = x.sum(axis=1, dtype='float32')
    sum_x2 = (x*x).sum(axis=1, dtype='float32')
    n = np.float32(max(1, len(ue_cols)))
    with np.errstate(divide='ignore', invalid='ignore'):
        J = (sum_x*sum_x) / (n * sum_x2)
    df['J'] = np.nan_to_num(J, nan=0.0, posinf=0.0, neginf=0.0).astype('float32')
    return df

def _compose_roles(df: pd.DataFrame, ue_cols: List[str], top_k: int = TOP_K) -> pd.DataFrame:
    share_cols = [c.replace('_rb','_share') for c in ue_cols]
    alpha = 2.0 / (20 + 1.0)
    for c in share_cols:
        df[c+"_ema"] = df[c].astype('float32').ewm(alpha=alpha, adjust=False).mean().astype('float32')

    S   = df[share_cols].to_numpy(dtype='float32', copy=False)
    E   = df[[c+"_ema" for c in share_cols]].to_numpy(dtype='float32', copy=False)
    PRB = df[ue_cols].to_numpy(dtype='int32',    copy=False)

    T, U = S.shape
    k = min(top_k, U)
    order  = np.argsort(-E, axis=1)
    top_ix = order[:, :k]

    roles       = np.take_along_axis(S,   top_ix, axis=1).astype('float32')
    roles_small = np.take_along_axis((PRB < SMALL_RB_EPS).astype('float32'), top_ix, axis=1)

    sum_all    = S.sum(axis=1, dtype='float32')
    sum_topk   = roles.sum(axis=1, dtype='float32')
    rest_share = (sum_all - sum_topk).clip(min=0).astype('float32')

    if U > k:
        row_ix = np.arange(T)[:, None]
        mask = np.ones_like(S, dtype=bool); mask[row_ix, top_ix] = False
        cnt  = mask.sum(axis=1)
        rest_small = ((PRB < SMALL_RB_EPS).astype('float32') * mask).sum(axis=1) / np.maximum(cnt, 1)
    else:
        rest_small = np.zeros(T, dtype='float32')

    for i in range(k):
        df[f'role{i+1}_share']   = roles[:, i]
        df[f'role{i+1}_smallrb'] = roles_small[:, i]
    for i in range(k, top_k):
        df[f'role{i+1}_share']   = np.float32(0.0)
        df[f'role{i+1}_smallrb'] = np.float32(0.0)
    df['rest_share']   = rest_share
    df['rest_smallrb'] = rest_small

    sec_bin = np.floor(df['sec_rel'].to_numpy(dtype='float32')).astype('int32')
    max_sec = int(sec_bin.max(initial=0))
    def sec_presence(vec_bool: np.ndarray) -> np.ndarray:
        agg = np.zeros(max_sec + 1, dtype='uint8')
        np.maximum.at(agg, sec_bin, vec_bool.astype('uint8'))
        return agg[sec_bin]
    for i in range(top_k):
        df[f'role{i+1}_present_1s'] = sec_presence((df[f'role{i+1}_share'].to_numpy() > 0))
    df['rest_present_1s'] = sec_presence((df['rest_share'].to_numpy() > 0))
    return df

def _label_window(shares_mat: np.ndarray) -> bool:
    if shares_mat.size == 0:
        return False
    top    = shares_mat[:, 0]
    others = shares_mat[:, 1:]
    ok = (top >= ATTACK_DOMINANCE) & (np.max(others, axis=1) <= OTHERS_MAX)
    return (ok.mean() >= DUTY_KEEP)

@dataclass
class WindowSpec:
    length_tti: int
    stride_tti: int

def iter_windows(df: pd.DataFrame, run_id: str, spec: WindowSpec, role_cols: List[str],
                 phases_attack={"attack"}, phases_benign={"baseline","recovery","benign"},
                 batch_windows: int = 8192):
    W, S = spec.length_tti, spec.stride_tti
    R = df[role_cols].to_numpy(dtype='float32', copy=False)
    phase = df['phase'].astype(str).to_numpy()
    T = len(df)
    starts = np.arange(0, max(0, T - W + 1), S, dtype=np.int64)

    bufX, rows = [], []
    for s in starts:
        e = s + W
        blk = R[s:e, :]
        if blk.shape[0] != W:
            continue
        win_phase = phase[s:e]
        frac_attack = np.isin(win_phase, list(phases_attack)).mean()
        frac_benign = np.isin(win_phase, list(phases_benign)).mean()
        looks_attacky = _label_window(blk)
        if (frac_attack >= 0.90) and looks_attacky: y = 2
        elif (frac_benign >= 0.90):                y = 0
        else:                                      y = 1

        bufX.append(blk)
        rows.append((run_id, int(s), int(e), int(W), int(y),
                     float(frac_attack), float(frac_benign), int(looks_attacky)))

        if len(bufX) == batch_windows:
            X = np.stack(bufX, axis=0)
            meta = pd.DataFrame.from_records(
                rows,
                columns=["run_id","start_tti","end_tti","window_len","label",
                         "frac_attack","frac_benign","looks_attacky"]
            )
            yield X, meta
            bufX.clear(); rows.clear()

    if bufX:
        X = np.stack(bufX, axis=0)
        meta = pd.DataFrame.from_records(
            rows,
            columns=["run_id","start_tti","end_tti","window_len","label",
                     "frac_attack","frac_benign","looks_attacky"]
        )
        yield X, meta

# -------------------- Main entry: build datasets (partitioned, low RAM) --------------------

csv_to_parquet(ATTACK_FN, "parquet/attack")
csv_to_parquet(BENIGN_FN, "parquet/benign")

def iter_run_partitions(root: str):
    for run_dir in sorted(glob(os.path.join(root, "run_id=*"))):
        rid = os.path.basename(run_dir).split("=", 1)[1]
        df = pd.read_parquet(run_dir, engine="pyarrow")
        if 'run_id' not in df.columns:
            df['run_id'] = rid
        if 'phase' not in df.columns:
            df['phase'] = 'unknown'
        if 'sec_rel' not in df.columns:
            df['sec_rel'] = (np.arange(len(df), dtype=np.int32) * np.float32(TTI_SEC)).astype('float32')
        yield rid, df

def process_source(root: str, src_name: str):
    if not os.path.isdir(root):
        print(f"[{src_name}] No parquet root found at {root} (skipping).")
        return

    os.makedirs("win_shards", exist_ok=True)
    shard_counter = itertools.count(0)

    runs = list(iter_run_partitions(root))
    if not runs:
        print(f"[{src_name}] No run_id partitions found in {root}.")
        return
    print(f"[{src_name}] Found {len(runs)} run_id partitions.")

    for rid, df in runs:
        ue_cols = _detect_ue_rb_columns(df)
        if not ue_cols:
            warnings.warn(f"[{src_name}] run_id={rid}: no UE *_rb columns; skipping")
            continue

        df = _ensure_total_rb(df, ue_cols)
        df = _compute_shares_and_fairness(df, ue_cols)
        df = _compose_roles(df, ue_cols, top_k=TOP_K)

        keep = [f'role{i}_share' for i in range(1, TOP_K+1)] + ['rest_share','J','sec_rel','run_id','phase']
        enriched_out = f"{src_name}_enriched_{rid}.parquet"
        df[keep].to_parquet(enriched_out, index=False)
        print(f"[{src_name}] run_id={rid} → {enriched_out}  (rows={len(df):,})")

        role_cols = [f'role{k}_share' for k in range(1, TOP_K+1)] + ['rest_share']
        g = df.reset_index(drop=True)

        for stride in STRIDES_LONG:
            spec = WindowSpec(length_tti=W_LONG, stride_tti=stride)
            for X, meta in iter_windows(g, rid, spec, role_cols):
                sid = next(shard_counter)
                np.savez_compressed(f"win_shards/long_s{stride}_{src_name}_{rid}_{sid}.npz", X=X)
                meta["source"] = src_name; meta["stride"] = stride; meta["kind"] = "long"
                meta.to_parquet(f"win_shards/long_s{stride}_{src_name}_{rid}_{sid}_meta.parquet", index=False)

        for stride in STRIDES_SHORT:
            spec = WindowSpec(length_tti=W_SHORT, stride_tti=stride)
            for X, meta in iter_windows(g, rid, spec, role_cols):
                sid = next(shard_counter)
                np.savez_compressed(f"win_shards/short_s{stride}_{src_name}_{rid}_{sid}.npz", X=X)
                meta["source"] = src_name; meta["stride"] = stride; meta["kind"] = "short"
                meta.to_parquet(f"win_shards/short_s{stride}_{src_name}_{rid}_{sid}_meta.parquet", index=False)

process_source("parquet/attack", "attack")
process_source("parquet/benign", "benign")

def load_all_meta(kind: str) -> pd.DataFrame:
    metas = [pd.read_parquet(p) for p in glob(f"win_shards/{kind}_*_meta.parquet")]
    return pd.concat(metas, ignore_index=True) if metas else pd.DataFrame()

M_long  = load_all_meta("long")
M_short = load_all_meta("short")

def grouped_kfold(meta: pd.DataFrame, K: int = 5, seed: int = 42) -> List[Tuple[np.ndarray, np.ndarray]]:
    if meta.empty:
        return []
    rng = np.random.default_rng(seed)
    run_ids = meta['run_id'].astype(str).unique().tolist()
    rng.shuffle(run_ids)
    folds = [set() for _ in range(K)]
    for i, rid in enumerate(run_ids):
        folds[i % K].add(rid)
    splits = []
    for i in range(K):
        val_rids = folds[i]
        train_idx = meta.index[~meta['run_id'].astype(str).isin(val_rids)].to_numpy()
        val_idx   = meta.index[ meta['run_id'].astype(str).isin(val_rids)].to_numpy()
        splits.append((train_idx, val_idx))
    return splits

splits_long  = grouped_kfold(M_long,  K=5, seed=2025)
splits_short = grouped_kfold(M_short, K=5, seed=2025)

if len(M_long):
    print(f"\n[long] meta={M_long.shape}, groups={M_long['run_id'].nunique()}, folds={len(splits_long)}")
if len(M_short):
    print(f"[short] meta={M_short.shape}, groups={M_short['run_id'].nunique()}, folds={len(splits_short)}")

print("\nPreprocessing complete. Shards are in ./win_shards; use meta to assemble CV splits or stream for training.")

[attack] Found 1 run_id partitions.
[attack] run_id=040933 → attack_enriched_040933.parquet  (rows=8,001)
[benign] Found 1 run_id partitions.
[benign] run_id=004912 → benign_enriched_004912.parquet  (rows=4,001)

[long] meta=(2403, 11), groups=2, folds=5
[short] meta=(4950, 11), groups=2, folds=5

Preprocessing complete. Shards are in ./win_shards; use meta to assemble CV splits or stream for training.


### Data Loader

This code builds a full PyTorch Geometric pipeline that turns PRB time-series windows into UE-level contention graphs and then into train/validation loaders for a PRB-GraphSAGE model. It starts by detecting UE RB columns (e.g., `UE1_rb`) in per-run parquet “PF ledger” data, then constructs node features for each UE in a window using PRB-based statistics such as mean/min/max/std share of PRBs, activity and zero fractions, and the longest zero (starvation) run. Edges are created between UEs that are co-scheduled at least once in the window, yielding an undirected contention graph. For each window in the preprocessed metadata (`win_shards/{kind}_*_meta.parquet`), `build_graph_from_window` slices the run dataframe, computes node features and edges, and assigns a graph label (binary or multiclass) from the window label, packaging everything into a `torch_geometric.data.Data` object plus a compact metadata dict. `build_prb_graphs` loads and filters all meta files, caches run-level parquet data and UE column detection by `(source, run_id)`, iterates over windows to construct graphs, and returns the list of graphs along with a metadata DataFrame. On top of this, `make_prb_graph_loaders` performs a grouped train/validation split by `run_id` (so windows from the same run don’t leak across splits), builds `GeoDataLoader` instances for training and validation, and returns them alongside split indices and run ID sets. Finally, the script instantiates these loaders for short windows, infers the node feature dimensionality with `detect_node_dim`, and prints it, providing everything needed to feed PRB-GraphSAGE with graph-structured inputs derived from scheduler PRB logs.

In [26]:
import os
import re
import glob
from typing import List, Tuple, Optional, Dict

import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeoDataLoader


# -------------------- Shared constants --------------------

C_PRB = 106  # total PRBs; must match preprocessing


# -------------------- UE column detection --------------------

def _detect_ue_rb_columns(df: pd.DataFrame) -> List[str]:
    """
    Detect UE RB columns, e.g. UE1_rb, UE2_rb, ...
    Falls back to '*_rb' if needed.
    """
    cand = [c for c in df.columns if re.fullmatch(r'UE\d+_rb', c, flags=re.IGNORECASE)]
    if not cand:
        cand = [c for c in df.columns if c.lower().endswith("_rb")]

    def _key(c):
        m = re.search(r"(\d+)", c)
        return int(m.group(1)) if m else 10**9

    return sorted(set(cand), key=_key)


# -------------------- Node feature construction --------------------

def _longest_zero_run(arr: np.ndarray) -> float:
    """Longest consecutive run of zeros in a 1D array."""
    longest = 0
    current = 0
    for v in arr:
        if v == 0:
            current += 1
            if current > longest:
                longest = current
        else:
            current = 0
    return float(longest)


def build_node_features(prb_window: np.ndarray) -> np.ndarray:
    """
    Build UE-level node features for a single window.

    prb_window: [W, U] PRBs per TTI per UE
    Returns: [U, d_node] node feature matrix.
    """
    # shape sanity
    assert prb_window.ndim == 2
    W, U = prb_window.shape
    if U == 0:
        return np.zeros((0, 1), dtype=np.float32)

    prb = prb_window.astype(np.float32)
    share = prb / np.float32(C_PRB)  # [W, U]

    # basic stats
    mean_share = share.mean(axis=0)         
    min_share = share.min(axis=0)
    max_share = share.max(axis=0)
    std_share = share.std(axis=0)

    # activity / starvation
    active = (prb > 0).astype(np.float32)    
    active_frac = active.mean(axis=0)        
    zero_frac = (prb == 0).mean(axis=0)      
    longest_zero = np.array(
        [_longest_zero_run(prb[:, u]) for u in range(U)], dtype=np.float32
    )                                        

    # stack into [U, d_node]
    feats = np.stack(
        [
            mean_share,
            min_share,
            max_share,
            std_share,
            active_frac,
            zero_frac,
            longest_zero,
        ],
        axis=1,
    ).astype(np.float32)

    return feats  # [U, 7]


# -------------------- Edge construction --------------------

def build_edges(prb_window: np.ndarray) -> torch.Tensor:
    """
    Build undirected edges between UEs that are co-scheduled at least once.

    prb_window: [W, U]
    Returns: edge_index [2, E] (torch.long)
    """
    prb = prb_window.astype(np.float32)
    W, U = prb.shape
    if U <= 1:
        return torch.empty((2, 0), dtype=torch.long)

    edge_src = []
    edge_dst = []

    for i in range(U):
        for j in range(i + 1, U):
            both_active = (prb[:, i] > 0) & (prb[:, j] > 0)
            if both_active.any():
                edge_src.extend([i, j])
                edge_dst.extend([j, i])

    if not edge_src:
        return torch.empty((2, 0), dtype=torch.long)

    edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
    return edge_index


# -------------------- Graph construction --------------------

def build_graph_from_window(
    df_run: pd.DataFrame,
    ue_cols: List[str],
    row_meta: pd.Series,
    label_mode: str = "binary",
) -> Tuple[Data, Dict]:
    """
    Build a PyG Data graph from one window.

    df_run: full run dataframe (all TTIs), index aligned so that start_tti/end_tti are valid iloc indices.
    ue_cols: list of UE*_rb columns.
    row_meta: single row from window meta parquet.
    label_mode: "binary" (attack vs non-attack) or "multiclass" (0,1,2).

    Returns:
        (Data, small_meta_dict)
    """
    start = int(row_meta["start_tti"])
    end = int(row_meta["end_tti"])
    run_id = str(row_meta.get("run_id", ""))
    source = str(row_meta.get("source", ""))

    df_w = df_run.iloc[start:end].reset_index(drop=True)
    if df_w.empty:
        # Return an empty graph; caller should skip if desired.
        data = Data(
            x=torch.zeros((0, 1), dtype=torch.float32),
            edge_index=torch.empty((2, 0), dtype=torch.long),
            y=torch.tensor(0, dtype=torch.long),
        )
        meta_small = {
            "run_id": run_id,
            "source": source,
            "start_tti": start,
            "end_tti": end,
            "label_raw": int(row_meta.get("label", 0)),
            "label_bin": 0,
        }
        return data, meta_small

    prb_window = df_w[ue_cols].to_numpy(dtype=np.float32)  # [W, U]
    node_feats = build_node_features(prb_window)           # [U, d_node]
    edge_index = build_edges(prb_window)                   # [2, E]

    label_raw = int(row_meta.get("label", 0))  # 0: benign, 1: proximal, 2: attack
    if label_mode == "binary":
        label_bin = 1 if label_raw == 2 else 0
        y = torch.tensor(label_bin, dtype=torch.long)
    elif label_mode == "multiclass":
        y = torch.tensor(label_raw, dtype=torch.long)
        label_bin = 1 if label_raw == 2 else 0
    else:
        raise ValueError(f"Unknown label_mode: {label_mode}")

    data = Data(
        x=torch.from_numpy(node_feats),  # [U, d_node]
        edge_index=edge_index,
        y=y,
    )

    meta_small = {
        "run_id": run_id,
        "source": source,
        "start_tti": start,
        "end_tti": end,
        "label_raw": label_raw,
        "label_bin": label_bin,
        "kind": str(row_meta.get("kind", "")),
        "stride": int(row_meta.get("stride", 0)),
    }

    return data, meta_small


# -------------------- Build all graphs from meta + parquet --------------------

def build_prb_graphs(
    kind: str = "short",
    meta_glob: Optional[str] = None,
    parquet_root_attack: str = "parquet/attack",
    parquet_root_benign: str = "parquet/benign",
    label_mode: str = "binary",
    restrict_sources: Optional[List[str]] = None,
    max_windows: Optional[int] = None,
) -> Tuple[List[Data], pd.DataFrame]:
    """
    Build a list of PRB-GraphSAGE graphs from preprocessed meta + PF ledger parquet.

    kind: "short" or "long" (matches the window kind in preprocessing)
    meta_glob: optional override for meta file glob pattern; if None, uses f"win_shards/{kind}_*_meta.parquet"
    parquet_root_attack / parquet_root_benign: roots for run-level PF ledgers
    label_mode: "binary" or "multiclass"
    restrict_sources: optional ["attack","benign"] filter
    max_windows: optional cap on number of windows (for dev / shims)
    """
    if meta_glob is None:
        meta_glob = f"win_shards/{kind}_*_meta.parquet"

    meta_files = sorted(glob.glob(meta_glob))
    if not meta_files:
        raise FileNotFoundError(f"No meta parquet files found for pattern: {meta_glob}")

    # Load and concatenate all meta rows
    meta_rows = []
    for mp in meta_files:
        m = pd.read_parquet(mp)
        if restrict_sources is not None and "source" in m.columns:
            m = m[m["source"].isin(restrict_sources)]
        if len(m) == 0:
            continue
        m = m.copy()
        m["meta_path"] = mp
        meta_rows.append(m)

    if not meta_rows:
        raise RuntimeError("No meta rows found after filtering.")

    meta_all = pd.concat(meta_rows, ignore_index=True)
    graphs: List[Data] = []
    gmeta_rows: List[Dict] = []

    # Cache run-level DF + UE cols so we don't re-read parquet per window
    run_cache: Dict[Tuple[str, str], Tuple[pd.DataFrame, List[str]]] = {}

    for idx, row in meta_all.iterrows():
        source = str(row.get("source", ""))
        run_id = str(row.get("run_id", ""))

        if restrict_sources is not None and source not in restrict_sources:
            continue

        if source == "attack":
            root = parquet_root_attack
        elif source == "benign":
            root = parquet_root_benign
        else:
            # fallback; interpret non-attack as benign
            root = parquet_root_benign

        key = (source, run_id)
        if key not in run_cache:
            run_dir = os.path.join(root, f"run_id={run_id}")
            if not os.path.isdir(run_dir) and not run_dir.endswith(".parquet"):
                # try direct parquet file as well
                if os.path.exists(run_dir + ".parquet"):
                    run_dir = run_dir + ".parquet"
            df_run = pd.read_parquet(run_dir)
            df_run = df_run.reset_index(drop=True)
            ue_cols = _detect_ue_rb_columns(df_run)
            if not ue_cols:
                # skip runs with no UE RB columns
                continue
            run_cache[key] = (df_run, ue_cols)

        df_run, ue_cols = run_cache[key]

        data, ms = build_graph_from_window(df_run, ue_cols, row, label_mode=label_mode)
        # skip empty graphs (no UEs or no rows)
        if data.x.size(0) == 0:
            continue

        graphs.append(data)
        ms["graph_idx"] = len(graphs) - 1
        gmeta_rows.append(ms)

        if max_windows is not None and len(graphs) >= max_windows:
            break

    if not graphs:
        raise RuntimeError("No graphs were constructed; check meta and parquet roots.")

    graph_meta = pd.DataFrame(gmeta_rows).reset_index(drop=True)
    return graphs, graph_meta


# -------------------- Utility: detect node feature dimension --------------------

def detect_node_dim(loader: GeoDataLoader) -> int:
    """
    Inspect the first batch from a PyG DataLoader to infer node feature dimension.
    """
    batch = next(iter(loader))
    return batch.x.size(-1)


# -------------------- High-level loader factory --------------------

def make_prb_graph_loaders(
    kind: str = "short",
    batch_size: int = 32,
    num_workers: int = 0,
    seed: int = 2025,
    val_ratio: float = 0.2,
    label_mode: str = "binary",
    restrict_sources: Optional[List[str]] = None,
    max_windows: Optional[int] = None,
) -> Tuple[GeoDataLoader, GeoDataLoader, pd.DataFrame, Dict]:
    """
    Build train/val PyG DataLoaders for PRB-GraphSAGE.

    Splits by run_id (grouped), so windows from the same run_id do not cross train/val.

    Returns:
        train_loader, val_loader, graph_meta, splits_dict
    """
    graphs, graph_meta = build_prb_graphs(
        kind=kind,
        label_mode=label_mode,
        restrict_sources=restrict_sources,
        max_windows=max_windows,
    )

    # Grouped split by run_id
    run_ids = graph_meta["run_id"].astype(str).unique().tolist()
    rng = np.random.default_rng(seed)
    rng.shuffle(run_ids)

    n_val_runs = max(1, int(len(run_ids) * val_ratio))
    val_runs = set(run_ids[:n_val_runs])

    train_idx: List[int] = []
    val_idx: List[int] = []

    for i, rid in enumerate(graph_meta["run_id"].astype(str).tolist()):
        if rid in val_runs:
            val_idx.append(i)
        else:
            train_idx.append(i)

    train_graphs = [graphs[i] for i in train_idx]
    val_graphs = [graphs[i] for i in val_idx]

    train_loader = GeoDataLoader(
        train_graphs,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
    val_loader = GeoDataLoader(
        val_graphs,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )

    splits = {
        "train_idx": np.array(train_idx, dtype=np.int64),
        "val_idx": np.array(val_idx, dtype=np.int64),
        "val_run_ids": list(val_runs),
        "all_run_ids": run_ids,
    }

    return train_loader, val_loader, graph_meta, splits

train_loader, val_loader, graph_meta, splits = make_prb_graph_loaders(
    kind="short",
    batch_size=GNN_HYP["batch_size"],
    val_ratio=0.2,
    label_mode="binary",   # attack vs non-attack
    max_windows=None,      # or a small number for quick dev
)

D_NODE = detect_node_dim(train_loader)
print("Node feature dim:", D_NODE)

Node feature dim: 7


### Base Model Architecture

#### PRB-GraphSAGE Architecture
This code defines `PRBGraphSAGE`, a configurable PyTorch Geometric graph neural network for classifying UE-level contention graphs in a PRB starvation detection task. The model builds a GNN backbone with a user-specified number of layers (`num_layers`), hidden dimension, and convolution type (`SAGEConv`, `GCNConv`, `GraphConv`, or `GATv2Conv`), optionally followed by batch normalization and dropout after each layer. In the forward pass, node features `x` are iteratively updated using the chosen conv layers over edges `edge_index`, then aggregated into graph-level embeddings via a selectable global pooling strategy (`mean`, `max`, or concatenated `mean+max`). This pooled representation is passed through an MLP head whose depth and hidden size are also configurable, producing either a scalar logit per graph for binary classification (`num_classes == 1`) or a multi-dimensional logit vector for multi-class problems. The interface is designed to be hyperparameter-sweep friendly, with `edge_attr` included in the signature as a placeholder for future edge-aware enhancements, while the current implementation focuses on node features and graph-level pooling.

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
    SAGEConv,
    GCNConv,
    GraphConv,
    GATv2Conv,
    global_mean_pool,
    global_max_pool,
)


class PRBGraphSAGE(nn.Module):
    """
    PRB-GraphSAGE: UE-level contention graph classifier for PRB starvation detection.

    This model is designed to be easy to sweep over hyperparameters:
      - in_dim:         node feature dimension
      - hidden_dim:     hidden channel width
      - num_layers:     number of GNN layers (>=1)
      - conv_type:      GNN layer type: {"sage", "gcn", "graph", "gat"}
      - aggr:           global pooling: {"mean", "max", "mean+max"}
      - mlp_hidden_dim: hidden size of the final MLP head
      - mlp_layers:     number of linear layers in the MLP head (>=1)
      - dropout:        dropout probability applied after each conv & MLP layer
      - num_classes:    1 for binary logit, >1 for multi-class logits

    Forward signature:
        logits = model(x, edge_index, batch, edge_attr=None)

      - x:          [N, in_dim] node features
      - edge_index: [2, E] COO edge indices
      - batch:      [N] graph IDs for global pooling
      - edge_attr:  [E, d_edge] (currently ignored; placeholder for future use)
    """

    def __init__(
        self,
        in_dim: int,
        hidden_dim: int = 64,
        num_layers: int = 2,
        conv_type: str = "sage",
        aggr: str = "mean",
        mlp_hidden_dim: int = 64,
        mlp_layers: int = 2,
        dropout: float = 0.1,
        num_classes: int = 1,
        use_batchnorm: bool = True,
    ):
        super().__init__()

        assert num_layers >= 1, "num_layers must be >= 1"
        assert mlp_layers >= 1, "mlp_layers must be >= 1"
        assert conv_type in {"sage", "gcn", "graph", "gat"}, f"Unknown conv_type: {conv_type}"
        assert aggr in {"mean", "max", "mean+max"}, f"Unknown aggr: {aggr}"
        self.aggr = aggr
        self.dropout = float(dropout)
        self.num_classes = int(num_classes)
        self.use_batchnorm = bool(use_batchnorm)

        # ---- GNN backbone ----
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList() if use_batchnorm else None

        # input layer
        self.convs.append(self._make_conv(conv_type, in_dim, hidden_dim))
        if use_batchnorm:
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        # hidden layers
        for _ in range(num_layers - 1):
            self.convs.append(self._make_conv(conv_type, hidden_dim, hidden_dim))
            if use_batchnorm:
                self.bns.append(nn.BatchNorm1d(hidden_dim))

        # ---- Readout (global pooling) ----
        # mean+max concatenation doubles the channel dimension
        readout_dim = hidden_dim if aggr in {"mean", "max"} else hidden_dim * 2

        # ---- MLP classifier head ----
        mlp_layers_list = []
        in_mlp = readout_dim
        for li in range(mlp_layers):
            out_mlp = mlp_hidden_dim if li < mlp_layers - 1 else num_classes
            mlp_layers_list.append(nn.Linear(in_mlp, out_mlp))
            if li < mlp_layers - 1:
                mlp_layers_list.append(nn.SiLU())
                if dropout > 0:
                    mlp_layers_list.append(nn.Dropout(dropout))
            in_mlp = out_mlp

        self.mlp = nn.Sequential(*mlp_layers_list)

    @staticmethod
    def _make_conv(conv_type: str, in_dim: int, out_dim: int):
        """Factory for different conv types to make hyperparameter sweeps easy."""
        if conv_type == "sage":
            return SAGEConv(in_dim, out_dim)
        if conv_type == "gcn":
            return GCNConv(in_dim, out_dim)
        if conv_type == "graph":
            return GraphConv(in_dim, out_dim)
        if conv_type == "gat":
            return GATv2Conv(in_dim, out_dim, heads=1, concat=False)
        raise ValueError(f"Unknown conv_type: {conv_type}")

    def forward(self, x, edge_index, batch, edge_attr=None):
        """
        x:          [N, in_dim]
        edge_index: [2, E]
        batch:      [N] graph IDs
        edge_attr:  [E, d_edge] (ignored for now; reserved for future edge-aware layers)
        """
        # ---- GNN backbone ----
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = F.silu(x)
            if self.use_batchnorm:
                x = self.bns[i](x)
            if self.dropout > 0:
                x = F.dropout(x, p=self.dropout, training=self.training)

        # ---- Global pooling ----
        if self.aggr == "mean":
            g = global_mean_pool(x, batch)
        elif self.aggr == "max":
            g = global_max_pool(x, batch)
        else:  # "mean+max"
            g_mean = global_mean_pool(x, batch)
            g_max = global_max_pool(x, batch)
            g = torch.cat([g_mean, g_max], dim=-1)

        # ---- MLP head ----
        logits = self.mlp(g)  # [num_graphs, num_classes]

        # For binary classification, most training loops will want shape [num_graphs]
        if self.num_classes == 1:
            logits = logits.squeeze(-1)

        return logits

### Base Model Initialization

#### Hyperparameters

This code sets up initialization, configuration, and lightweight hyperparameter sweeping for a PRB-GraphSAGE GNN used in PRB starvation detection. It begins by seeding Python, NumPy, and PyTorch RNGs for reproducibility, choosing a CUDA or CPU device, and defining `D_NODE` as the node feature dimension (with a placeholder warning until real graph data is available). A `GNN_HYP` dictionary gathers all key hyperparameters in one place—training settings (batch size, epochs, learning rate and grid, weight decay), backbone architecture (hidden size, number of layers, dropout, conv type, pooling strategy, batchnorm), MLP head configuration, output dimension (binary vs multi-class), early-stopping criteria, and scheduler parameters (warmup and total steps). A helper `compute_pos_weight_from_labels` computes class weights for `BCEWithLogitsLoss` from imbalanced binary labels, with a default `pos_weight` of 1.0 until real labels are known. The `make_gnn_model` factory instantiates a `PRBGraphSAGE` model from a hyperparameter dict and moves it to the chosen device, while `count_params` reports trainable parameter count. Optimizer and scheduler builders create an AdamW optimizer and a LambdaLR scheduler implementing linear warmup followed by cosine decay over `total_steps`. The script then instantiates a default model, optimizer, scheduler, and a binary `BCEWithLogitsLoss` using the current `pos_weight`, printing summaries. Finally, `iter_gnn_configs` defines a simple grid-search generator over selected hyperparameters (learning rate, hidden size, number of layers, dropout, aggregation, and conv type), and the code demonstrates its use by instantiating and printing the first few candidate configurations along with their parameter counts, providing a ready-made loop for hyperparameter sweeps.

In [22]:
# ==== CELL: GNN Model + Hyperparameter Initialization ====
import os, math, json, random
from typing import Dict, Iterable

import numpy as np
import torch
import torch.nn as nn

# ---- Reproducibility & device ----
SEED = 2025

def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_all(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

try:
    D_NODE
except NameError:
    D_NODE = 16  # placeholder for later
    print("WARNING: D_NODE is using a placeholder value (16). Set it to your actual node feature dim.")

# ---- GNN Hyperparameters (with small grids for sweeps) ----
GNN_HYP: Dict = {
    # ---- Training-level hparams ----
    "batch_size": 32,              
    "epochs": 40,                  
    "lr": 3e-4,                    
    "lr_grid": [2e-4, 3e-4, 5e-4], 
    "weight_decay": 1e-4,

    # ---- Backbone architecture ----
    # Base config (used if you don't sweep)
    "hidden_dim": 64,
    "num_layers": 2,
    "dropout": 0.1,
    "conv_type": "sage",           
    "aggr": "mean",                
    "use_batchnorm": True,

    # Grids for a compact but meaningful sweep
    "hidden_dim_grid": [32, 64],           
    "num_layers_grid": [1, 2, 3],          
    "dropout_grid": [0.0, 0.1, 0.3],       
    "aggr_grid": ["mean", "mean+max"],     
    "conv_type_grid": ["sage", "gcn"],     

    # ---- MLP head ----
    "mlp_hidden_dim": 64,
    "mlp_layers": 2,

    # ---- Output ----
    # 1 = binary logit (BCEWithLogits); >1 = multi-class (CrossEntropy)
    "num_classes": 1,

    # ---- Early-stop config (used by your training loop) ----
    "early_stop": {
        "monitor": "f1",            
        "mode": "max",              
        "patience": 5,              
        "min_delta": 1e-3,          
    },

    # ---- Scheduler placeholders (updated once loaders are known) ----
    "warmup_steps": 200,
    "total_steps": 10000,
}

print("GNN Hyperparameters:", json.dumps(GNN_HYP, indent=2))


# ---- Optional class weights for binary BCEWithLogitsLoss ----
# You can adapt this to your graph meta if you materialize labels in a DataFrame.
def compute_pos_weight_from_labels(labels: np.ndarray) -> torch.Tensor:
    """
    Compute pos_weight for BCEWithLogitsLoss from a binary label vector.
    pos_weight = (#neg / #pos)
    """
    labels = labels.astype(np.int64).ravel()
    pos = (labels == 1).sum()
    neg = (labels == 0).sum()
    if pos == 0:
        return torch.tensor(1.0, dtype=torch.float32)
    return torch.tensor(float(neg) / float(pos), dtype=torch.float32)

pos_weight = torch.tensor(1.0, dtype=torch.float32)
print("Initial pos_weight (binary):", float(pos_weight))


# ---- PRB-GraphSAGE factory ----
def make_gnn_model(hyp: Dict, in_dim: int) -> PRBGraphSAGE:
    """
    Instantiate a PRBGraphSAGE model from a hyperparameter dict.
    """
    model = PRBGraphSAGE(
        in_dim=in_dim,
        hidden_dim=hyp["hidden_dim"],
        num_layers=hyp["num_layers"],
        conv_type=hyp["conv_type"],      
        aggr=hyp["aggr"],                
        mlp_hidden_dim=hyp["mlp_hidden_dim"],
        mlp_layers=hyp["mlp_layers"],
        dropout=hyp["dropout"],
        num_classes=hyp["num_classes"],
        use_batchnorm=hyp["use_batchnorm"],
    )
    return model.to(DEVICE)



def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)


# ---- Optimizer & Scheduler helpers ----
def make_gnn_optimizer(model: nn.Module, hyp: Dict) -> torch.optim.Optimizer:
    return torch.optim.AdamW(
        model.parameters(),
        lr=hyp["lr"],
        weight_decay=hyp["weight_decay"],
    )

def make_scheduler(optimizer, warmup_steps: int, total_steps: int):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


# ---- Instantiate a default GNN model + optimizer/scheduler ----
gnn_model = make_gnn_model(GNN_HYP, D_NODE)
print(f"GNN params: {count_params(gnn_model):,}")

gnn_optim = make_gnn_optimizer(gnn_model, GNN_HYP)
gnn_scheduler = make_scheduler(gnn_optim, GNN_HYP["warmup_steps"], GNN_HYP["total_steps"])

# Binary classification by default (y ∈ {0,1})
gnn_criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(DEVICE))

print("GNN model, optimizer, scheduler, and loss are initialized.")


# ---- Hyperparameter sweep helper (config generator only) ----
def iter_gnn_configs(base: Dict) -> Iterable[Dict]:
    """
    Simple grid over a few key hyperparameters.
    You can extend/prune this as needed.
    """
    for lr in base["lr_grid"]:
        for hidden in base["hidden_dim_grid"]:
            for num_layers in base["num_layers_grid"]:
                for dropout in base["dropout_grid"]:
                    for aggr in base["aggr_grid"]:
                        for conv_type in base.get("conv_type_grid", [base["conv_type"]]):
                            cfg = dict(base)  # shallow copy
                            cfg["lr"] = lr
                            cfg["hidden_dim"] = hidden
                            cfg["num_layers"] = num_layers
                            cfg["dropout"] = dropout
                            cfg["aggr"] = aggr
                            cfg["conv_type"] = conv_type
                            yield cfg

print("\nInitialization")
for i, cfg in enumerate(iter_gnn_configs(GNN_HYP)):
    if i >= 3:
        break
    m = make_gnn_model(cfg, D_NODE)
    print(
        f"  Config {i}: lr={cfg['lr']}, hidden={cfg['hidden_dim']}, "
        f"layers={cfg['num_layers']}, dropout={cfg['dropout']}, "
        f"aggr={cfg['aggr']}, conv_type={cfg['conv_type']}, "
        f"params={count_params(m):,}"
    )

Device: cuda
GNN Hyperparameters: {
  "batch_size": 32,
  "epochs": 40,
  "lr": 0.0003,
  "lr_grid": [
    0.0002,
    0.0003,
    0.0005
  ],
  "weight_decay": 0.0001,
  "hidden_dim": 64,
  "num_layers": 2,
  "dropout": 0.1,
  "conv_type": "sage",
  "aggr": "mean",
  "use_batchnorm": true,
  "hidden_dim_grid": [
    32,
    64
  ],
  "num_layers_grid": [
    1,
    2,
    3
  ],
  "dropout_grid": [
    0.0,
    0.1,
    0.3
  ],
  "aggr_grid": [
    "mean",
    "mean+max"
  ],
  "conv_type_grid": [
    "sage",
    "gcn"
  ],
  "mlp_hidden_dim": 64,
  "mlp_layers": 2,
  "num_classes": 1,
  "early_stop": {
    "monitor": "f1",
    "mode": "max",
    "patience": 5,
    "min_delta": 0.001
  },
  "warmup_steps": 200,
  "total_steps": 10000
}
Initial pos_weight (binary): 1.0
GNN params: 14,849
GNN model, optimizer, scheduler, and loss are initialized.

Initialization
  Config 0: lr=0.0002, hidden=32, layers=1, dropout=0.0, aggr=mean, conv_type=sage, params=3,297
  Config 1: lr=0.0002, hidde

### Model Training Loop

#### PRB-GraphSAGE

##### Helper Functions

These helpers implement a compact training/evaluation loop for a binary graph classifier and standard performance metrics. compute_binary_metrics takes ground-truth binary labels and predicted probabilities (post-sigmoid), thresholds them into hard predictions, and computes accuracy, precision, recall, and F1 score with small numerical safeguards. run_epoch iterates over a PyTorch Geometric DataLoader for one epoch in either training or evaluation mode, moves each batch to the target device, runs the model forward on graph data (x, edge_index, batch), computes a scalar loss against flattened labels using a given criterion (e.g., BCEWithLogitsLoss), and in training mode performs backpropagation, optimizer steps, and optional scheduler updates. It accumulates loss weighted by batch size to return an average epoch loss, collects all sigmoid probabilities and labels across batches, and finally calls compute_binary_metrics to produce a metrics dictionary augmented with the averaged loss, or returns zeros if no samples were seen.

In [17]:
import copy
from typing import Dict, Tuple, List
import numpy as np
import torch
import torch.nn as nn

def compute_binary_metrics(y_true: np.ndarray,
                           y_prob: np.ndarray,
                           threshold: float = 0.5) -> Dict[str, float]:
    """
    y_true: [N] in {0,1}
    y_prob: [N] in [0,1]  (sigmoid outputs)
    """
    y_true = y_true.astype(np.int64).ravel()
    y_prob = y_prob.astype(np.float32).ravel()
    y_pred = (y_prob >= threshold).astype(np.int64)

    tp = np.sum((y_pred == 1) & (y_true == 1))
    tn = np.sum((y_pred == 0) & (y_true == 0))
    fp = np.sum((y_pred == 1) & (y_true == 0))
    fn = np.sum((y_pred == 0) & (y_true == 1))

    eps = 1e-9
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1        = 2 * precision * recall / (precision + recall + eps)
    acc       = (tp + tn) / max(1, len(y_true))

    return {
        "acc": float(acc),
        "precision": float(precision),
        "recall": float(recall),
        "f1": float(f1),
    }

def run_epoch(model: nn.Module,
              loader,
              optimizer: torch.optim.Optimizer,
              criterion: nn.Module,
              device: torch.device,
              scheduler=None,
              train: bool = True) -> Tuple[float, Dict[str, float]]:
    """
    Runs one epoch over loader.
    Returns (avg_loss, metrics_dict).
    """
    if train:
        model.train()
    else:
        model.eval()

    all_probs: List[float] = []
    all_labels: List[int] = []
    total_loss = 0.0
    total_samples = 0

    for batch in loader:
        batch = batch.to(device)
        # y: [num_graphs] or [num_graphs, 1]
        y = batch.y.view(-1).float()

        with torch.set_grad_enabled(train):
            logits = model(batch.x, batch.edge_index, batch.batch)  # [num_graphs] or [num_graphs,1]
            logits = logits.view_as(y)
            loss = criterion(logits, y)

            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

        bs = y.size(0)
        total_loss += float(loss.item()) * bs
        total_samples += bs

        probs = torch.sigmoid(logits.detach()).cpu().numpy()
        labels = y.detach().cpu().numpy()
        all_probs.append(probs)
        all_labels.append(labels)

    if total_samples == 0:
        return 0.0, {"acc": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0}

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    avg_loss = total_loss / total_samples
    metrics = compute_binary_metrics(all_labels, all_probs, threshold=0.5)
    metrics["loss"] = float(avg_loss)
    return avg_loss, metrics

##### Training Loop (with Early Stop)

This code implements an early-stopping mechanism and a high-level training loop for a single GNN configuration. The `EarlyStopper` class tracks a chosen validation metric (e.g., F1) over epochs, deciding whether to stop training when the metric fails to improve by at least `min_delta` for a specified number of epochs (`patience`), and it keeps a deep copy of the best model state encountered. The `train_gnn_model` function wires everything together for one run: it constructs a GNN using `make_gnn_model`, sets up an AdamW optimizer and cosine–decay scheduler via `make_gnn_optimizer` and `make_scheduler`, and uses a shared loss (`gnn_criterion`, typically `BCEWithLogitsLoss` for binary classification). For each epoch, it calls `run_epoch` on both the training and validation loaders, capturing losses and metrics, logs progress with train/val loss and F1, appends snapshots into a `history` dict, and feeds validation metrics into `EarlyStopper` to determine if early stopping should trigger. When training ends—either by exhausting the epoch budget or by early stop—the best-performing weights (if any) are restored onto the model, and a result dictionary is returned containing the trained model, metric history, the best monitored metric value, its name, and the corresponding `state_dict` for downstream evaluation or checkpointing.


In [18]:
class EarlyStopper:
    def __init__(self, monitor: str = "f1", mode: str = "max",
                 patience: int = 5, min_delta: float = 1e-3):
        self.monitor = monitor
        self.mode = mode
        self.patience = patience
        self.min_delta = min_delta

        if mode == "max":
            self.best = -float("inf")
        elif mode == "min":
            self.best = float("inf")
        else:
            raise ValueError("mode must be 'max' or 'min'")

        self.num_bad = 0
        self.best_state = None

    def step(self, metrics: Dict[str, float], model: nn.Module) -> bool:
        """
        Returns True if we should stop (patience exhausted).
        """
        current = metrics.get(self.monitor, None)
        if current is None:
            return False

        improved = False
        if self.mode == "max":
            if current > self.best + self.min_delta:
                improved = True
        else:  # 'min'
            if current < self.best - self.min_delta:
                improved = True

        if improved:
            self.best = current
            self.num_bad = 0
            self.best_state = copy.deepcopy(model.state_dict())
        else:
            self.num_bad += 1

        return self.num_bad >= self.patience


def train_gnn_model(
    hyp: Dict,
    train_loader,
    val_loader,
    in_dim: int,
    device: torch.device = DEVICE,
) -> Dict:
    """
    High-level training loop for a single GNN config.
    Returns a dict with best metrics, history, and best_state_dict.
    """

    model = make_gnn_model(hyp, in_dim=in_dim)
    optimizer = make_gnn_optimizer(model, hyp)
    scheduler = make_scheduler(optimizer, hyp["warmup_steps"], hyp["total_steps"])

    criterion = gnn_criterion 

    es_cfg = hyp["early_stop"]
    early_stopper = EarlyStopper(
        monitor=es_cfg.get("monitor", "f1"),
        mode=es_cfg.get("mode", "max"),
        patience=es_cfg.get("patience", 5),
        min_delta=es_cfg.get("min_delta", 1e-3),
    )

    history = {"train": [], "val": []}

    for epoch in range(1, hyp["epochs"] + 1):
        train_loss, train_metrics = run_epoch(
            model, train_loader, optimizer, criterion, device, scheduler, train=True
        )
        val_loss, val_metrics = run_epoch(
            model, val_loader, optimizer, criterion, device, scheduler=None, train=False
        )

        history["train"].append(train_metrics)
        history["val"].append(val_metrics)

        print(
            f"[Epoch {epoch:03d}] "
            f"train_loss={train_metrics['loss']:.4f}, train_f1={train_metrics['f1']:.3f} | "
            f"val_loss={val_metrics['loss']:.4f}, val_f1={val_metrics['f1']:.3f}"
        )

        stop = early_stopper.step(val_metrics, model)
        if stop:
            print(f"Early stopping triggered at epoch {epoch}. Best {early_stopper.monitor} = {early_stopper.best:.4f}")
            break

    # Restore best weights if we got any improvement
    if early_stopper.best_state is not None:
        model.load_state_dict(early_stopper.best_state)

    result = {
        "model": model,
        "history": history,
        "best_metric": early_stopper.best,
        "best_monitor": early_stopper.monitor,
        "best_state_dict": early_stopper.best_state,
    }
    return result

##### Initiate Training

This code performs a simple grid search over GNN hyperparameters and tracks the best-performing configuration based on validation metrics. It iterates through each candidate config generated by `iter_gnn_configs(GNN_HYP)`, prints a short summary of that config (convolution type, hidden dimension, number of layers, dropout, aggregation, and learning rate), and trains a model with `train_gnn_model` on the given train/validation loaders and node feature dimension. The result `out` from each run includes the best monitored metric (e.g., validation F1) for that configuration; the loop keeps `best_overall` and `best_cfg` updated whenever a configuration surpasses the current best metric. After all configurations have been evaluated, it prints the best hyperparameter dictionary and the corresponding best validation score, effectively selecting the strongest model configuration discovered during the sweep.

In [27]:
best_overall = None
best_cfg = None

for cfg in iter_gnn_configs(GNN_HYP):
    print("\n=== Config:", cfg["conv_type"], cfg["hidden_dim"], cfg["num_layers"],
          cfg["dropout"], cfg["aggr"], "lr", cfg["lr"], "===")
    out = train_gnn_model(cfg, train_loader, val_loader, in_dim=D_NODE, device=DEVICE)

    if best_overall is None or out["best_metric"] > best_overall["best_metric"]:
        best_overall = out
        best_cfg = cfg

print("\nBest config:", best_cfg)
print("Best validation", best_overall["best_monitor"], "=", best_overall["best_metric"])


=== Config: sage 32 1 0.0 mean lr 0.0002 ===
[Epoch 001] train_loss=0.7448, train_f1=0.000 | val_loss=0.7011, val_f1=0.000
[Epoch 002] train_loss=0.7092, train_f1=0.000 | val_loss=0.6753, val_f1=0.000
[Epoch 003] train_loss=0.6385, train_f1=0.000 | val_loss=0.5761, val_f1=0.000
[Epoch 004] train_loss=0.5223, train_f1=0.000 | val_loss=0.4332, val_f1=0.000
[Epoch 005] train_loss=0.3577, train_f1=0.000 | val_loss=0.2211, val_f1=0.000
[Epoch 006] train_loss=0.2160, train_f1=0.000 | val_loss=0.1255, val_f1=0.000
Early stopping triggered at epoch 6. Best f1 = 0.0000

=== Config: gcn 32 1 0.0 mean lr 0.0002 ===
[Epoch 001] train_loss=0.6901, train_f1=0.000 | val_loss=0.7423, val_f1=0.000
[Epoch 002] train_loss=0.6493, train_f1=0.000 | val_loss=0.7568, val_f1=0.000
[Epoch 003] train_loss=0.5687, train_f1=0.000 | val_loss=0.7092, val_f1=0.000
[Epoch 004] train_loss=0.4417, train_f1=0.000 | val_loss=0.6321, val_f1=0.000
[Epoch 005] train_loss=0.2875, train_f1=0.000 | val_loss=0.5438, val_f1=0.0



[Epoch 001] train_loss=0.7084, train_f1=0.000 | val_loss=0.6589, val_f1=0.000
[Epoch 002] train_loss=0.6442, train_f1=0.000 | val_loss=0.4521, val_f1=0.000
[Epoch 003] train_loss=0.5044, train_f1=0.000 | val_loss=0.3314, val_f1=0.000
[Epoch 004] train_loss=0.2873, train_f1=0.000 | val_loss=0.2197, val_f1=0.000
[Epoch 005] train_loss=0.1336, train_f1=0.000 | val_loss=0.1256, val_f1=0.000
[Epoch 006] train_loss=0.0684, train_f1=0.000 | val_loss=0.0845, val_f1=0.000
Early stopping triggered at epoch 6. Best f1 = 0.0000

=== Config: gcn 32 1 0.0 mean+max lr 0.0002 ===
[Epoch 001] train_loss=0.7043, train_f1=0.000 | val_loss=0.6631, val_f1=0.000
[Epoch 002] train_loss=0.6468, train_f1=0.000 | val_loss=0.5306, val_f1=0.000
[Epoch 003] train_loss=0.5238, train_f1=0.000 | val_loss=0.3470, val_f1=0.000
[Epoch 004] train_loss=0.3157, train_f1=0.000 | val_loss=0.1584, val_f1=0.000
[Epoch 005] train_loss=0.1412, train_f1=0.000 | val_loss=0.0554, val_f1=0.000
[Epoch 006] train_loss=0.0711, train_f1

### Model Evaluation

#### PRB-GraphSAGE

##### Helper Functions

This code provides evaluation utilities for a trained PRB-GraphSAGE model, computing metrics and confusion statistics on a given dataset. The `evaluate_gnn_model` function runs in no-grad mode, sets the model to eval, and iterates over a PyTorch Geometric loader, moving each batch to the target device, forwarding node features and graph structure through the model to obtain logits, and computing a loss with the same `gnn_criterion` used during training. It accumulates batch-weighted loss, collects sigmoid-transformed probabilities and true labels, and at the end concatenates them to compute average loss and standard binary classification metrics (accuracy, precision, recall, F1) via the existing `compute_binary_metrics` helper, returning both the metrics dictionary and the full `y_true` and `y_prob` arrays. The `summarize_eval_split` helper prints these metrics in a compact, human-readable line for a named split (e.g., “train”, “val”, “test”). Finally, `confusion_from_probs` turns true labels and predicted probabilities into a small 2×2 confusion matrix (tp, tn, fp, fn) using a configurable threshold, making it easy to inspect the model’s error profile on any evaluation split.

In [29]:
# ==== CELL: PRB-GraphSAGE Evaluation (Best Model) ====
import numpy as np
import torch
from typing import Dict, Tuple


@torch.no_grad()
def evaluate_gnn_model(
    model: torch.nn.Module,
    loader,
    device: torch.device,
    threshold: float = 0.5,
) -> Tuple[Dict[str, float], np.ndarray, np.ndarray]:
    """
    Evaluate a trained PRB-GraphSAGE model on a given loader.

    Returns:
        metrics: dict with loss, acc, precision, recall, f1
        y_true:  np.ndarray of true labels (0/1)
        y_prob:  np.ndarray of predicted probabilities p(y=1)
    """
    model.eval()
    all_probs = []
    all_labels = []
    total_loss = 0.0
    total_samples = 0

    for batch in loader:
        batch = batch.to(device)
        y = batch.y.view(-1).float()            # [B]

        logits = model(batch.x, batch.edge_index, batch.batch)  # [B] or [B,1]
        logits = logits.view_as(y)
        loss = gnn_criterion(logits, y)

        bs = y.size(0)
        total_loss += float(loss.item()) * bs
        total_samples += bs

        probs = torch.sigmoid(logits).cpu().numpy()
        labels = y.cpu().numpy()
        all_probs.append(probs)
        all_labels.append(labels)

    if total_samples == 0:
        metrics = {"loss": 0.0, "acc": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0}
        return metrics, np.array([]), np.array([])

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    avg_loss = total_loss / total_samples

    # Reuse the helper you already defined earlier
    metrics = compute_binary_metrics(all_labels, all_probs, threshold=threshold)
    metrics["loss"] = float(avg_loss)
    return metrics, all_labels, all_probs


def summarize_eval_split(name: str, metrics: Dict[str, float]):
    print(
        f"[{name}] "
        f"loss={metrics['loss']:.4f}  "
        f"acc={metrics['acc']:.3f}  "
        f"precision={metrics['precision']:.3f}  "
        f"recall={metrics['recall']:.3f}  "
        f"f1={metrics['f1']:.3f}"
    )


def confusion_from_probs(y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> Dict[str, int]:
    """
    Tiny 2x2 confusion matrix helper for reporting.
    """
    y_true = y_true.astype(np.int64).ravel()
    y_pred = (y_prob >= threshold).astype(np.int64).ravel()

    tp = int(((y_pred == 1) & (y_true == 1)).sum())
    tn = int(((y_pred == 0) & (y_true == 0)).sum())
    fp = int(((y_pred == 1) & (y_true == 0)).sum())
    fn = int(((y_pred == 0) & (y_true == 1)).sum())

    return {"tp": tp, "tn": tn, "fp": fp, "fn": fn}

##### Evaluation

This code evaluates the best-performing PRB-GraphSAGE model found during the hyperparameter sweep on both the training and validation splits to assess fit and generalization. It first retrieves the model with its restored best weights from `best_overall`, moves it to the active device, switches it to evaluation mode, and prints the corresponding best hyperparameter configuration and monitored metric (e.g., validation F1). It then calls `evaluate_gnn_model` on the training DataLoader to compute loss and binary metrics, uses `summarize_eval_split` to print a concise summary line, and derives a 2×2 confusion matrix with `confusion_from_probs` to inspect true/false positives and negatives. The same process is repeated on the validation DataLoader, with the validation metrics and confusion matrix serving as the primary performance numbers to report in experiments or papers.

In [30]:
# Grab the best model (weights are already restored in train_gnn_model)
best_model = best_overall["model"].to(DEVICE)
best_model.eval()

print("Best config:", best_cfg)
print(f"Best monitored metric ({best_overall['best_monitor']}): {best_overall['best_metric']:.4f}\n")

# Evaluate on training split (to see overfit / underfit)
train_metrics, train_y, train_p = evaluate_gnn_model(best_model, train_loader, DEVICE, threshold=0.5)
summarize_eval_split("train", train_metrics)
train_cm = confusion_from_probs(train_y, train_p, threshold=0.5)
print("  train confusion:", train_cm)

# Evaluate on validation split (primary number for the report)
val_metrics, val_y, val_p = evaluate_gnn_model(best_model, val_loader, DEVICE, threshold=0.5)
summarize_eval_split("val", val_metrics)
val_cm = confusion_from_probs(val_y, val_p, threshold=0.5)
print("  val confusion:", val_cm)

Best config: {'batch_size': 32, 'epochs': 40, 'lr': 0.0002, 'lr_grid': [0.0002, 0.0003, 0.0005], 'weight_decay': 0.0001, 'hidden_dim': 32, 'num_layers': 1, 'dropout': 0.0, 'conv_type': 'sage', 'aggr': 'mean', 'use_batchnorm': True, 'hidden_dim_grid': [32, 64], 'num_layers_grid': [1, 2, 3], 'dropout_grid': [0.0, 0.1, 0.3], 'aggr_grid': ['mean', 'mean+max'], 'conv_type_grid': ['sage', 'gcn'], 'mlp_hidden_dim': 64, 'mlp_layers': 2, 'num_classes': 1, 'early_stop': {'monitor': 'f1', 'mode': 'max', 'patience': 5, 'min_delta': 0.001}, 'warmup_steps': 200, 'total_steps': 10000}
Best monitored metric (f1): 0.0000

[train] loss=0.7396  acc=0.270  precision=0.000  recall=0.000  f1=0.000
  train confusion: {'tp': 0, 'tn': 444, 'fp': 1198, 'fn': 0}
[val] loss=0.7011  acc=0.156  precision=0.000  recall=0.000  f1=0.000
  val confusion: {'tp': 0, 'tn': 515, 'fp': 2793, 'fn': 0}
