# PRB-GraphSAGE: Resource Management Reconstruction GNN  
## CPSC 8810 — ML4G Term Project
### Authored by Ryan Barker

This notebook implements a minimal viable **PRB-GraphSAGE** model: an offline, UE-level graph neural network that reconstructs resource-management behavior from noisy per-TTI proportional-fair (PF) scheduler logs. Using the existing TTI-Trust Anamoly Detector + Classifier preprocessing pipeline, each sliding window of TTIs is converted into a **UE-contention graph** whose nodes represent UEs active in the window and whose edges capture co-scheduling (PRB contention) relationships. Node features summarize PRB usage statistics, fairness/PF surrogates, and starvation patterns, while a compact **GraphSAGE → global pooling → MLP** architecture performs **graph-level classification** to distinguish benign scheduling behavior from PF-based PRB starvation indicative of attack conditions. This MVP focuses on methodology rather than headline accuracy, acknowledging the dataset’s timing noise and simulator artifacts, and serves as a prototype for future training on GPU-native, slot-accurate Aerial/cuMAC traces.

### preproc_tti_trust.py

This script is a preprocessing scaffold for the TTI-Trust project that converts lightweight attack/benign CSV “shims” into parquet datasets, derives scheduler-native, identity-agnostic features, and emits windowed shards plus cross-validation splits for PyTorch. It first materializes CSVs into partitioned parquet (by `run_id`), adding a relative time axis `sec_rel` if needed, then for each run it auto-detects UE resource-block columns, computes `total_rb`, normalizes RB usage into per-UE PRB shares, and derives a Jain’s fairness index `J`. On top of these per-UE shares it builds “roles” by tracking an exponential-moving-average of the shares, selecting the top-K dominant UEs per TTI, and aggregating everyone else into a `rest_share` and corresponding “small RB” indicators, plus 1-second presence flags to capture coarse duty cycle behavior. Using these role-based features and the `phase` labels, it slides long and short windows (e.g., 240- and 64-TTI) with multiple strides, labels each window as benign, attack, or ambiguous via a dominance test on the top role (with thresholds like `ATTACK_DOMINANCE`, `OTHERS_MAX`, and `DUTY_KEEP`), and batches them into compressed NumPy `.npz` shards alongside parquet meta tables that record window boundaries, labels, and summary statistics. The script writes enriched per-run parquet files, then populates a `win_shards` directory with both long and short window shards for attack and benign sources. Finally, it builds grouped K-fold splits where entire `run_id`s are kept together in either train or validation, returning index pairs for each fold and printing a brief summary so downstream training can stream shards or assemble cross-validation sets directly from the generated metadata.

In [7]:
# TTI-Trust preprocessing scaffold for PyTorch: identity-agnostic feature composer,
# phase-aware windowing, and grouped K-fold splits.

from __future__ import annotations
import os, re, warnings, itertools
from dataclasses import dataclass
from typing import List, Tuple
from glob import glob

import numpy as np
import pandas as pd
import pyarrow as pa, pyarrow.csv as pv, pyarrow.parquet as pq

# -------------------- DATASETS --------------------
ATTACK_FN = "data/shims/attack_shim_16000_24000.csv" # Committing lightweight shims to the repo due to storage quotas, full 15GB dataset available upon request
BENIGN_FN = "data/shims/benign_shim_8000_12000.csv"  # Committing lightweight shims to the repo due to storage quotas, full 15GB dataset available upon request

# -------------------- Constants aligned to paper --------------------
TTI_SEC = 0.0005
C_PRB   = 106

W_LONG  = 240
W_SHORT = 64

STRIDES_LONG  = [8, 12]
STRIDES_SHORT = [4, 6]

TOP_K = 6
SMALL_RB_EPS = 4

ATTACK_DOMINANCE = 0.90
OTHERS_MAX       = 0.10
DUTY_KEEP        = 0.90

# -------------------- Helpers --------------------

def csv_to_parquet(src_csv: str, dest_root: str):
    os.makedirs(dest_root, exist_ok=True)
    for _root, _dirs, files in os.walk(dest_root):
        if any(f.endswith(".parquet") for f in files):
            return

    convert_cols = {"run_id": pa.large_string(), "phase": pa.large_string()}
    read_opts  = pv.ReadOptions(block_size=1<<26)
    conv_opts  = pv.ConvertOptions(column_types=convert_cols, strings_can_be_null=True)
    t = pv.read_csv(src_csv, read_options=read_opts, convert_options=conv_opts)

    if "run_id" not in t.schema.names:
        stem = os.path.splitext(os.path.basename(src_csv))[0]
        run_id = pa.array(np.full(t.num_rows, stem), type=pa.large_string())
        t = t.append_column("run_id", run_id)

    if "sec_rel" not in t.schema.names:
        n = t.num_rows
        sec_rel = pa.array((np.arange(n, dtype=np.int32) * np.float32(TTI_SEC)), type=pa.float32())
        t = t.append_column("sec_rel", sec_rel)

    pq.write_to_dataset(t, root_path=dest_root, partition_cols=["run_id"],
                        compression="zstd", use_dictionary=True)

def _detect_ue_rb_columns(df: pd.DataFrame) -> List[str]:
    cand = [c for c in df.columns if re.fullmatch(r'UE\d+_rb', c, flags=re.IGNORECASE)]
    if not cand:
        cand = [c for c in df.columns if c.lower().endswith('_rb')]
    def _key(c):
        m = re.search(r'(\d+)', c)
        return int(m.group(1)) if m else 10**9
    return sorted(set(cand), key=_key)

def _ensure_total_rb(df: pd.DataFrame, ue_cols: List[str]) -> pd.DataFrame:
    if 'total_rb' not in df.columns:
        df['total_rb'] = df[ue_cols].sum(axis=1)
    return df

def _compute_shares_and_fairness(df: pd.DataFrame, ue_cols: List[str]) -> pd.DataFrame:
    shares = (df[ue_cols].clip(lower=0).astype('float32') / np.float32(C_PRB))
    shares.columns = [c.replace('_rb','_share') for c in ue_cols]
    for c in shares.columns:
        df[c] = shares[c]
    x = shares.to_numpy(dtype='float32', copy=False)
    sum_x  = x.sum(axis=1, dtype='float32')
    sum_x2 = (x*x).sum(axis=1, dtype='float32')
    n = np.float32(max(1, len(ue_cols)))
    with np.errstate(divide='ignore', invalid='ignore'):
        J = (sum_x*sum_x) / (n * sum_x2)
    df['J'] = np.nan_to_num(J, nan=0.0, posinf=0.0, neginf=0.0).astype('float32')
    return df

def _compose_roles(df: pd.DataFrame, ue_cols: List[str], top_k: int = TOP_K) -> pd.DataFrame:
    share_cols = [c.replace('_rb','_share') for c in ue_cols]
    alpha = 2.0 / (20 + 1.0)
    for c in share_cols:
        df[c+"_ema"] = df[c].astype('float32').ewm(alpha=alpha, adjust=False).mean().astype('float32')

    S   = df[share_cols].to_numpy(dtype='float32', copy=False)
    E   = df[[c+"_ema" for c in share_cols]].to_numpy(dtype='float32', copy=False)
    PRB = df[ue_cols].to_numpy(dtype='int32',    copy=False)

    T, U = S.shape
    k = min(top_k, U)
    order  = np.argsort(-E, axis=1)
    top_ix = order[:, :k]

    roles       = np.take_along_axis(S,   top_ix, axis=1).astype('float32')
    roles_small = np.take_along_axis((PRB < SMALL_RB_EPS).astype('float32'), top_ix, axis=1)

    sum_all    = S.sum(axis=1, dtype='float32')
    sum_topk   = roles.sum(axis=1, dtype='float32')
    rest_share = (sum_all - sum_topk).clip(min=0).astype('float32')

    if U > k:
        row_ix = np.arange(T)[:, None]
        mask = np.ones_like(S, dtype=bool); mask[row_ix, top_ix] = False
        cnt  = mask.sum(axis=1)
        rest_small = ((PRB < SMALL_RB_EPS).astype('float32') * mask).sum(axis=1) / np.maximum(cnt, 1)
    else:
        rest_small = np.zeros(T, dtype='float32')

    for i in range(k):
        df[f'role{i+1}_share']   = roles[:, i]
        df[f'role{i+1}_smallrb'] = roles_small[:, i]
    for i in range(k, top_k):
        df[f'role{i+1}_share']   = np.float32(0.0)
        df[f'role{i+1}_smallrb'] = np.float32(0.0)
    df['rest_share']   = rest_share
    df['rest_smallrb'] = rest_small

    sec_bin = np.floor(df['sec_rel'].to_numpy(dtype='float32')).astype('int32')
    max_sec = int(sec_bin.max(initial=0))
    def sec_presence(vec_bool: np.ndarray) -> np.ndarray:
        agg = np.zeros(max_sec + 1, dtype='uint8')
        np.maximum.at(agg, sec_bin, vec_bool.astype('uint8'))
        return agg[sec_bin]
    for i in range(top_k):
        df[f'role{i+1}_present_1s'] = sec_presence((df[f'role{i+1}_share'].to_numpy() > 0))
    df['rest_present_1s'] = sec_presence((df['rest_share'].to_numpy() > 0))
    return df

def _label_window(shares_mat: np.ndarray) -> bool:
    if shares_mat.size == 0:
        return False
    top    = shares_mat[:, 0]
    others = shares_mat[:, 1:]
    ok = (top >= ATTACK_DOMINANCE) & (np.max(others, axis=1) <= OTHERS_MAX)
    return (ok.mean() >= DUTY_KEEP)

@dataclass
class WindowSpec:
    length_tti: int
    stride_tti: int

def iter_windows(df: pd.DataFrame, run_id: str, spec: WindowSpec, role_cols: List[str],
                 phases_attack={"attack"}, phases_benign={"baseline","recovery","benign"},
                 batch_windows: int = 8192):
    W, S = spec.length_tti, spec.stride_tti
    R = df[role_cols].to_numpy(dtype='float32', copy=False)
    phase = df['phase'].astype(str).to_numpy()
    T = len(df)
    starts = np.arange(0, max(0, T - W + 1), S, dtype=np.int64)

    bufX, rows = [], []
    for s in starts:
        e = s + W
        blk = R[s:e, :]
        if blk.shape[0] != W:
            continue
        win_phase = phase[s:e]
        frac_attack = np.isin(win_phase, list(phases_attack)).mean()
        frac_benign = np.isin(win_phase, list(phases_benign)).mean()
        looks_attacky = _label_window(blk)
        if (frac_attack >= 0.90) and looks_attacky: y = 2
        elif (frac_benign >= 0.90):                y = 0
        else:                                      y = 1

        bufX.append(blk)
        rows.append((run_id, int(s), int(e), int(W), int(y),
                     float(frac_attack), float(frac_benign), int(looks_attacky)))

        if len(bufX) == batch_windows:
            X = np.stack(bufX, axis=0)
            meta = pd.DataFrame.from_records(
                rows,
                columns=["run_id","start_tti","end_tti","window_len","label",
                         "frac_attack","frac_benign","looks_attacky"]
            )
            yield X, meta
            bufX.clear(); rows.clear()

    if bufX:
        X = np.stack(bufX, axis=0)
        meta = pd.DataFrame.from_records(
            rows,
            columns=["run_id","start_tti","end_tti","window_len","label",
                     "frac_attack","frac_benign","looks_attacky"]
        )
        yield X, meta

# -------------------- Main entry: build datasets (partitioned, low RAM) --------------------

csv_to_parquet(ATTACK_FN, "parquet/attack")
csv_to_parquet(BENIGN_FN, "parquet/benign")

def iter_run_partitions(root: str):
    for run_dir in sorted(glob(os.path.join(root, "run_id=*"))):
        rid = os.path.basename(run_dir).split("=", 1)[1]
        df = pd.read_parquet(run_dir, engine="pyarrow")
        if 'run_id' not in df.columns:
            df['run_id'] = rid
        if 'phase' not in df.columns:
            df['phase'] = 'unknown'
        if 'sec_rel' not in df.columns:
            df['sec_rel'] = (np.arange(len(df), dtype=np.int32) * np.float32(TTI_SEC)).astype('float32')
        yield rid, df

def process_source(root: str, src_name: str):
    if not os.path.isdir(root):
        print(f"[{src_name}] No parquet root found at {root} (skipping).")
        return

    os.makedirs("win_shards", exist_ok=True)
    shard_counter = itertools.count(0)

    runs = list(iter_run_partitions(root))
    if not runs:
        print(f"[{src_name}] No run_id partitions found in {root}.")
        return
    print(f"[{src_name}] Found {len(runs)} run_id partitions.")

    for rid, df in runs:
        ue_cols = _detect_ue_rb_columns(df)
        if not ue_cols:
            warnings.warn(f"[{src_name}] run_id={rid}: no UE *_rb columns; skipping")
            continue

        df = _ensure_total_rb(df, ue_cols)
        df = _compute_shares_and_fairness(df, ue_cols)
        df = _compose_roles(df, ue_cols, top_k=TOP_K)

        keep = [f'role{i}_share' for i in range(1, TOP_K+1)] + ['rest_share','J','sec_rel','run_id','phase']
        enriched_out = f"{src_name}_enriched_{rid}.parquet"
        df[keep].to_parquet(enriched_out, index=False)
        print(f"[{src_name}] run_id={rid} → {enriched_out}  (rows={len(df):,})")

        role_cols = [f'role{k}_share' for k in range(1, TOP_K+1)] + ['rest_share']
        g = df.reset_index(drop=True)

        for stride in STRIDES_LONG:
            spec = WindowSpec(length_tti=W_LONG, stride_tti=stride)
            for X, meta in iter_windows(g, rid, spec, role_cols):
                sid = next(shard_counter)
                np.savez_compressed(f"win_shards/long_s{stride}_{src_name}_{rid}_{sid}.npz", X=X)
                meta["source"] = src_name; meta["stride"] = stride; meta["kind"] = "long"
                meta.to_parquet(f"win_shards/long_s{stride}_{src_name}_{rid}_{sid}_meta.parquet", index=False)

        for stride in STRIDES_SHORT:
            spec = WindowSpec(length_tti=W_SHORT, stride_tti=stride)
            for X, meta in iter_windows(g, rid, spec, role_cols):
                sid = next(shard_counter)
                np.savez_compressed(f"win_shards/short_s{stride}_{src_name}_{rid}_{sid}.npz", X=X)
                meta["source"] = src_name; meta["stride"] = stride; meta["kind"] = "short"
                meta.to_parquet(f"win_shards/short_s{stride}_{src_name}_{rid}_{sid}_meta.parquet", index=False)

process_source("parquet/attack", "attack")
process_source("parquet/benign", "benign")

def load_all_meta(kind: str) -> pd.DataFrame:
    metas = [pd.read_parquet(p) for p in glob(f"win_shards/{kind}_*_meta.parquet")]
    return pd.concat(metas, ignore_index=True) if metas else pd.DataFrame()

M_long  = load_all_meta("long")
M_short = load_all_meta("short")

def grouped_kfold(meta: pd.DataFrame, K: int = 5, seed: int = 42) -> List[Tuple[np.ndarray, np.ndarray]]:
    if meta.empty:
        return []
    rng = np.random.default_rng(seed)
    run_ids = meta['run_id'].astype(str).unique().tolist()
    rng.shuffle(run_ids)
    folds = [set() for _ in range(K)]
    for i, rid in enumerate(run_ids):
        folds[i % K].add(rid)
    splits = []
    for i in range(K):
        val_rids = folds[i]
        train_idx = meta.index[~meta['run_id'].astype(str).isin(val_rids)].to_numpy()
        val_idx   = meta.index[ meta['run_id'].astype(str).isin(val_rids)].to_numpy()
        splits.append((train_idx, val_idx))
    return splits

splits_long  = grouped_kfold(M_long,  K=5, seed=2025)
splits_short = grouped_kfold(M_short, K=5, seed=2025)

if len(M_long):
    print(f"\n[long] meta={M_long.shape}, groups={M_long['run_id'].nunique()}, folds={len(splits_long)}")
if len(M_short):
    print(f"[short] meta={M_short.shape}, groups={M_short['run_id'].nunique()}, folds={len(splits_short)}")

print("\nPreprocessing complete. Shards are in ./win_shards; use meta to assemble CV splits or stream for training.")

[attack] Found 1 run_id partitions.
[attack] run_id=040933 → attack_enriched_040933.parquet  (rows=8,001)
[benign] Found 1 run_id partitions.
[benign] run_id=004912 → benign_enriched_004912.parquet  (rows=4,001)

[long] meta=(2403, 11), groups=2, folds=5
[short] meta=(4950, 11), groups=2, folds=5

Preprocessing complete. Shards are in ./win_shards; use meta to assemble CV splits or stream for training.


### Data Loader

This code defines a high-throughput PyTorch data pipeline for the TTI-Trust project that streams scheduler-native, identity-agnostic features directly from precomputed NPZ shards and their parquet metadata. It provides a set of NumPy-based feature helpers to compute Jain’s fairness index over time, causal rolling minima/medians, run-lengths of “small RB” streaks, approximate PRB plateau histograms, benign-grant fractions, and contiguous zero-run statistics across non-dominant UEs. The `TTINPZShardBatches` `IterableDataset` discovers shard files under `win_shards`, loads each shard’s `[N, W, D]` role-based share tensor and labels, and for each batch constructs a rich sequential feature tensor `X_seq` by concatenating the original shares, temporal deltas, fairness trajectories (instant, min, median), and small-RB run-lengths across the window. In parallel, it builds a per-window auxiliary feature vector `X_aux` that captures plateau occupancy, how often other UEs receive grants, zero-run lengths, and placeholder “radio hint” slots, while also packaging minimal metadata (run ID, source, stride, window bounds) into a Python dict list. The `make_loader` factory wraps this iterable in a PyTorch `DataLoader` configured for streaming (no secondary batching), with multiple workers, pinned memory, persistent workers, and prefetching tuned for large datasets. A small smoke test at the end attempts to instantiate short-window loaders, pull the first batch, and print tensor shapes and sample metadata, illustrating how downstream models should consume `(xb_seq, xb_aux, yb, mb)` in training.

In [21]:
# PyTorch feature space & data loader for TTI-Trust (scheduler-native, identity-agnostic)
# Optimized to stream large datasets from shards: win_shards/{long,short}_s*_... .npz + *_meta.parquet
#
# Emits batches directly (IterableDataset) to avoid in-memory concatenation or per-item pandas operations.

import os, re, glob, math
from typing import Dict, List, Tuple, Optional, Iterator
import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset, DataLoader

# --- Shared constants (align with preprocessing) ---
TTI_SEC = 0.0005
C_PRB   = 106
TOP_K   = 6
SMALL_RB_EPS = 4

PLATEAUS = np.array([16, 31, 46, 61, 76, 91, 106], dtype=np.int32)

# ---------- Feature helpers (NumPy; no pandas in the hot path) ----------

def jains_fairness_seq(shares: np.ndarray) -> np.ndarray:
    # shares: [N, W, D] or [W, D]
    if shares.ndim == 2:
        shares = shares[None, ...]
    sum_x  = shares.sum(axis=2, dtype=np.float32)                     # [N, W]
    sum_x2 = (shares * shares).sum(axis=2, dtype=np.float32)          # [N, W]
    n = shares.shape[2]
    J = (sum_x * sum_x) / (np.float32(n) * sum_x2 + 1e-9)
    return J.squeeze(0).astype(np.float32) if J.shape[0] == 1 else J.astype(np.float32)

def rolling_min_med_causal(x: np.ndarray, win: int) -> Tuple[np.ndarray, np.ndarray]:
    # x: [W] (per-window). W<=240 → simple loop is fine and cache-friendly.
    W = x.shape[0]
    rmin = np.empty(W, dtype=np.float32)
    rmed = np.empty(W, dtype=np.float32)
    for t in range(W):
        a = max(0, t - win + 1)
        sl = x[a:t+1]
        rmin[t] = sl.min() if sl.size else 0.0
        rmed[t] = np.median(sl) if sl.size else 0.0
    return rmin, rmed

def small_rb_runlength_1d(streak_bin: np.ndarray) -> np.ndarray:
    out = np.zeros_like(streak_bin, dtype=np.float32)
    c = 0.0
    for i, b in enumerate(streak_bin.astype(bool)):
        c = c + 1.0 if b else 0.0
        out[i] = c
    return out

def build_c_runs(srb_bin: np.ndarray) -> np.ndarray:
    # srb_bin: [W, D] 0/1 → run-length per channel
    return np.stack([small_rb_runlength_1d(srb_bin[:, j]) for j in range(srb_bin.shape[1])], axis=1).astype(np.float32)

def approx_plateau_hist(share_seq: np.ndarray) -> np.ndarray:
    # share_seq: [W, D] → PRBs → closest plateau histogram normalized
    prbs = np.rint(share_seq * C_PRB).astype(np.int32)
    idx  = np.abs(prbs[:, :, None] - PLATEAUS[None, None, :]).argmin(axis=2)  # [W, D]
    hist = np.bincount(idx.ravel(), minlength=len(PLATEAUS)).astype(np.float32)
    if hist.sum() > 0: hist /= hist.sum()
    return hist

def benign_grant_fraction(share_seq: np.ndarray) -> float:
    others = share_seq[:, 1:]
    return float((others > 0).any(axis=1).mean())

def contiguous_zero_runs(share_seq: np.ndarray) -> Tuple[float, float]:
    others_all_zero = (share_seq[:, 1:] == 0).all(axis=1).astype(np.int32)  # [W]
    runs, c = [], 0
    for b in others_all_zero:
        if b: c += 1
        else:
            if c: runs.append(c); c = 0
    if c: runs.append(c)
    return (0.0, 0.0) if not runs else (float(max(runs)), float(np.mean(runs)))

# ---------- IterableDataset over shard files ----------

class TTINPZShardBatches(IterableDataset):
    """
    Streams batches from NPZ shards produced by preprocessing.
    Each shard file has X: [N, W, D] where D = TOP_K+1 (roles incl. rest), and a paired *_meta.parquet.

    Yields (X_seq, X_aux, y, meta_list) where:
      X_seq: [B, W, Dseq]  (roles, d_roles, J, J_min, J_med, smallRB runs)
      X_aux: [B, Daux]     (plateau hist, benign frac, zero-run stats, radio hints if available)
      y:     [B] {0,1,2}
      meta_list: list of dicts with run_id/source/stride/window_len/start_tti/end_tti
    """
    def __init__(self, kind: str, batch_size: int = 1024, shard_glob: Optional[str] = None,
                 restrict_sources: Optional[List[str]] = None):
        assert kind in ("long", "short")
        self.kind = kind
        self.batch_size = int(batch_size)
        # Discover shards
        patt = shard_glob or f"win_shards/{kind}_*.npz"
        self.shard_paths = sorted(glob.glob(patt))
        if not self.shard_paths:
            raise FileNotFoundError(f"No shards found for pattern: {patt}. Run preprocessing to generate win_shards.")
        self.restrict_sources = set(restrict_sources) if restrict_sources else None

    def __iter__(self) -> Iterator:
        for npz_path in self.shard_paths:
            meta_path = npz_path.replace(".npz", "_meta.parquet")
            if not os.path.exists(meta_path):
                continue
            meta = pd.read_parquet(meta_path)
            if self.restrict_sources is not None:
                if "source" in meta.columns:
                    if not any(src in self.restrict_sources for src in meta["source"].unique().tolist()):
                        continue
            X = np.load(npz_path)["X"]  # [N, W, D], float32 from preprocessing
            # Labels
            if "label" not in meta.columns:
                raise ValueError(f"Meta {meta_path} missing 'label' column.")
            y_all = meta["label"].to_numpy(dtype=np.int64)

            # Optional radio hints (rare). If present as per-TTI per-role, we’d read a second array.
            have_radio = False  # keep simple; emit zeros + flag

            N, W, D = X.shape
            assert D == (TOP_K + 1), f"Expected D=TOP_K+1, got D={D}"

            # Batch over this shard
            for i0 in range(0, N, self.batch_size):
                i1 = min(N, i0 + self.batch_size)
                seq = X[i0:i1]                                  # [B, W, D]
                # Δshares
                dseq = np.diff(seq, axis=1, prepend=seq[:, 0:1, :])  # [B, W, D]

                # Jain's J(t) + causal rolling stats (per sample)
                J_list, Jmin_list, Jmed_list = [], [], []
                for b in range(seq.shape[0]):
                    Jb = jains_fairness_seq(seq[b])                   # [W]
                    Jmin, Jmed = rolling_min_med_causal(Jb, win=min(60, Jb.shape[0]))
                    J_list.append(Jb); Jmin_list.append(Jmin); Jmed_list.append(Jmed)
                J   = np.stack(J_list,    axis=0)[:, :, None]         # [B, W, 1]
                Jmn = np.stack(Jmin_list, axis=0)[:, :, None]         # [B, W, 1]
                Jmd = np.stack(Jmed_list, axis=0)[:, :, None]         # [B, W, 1]

                # small-RB run-lengths (approx from shares threshold)
                prb = np.rint(seq * C_PRB)                            # [B, W, D]
                srb = (prb < SMALL_RB_EPS).astype(np.float32)
                c_runs = np.stack([build_c_runs(srb[b]) for b in range(srb.shape[0])], axis=0)  # [B, W, D]

                # Time-step channels
                X_seq = np.concatenate([seq, dseq, J, Jmn, Jmd, c_runs], axis=2).astype(np.float32)  # [B, W, Dseq]

                # Aux features per window
                aux_list = []
                for b in range(seq.shape[0]):
                    ph = approx_plateau_hist(seq[b])
                    bf = benign_grant_fraction(seq[b])
                    zmax, zmean = contiguous_zero_runs(seq[b])
                    if have_radio:
                        radio_aux = np.zeros(33, dtype=np.float32)
                        radio_flag = 1.0
                    else:
                        radio_aux = np.zeros(33, dtype=np.float32)
                        radio_flag = 0.0
                    aux = np.concatenate([ph, np.array([bf, zmax, zmean], dtype=np.float32),
                                          radio_aux, np.array([radio_flag], dtype=np.float32)])
                    aux_list.append(aux)
                X_aux = np.stack(aux_list, axis=0).astype(np.float32)  # [B, 44]

                y = torch.from_numpy(y_all[i0:i1])                     # [B], int64

                # Build meta dicts list (keep only essentials)
                meta_batch = []
                cols = meta.columns
                for j in range(i0, i1):
                    row = meta.iloc[j]
                    meta_batch.append({
                        "run_id": str(row.get("run_id", "")),
                        "source": str(row.get("source", "")),
                        "stride": int(row.get("stride", 0)),
                        "window_len": int(row.get("window_len", seq.shape[1])),
                        "start_tti": int(row.get("start_tti", 0)),
                        "end_tti": int(row.get("end_tti", 0)),
                    })

                # Zero-copy to torch for X_seq/X_aux
                yield torch.from_numpy(X_seq), torch.from_numpy(X_aux), y, meta_batch

# ---------- High-throughput DataLoader factory ----------

def make_loader(kind: str, batch_size: int = 1024, num_workers: int = 4,
                restrict_sources: Optional[List[str]] = None) -> DataLoader:
    ds = TTINPZShardBatches(kind=kind, batch_size=batch_size, restrict_sources=restrict_sources)
    loader = DataLoader(
        ds,
        batch_size=None,               # dataset already yields batches
        num_workers=num_workers,       # tune: 2–8 based on CPU/IO
        pin_memory=True,               # faster H2D copies
        persistent_workers=True,
        prefetch_factor=2
    )
    return loader

# ---------- Smoke test (requires win_shards present) ----------

try:
    train_loader = make_loader(kind="short", batch_size=1024, num_workers=4)
    val_loader   = make_loader(kind="short", batch_size=1024, num_workers=4)
    print("Streaming loader ready. First batch shape:")
    xb_seq, xb_aux, yb, mb = next(iter(train_loader))
    print("X_seq:", tuple(xb_seq.shape), "X_aux:", tuple(xb_aux.shape), "y:", tuple(yb.shape), "meta[0]:", mb[0])
except Exception as e:
    print("Note:", str(e))

print("\nUse: for xb_seq, xb_aux, yb, mb in make_loader('short', ...): model(xb_seq, xb_aux) ...")

Streaming loader ready. First batch shape:
X_seq: (1024, 64, 24) X_aux: (1024, 44) y: (1024,) meta[0]: {'run_id': '040933', 'source': 'attack', 'stride': 4, 'window_len': 64, 'start_tti': 0, 'end_tti': 64}

Use: for xb_seq, xb_aux, yb, mb in make_loader('short', ...): model(xb_seq, xb_aux) ...


### Base Model Architecture

#### PRB-GraphSAGE Architecture
This code defines `PRBGraphSAGE`, a configurable PyTorch Geometric graph neural network for classifying UE-level contention graphs in a PRB starvation detection task. The model builds a GNN backbone with a user-specified number of layers (`num_layers`), hidden dimension, and convolution type (`SAGEConv`, `GCNConv`, `GraphConv`, or `GATv2Conv`), optionally followed by batch normalization and dropout after each layer. In the forward pass, node features `x` are iteratively updated using the chosen conv layers over edges `edge_index`, then aggregated into graph-level embeddings via a selectable global pooling strategy (`mean`, `max`, or concatenated `mean+max`). This pooled representation is passed through an MLP head whose depth and hidden size are also configurable, producing either a scalar logit per graph for binary classification (`num_classes == 1`) or a multi-dimensional logit vector for multi-class problems. The interface is designed to be hyperparameter-sweep friendly, with `edge_attr` included in the signature as a placeholder for future edge-aware enhancements, while the current implementation focuses on node features and graph-level pooling.

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
    SAGEConv,
    GCNConv,
    GraphConv,
    GATv2Conv,
    global_mean_pool,
    global_max_pool,
)


class PRBGraphSAGE(nn.Module):
    """
    PRB-GraphSAGE: UE-level contention graph classifier for PRB starvation detection.

    This model is designed to be easy to sweep over hyperparameters:
      - in_dim:         node feature dimension
      - hidden_dim:     hidden channel width
      - num_layers:     number of GNN layers (>=1)
      - conv_type:      GNN layer type: {"sage", "gcn", "graph", "gat"}
      - aggr:           global pooling: {"mean", "max", "mean+max"}
      - mlp_hidden_dim: hidden size of the final MLP head
      - mlp_layers:     number of linear layers in the MLP head (>=1)
      - dropout:        dropout probability applied after each conv & MLP layer
      - num_classes:    1 for binary logit, >1 for multi-class logits

    Forward signature:
        logits = model(x, edge_index, batch, edge_attr=None)

      - x:          [N, in_dim] node features
      - edge_index: [2, E] COO edge indices
      - batch:      [N] graph IDs for global pooling
      - edge_attr:  [E, d_edge] (currently ignored; placeholder for future use)
    """

    def __init__(
        self,
        in_dim: int,
        hidden_dim: int = 64,
        num_layers: int = 2,
        conv_type: str = "sage",
        aggr: str = "mean",
        mlp_hidden_dim: int = 64,
        mlp_layers: int = 2,
        dropout: float = 0.1,
        num_classes: int = 1,
        use_batchnorm: bool = True,
    ):
        super().__init__()

        assert num_layers >= 1, "num_layers must be >= 1"
        assert mlp_layers >= 1, "mlp_layers must be >= 1"
        assert conv_type in {"sage", "gcn", "graph", "gat"}, f"Unknown conv_type: {conv_type}"
        assert aggr in {"mean", "max", "mean+max"}, f"Unknown aggr: {aggr}"
        self.aggr = aggr
        self.dropout = float(dropout)
        self.num_classes = int(num_classes)
        self.use_batchnorm = bool(use_batchnorm)

        # ---- GNN backbone ----
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList() if use_batchnorm else None

        # input layer
        self.convs.append(self._make_conv(conv_type, in_dim, hidden_dim))
        if use_batchnorm:
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        # hidden layers
        for _ in range(num_layers - 1):
            self.convs.append(self._make_conv(conv_type, hidden_dim, hidden_dim))
            if use_batchnorm:
                self.bns.append(nn.BatchNorm1d(hidden_dim))

        # ---- Readout (global pooling) ----
        # mean+max concatenation doubles the channel dimension
        readout_dim = hidden_dim if aggr in {"mean", "max"} else hidden_dim * 2

        # ---- MLP classifier head ----
        mlp_layers_list = []
        in_mlp = readout_dim
        for li in range(mlp_layers):
            out_mlp = mlp_hidden_dim if li < mlp_layers - 1 else num_classes
            mlp_layers_list.append(nn.Linear(in_mlp, out_mlp))
            if li < mlp_layers - 1:
                mlp_layers_list.append(nn.SiLU())
                if dropout > 0:
                    mlp_layers_list.append(nn.Dropout(dropout))
            in_mlp = out_mlp

        self.mlp = nn.Sequential(*mlp_layers_list)

    @staticmethod
    def _make_conv(conv_type: str, in_dim: int, out_dim: int):
        """Factory for different conv types to make hyperparameter sweeps easy."""
        if conv_type == "sage":
            return SAGEConv(in_dim, out_dim)
        if conv_type == "gcn":
            return GCNConv(in_dim, out_dim)
        if conv_type == "graph":
            return GraphConv(in_dim, out_dim)
        if conv_type == "gat":
            return GATv2Conv(in_dim, out_dim, heads=1, concat=False)
        raise ValueError(f"Unknown conv_type: {conv_type}")

    def forward(self, x, edge_index, batch, edge_attr=None):
        """
        x:          [N, in_dim]
        edge_index: [2, E]
        batch:      [N] graph IDs
        edge_attr:  [E, d_edge] (ignored for now; reserved for future edge-aware layers)
        """
        # ---- GNN backbone ----
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = F.silu(x)
            if self.use_batchnorm:
                x = self.bns[i](x)
            if self.dropout > 0:
                x = F.dropout(x, p=self.dropout, training=self.training)

        # ---- Global pooling ----
        if self.aggr == "mean":
            g = global_mean_pool(x, batch)
        elif self.aggr == "max":
            g = global_max_pool(x, batch)
        else:  # "mean+max"
            g_mean = global_mean_pool(x, batch)
            g_max = global_max_pool(x, batch)
            g = torch.cat([g_mean, g_max], dim=-1)

        # ---- MLP head ----
        logits = self.mlp(g)  # [num_graphs, num_classes]

        # For binary classification, most training loops will want shape [num_graphs]
        if self.num_classes == 1:
            logits = logits.squeeze(-1)

        return logits

### Base Model Initialization

#### Hyperparameters

This code sets up initialization, configuration, and lightweight hyperparameter sweeping for a PRB-GraphSAGE GNN used in PRB starvation detection. It begins by seeding Python, NumPy, and PyTorch RNGs for reproducibility, choosing a CUDA or CPU device, and defining `D_NODE` as the node feature dimension (with a placeholder warning until real graph data is available). A `GNN_HYP` dictionary gathers all key hyperparameters in one place—training settings (batch size, epochs, learning rate and grid, weight decay), backbone architecture (hidden size, number of layers, dropout, conv type, pooling strategy, batchnorm), MLP head configuration, output dimension (binary vs multi-class), early-stopping criteria, and scheduler parameters (warmup and total steps). A helper `compute_pos_weight_from_labels` computes class weights for `BCEWithLogitsLoss` from imbalanced binary labels, with a default `pos_weight` of 1.0 until real labels are known. The `make_gnn_model` factory instantiates a `PRBGraphSAGE` model from a hyperparameter dict and moves it to the chosen device, while `count_params` reports trainable parameter count. Optimizer and scheduler builders create an AdamW optimizer and a LambdaLR scheduler implementing linear warmup followed by cosine decay over `total_steps`. The script then instantiates a default model, optimizer, scheduler, and a binary `BCEWithLogitsLoss` using the current `pos_weight`, printing summaries. Finally, `iter_gnn_configs` defines a simple grid-search generator over selected hyperparameters (learning rate, hidden size, number of layers, dropout, aggregation, and conv type), and the code demonstrates its use by instantiating and printing the first few candidate configurations along with their parameter counts, providing a ready-made loop for hyperparameter sweeps.

In [22]:
# ==== CELL: GNN Model + Hyperparameter Initialization ====
import os, math, json, random
from typing import Dict, Iterable

import numpy as np
import torch
import torch.nn as nn

# ---- Reproducibility & device ----
SEED = 2025

def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_all(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

try:
    D_NODE
except NameError:
    D_NODE = 16  # placeholder for later
    print("WARNING: D_NODE is using a placeholder value (16). Set it to your actual node feature dim.")

# ---- GNN Hyperparameters (with small grids for sweeps) ----
GNN_HYP: Dict = {
    # ---- Training-level hparams ----
    "batch_size": 32,              
    "epochs": 40,                  
    "lr": 3e-4,                    
    "lr_grid": [2e-4, 3e-4, 5e-4], 
    "weight_decay": 1e-4,

    # ---- Backbone architecture ----
    # Base config (used if you don't sweep)
    "hidden_dim": 64,
    "num_layers": 2,
    "dropout": 0.1,
    "conv_type": "sage",           
    "aggr": "mean",                
    "use_batchnorm": True,

    # Grids for a compact but meaningful sweep
    "hidden_dim_grid": [32, 64],           
    "num_layers_grid": [1, 2, 3],          
    "dropout_grid": [0.0, 0.1, 0.3],       
    "aggr_grid": ["mean", "mean+max"],     
    "conv_type_grid": ["sage", "gcn"],     

    # ---- MLP head ----
    "mlp_hidden_dim": 64,
    "mlp_layers": 2,

    # ---- Output ----
    # 1 = binary logit (BCEWithLogits); >1 = multi-class (CrossEntropy)
    "num_classes": 1,

    # ---- Early-stop config (used by your training loop) ----
    "early_stop": {
        "monitor": "f1",            
        "mode": "max",              
        "patience": 5,              
        "min_delta": 1e-3,          
    },

    # ---- Scheduler placeholders (updated once loaders are known) ----
    "warmup_steps": 200,
    "total_steps": 10000,
}

print("GNN Hyperparameters:", json.dumps(GNN_HYP, indent=2))


# ---- Optional class weights for binary BCEWithLogitsLoss ----
# You can adapt this to your graph meta if you materialize labels in a DataFrame.
def compute_pos_weight_from_labels(labels: np.ndarray) -> torch.Tensor:
    """
    Compute pos_weight for BCEWithLogitsLoss from a binary label vector.
    pos_weight = (#neg / #pos)
    """
    labels = labels.astype(np.int64).ravel()
    pos = (labels == 1).sum()
    neg = (labels == 0).sum()
    if pos == 0:
        return torch.tensor(1.0, dtype=torch.float32)
    return torch.tensor(float(neg) / float(pos), dtype=torch.float32)

pos_weight = torch.tensor(1.0, dtype=torch.float32)
print("Initial pos_weight (binary):", float(pos_weight))


# ---- PRB-GraphSAGE factory ----
def make_gnn_model(hyp: Dict, in_dim: int) -> PRBGraphSAGE:
    """
    Instantiate a PRBGraphSAGE model from a hyperparameter dict.
    """
    model = PRBGraphSAGE(
        in_dim=in_dim,
        hidden_dim=hyp["hidden_dim"],
        num_layers=hyp["num_layers"],
        conv_type=hyp["conv_type"],      
        aggr=hyp["aggr"],                
        mlp_hidden_dim=hyp["mlp_hidden_dim"],
        mlp_layers=hyp["mlp_layers"],
        dropout=hyp["dropout"],
        num_classes=hyp["num_classes"],
        use_batchnorm=hyp["use_batchnorm"],
    )
    return model.to(DEVICE)



def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)


# ---- Optimizer & Scheduler helpers ----
def make_gnn_optimizer(model: nn.Module, hyp: Dict) -> torch.optim.Optimizer:
    return torch.optim.AdamW(
        model.parameters(),
        lr=hyp["lr"],
        weight_decay=hyp["weight_decay"],
    )

def make_scheduler(optimizer, warmup_steps: int, total_steps: int):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


# ---- Instantiate a default GNN model + optimizer/scheduler ----
gnn_model = make_gnn_model(GNN_HYP, D_NODE)
print(f"GNN params: {count_params(gnn_model):,}")

gnn_optim = make_gnn_optimizer(gnn_model, GNN_HYP)
gnn_scheduler = make_scheduler(gnn_optim, GNN_HYP["warmup_steps"], GNN_HYP["total_steps"])

# Binary classification by default (y ∈ {0,1})
gnn_criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(DEVICE))

print("GNN model, optimizer, scheduler, and loss are initialized.")


# ---- Hyperparameter sweep helper (config generator only) ----
def iter_gnn_configs(base: Dict) -> Iterable[Dict]:
    """
    Simple grid over a few key hyperparameters.
    You can extend/prune this as needed.
    """
    for lr in base["lr_grid"]:
        for hidden in base["hidden_dim_grid"]:
            for num_layers in base["num_layers_grid"]:
                for dropout in base["dropout_grid"]:
                    for aggr in base["aggr_grid"]:
                        for conv_type in base.get("conv_type_grid", [base["conv_type"]]):
                            cfg = dict(base)  # shallow copy
                            cfg["lr"] = lr
                            cfg["hidden_dim"] = hidden
                            cfg["num_layers"] = num_layers
                            cfg["dropout"] = dropout
                            cfg["aggr"] = aggr
                            cfg["conv_type"] = conv_type
                            yield cfg

print("\nInitialization")
for i, cfg in enumerate(iter_gnn_configs(GNN_HYP)):
    if i >= 3:
        break
    m = make_gnn_model(cfg, D_NODE)
    print(
        f"  Config {i}: lr={cfg['lr']}, hidden={cfg['hidden_dim']}, "
        f"layers={cfg['num_layers']}, dropout={cfg['dropout']}, "
        f"aggr={cfg['aggr']}, conv_type={cfg['conv_type']}, "
        f"params={count_params(m):,}"
    )

Device: cuda
GNN Hyperparameters: {
  "batch_size": 32,
  "epochs": 40,
  "lr": 0.0003,
  "lr_grid": [
    0.0002,
    0.0003,
    0.0005
  ],
  "weight_decay": 0.0001,
  "hidden_dim": 64,
  "num_layers": 2,
  "dropout": 0.1,
  "conv_type": "sage",
  "aggr": "mean",
  "use_batchnorm": true,
  "hidden_dim_grid": [
    32,
    64
  ],
  "num_layers_grid": [
    1,
    2,
    3
  ],
  "dropout_grid": [
    0.0,
    0.1,
    0.3
  ],
  "aggr_grid": [
    "mean",
    "mean+max"
  ],
  "conv_type_grid": [
    "sage",
    "gcn"
  ],
  "mlp_hidden_dim": 64,
  "mlp_layers": 2,
  "num_classes": 1,
  "early_stop": {
    "monitor": "f1",
    "mode": "max",
    "patience": 5,
    "min_delta": 0.001
  },
  "warmup_steps": 200,
  "total_steps": 10000
}
Initial pos_weight (binary): 1.0
GNN params: 14,849
GNN model, optimizer, scheduler, and loss are initialized.

Initialization
  Config 0: lr=0.0002, hidden=32, layers=1, dropout=0.0, aggr=mean, conv_type=sage, params=3,297
  Config 1: lr=0.0002, hidde

### Model Training Loop

#### PRB-GraphSAGE

##### Helper Functions

In [17]:
import copy
from typing import Dict, Tuple, List
import numpy as np
import torch
import torch.nn as nn

def compute_binary_metrics(y_true: np.ndarray,
                           y_prob: np.ndarray,
                           threshold: float = 0.5) -> Dict[str, float]:
    """
    y_true: [N] in {0,1}
    y_prob: [N] in [0,1]  (sigmoid outputs)
    """
    y_true = y_true.astype(np.int64).ravel()
    y_prob = y_prob.astype(np.float32).ravel()
    y_pred = (y_prob >= threshold).astype(np.int64)

    tp = np.sum((y_pred == 1) & (y_true == 1))
    tn = np.sum((y_pred == 0) & (y_true == 0))
    fp = np.sum((y_pred == 1) & (y_true == 0))
    fn = np.sum((y_pred == 0) & (y_true == 1))

    eps = 1e-9
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1        = 2 * precision * recall / (precision + recall + eps)
    acc       = (tp + tn) / max(1, len(y_true))

    return {
        "acc": float(acc),
        "precision": float(precision),
        "recall": float(recall),
        "f1": float(f1),
    }

def run_epoch(model: nn.Module,
              loader,
              optimizer: torch.optim.Optimizer,
              criterion: nn.Module,
              device: torch.device,
              scheduler=None,
              train: bool = True) -> Tuple[float, Dict[str, float]]:
    """
    Runs one epoch over loader.
    Returns (avg_loss, metrics_dict).
    """
    if train:
        model.train()
    else:
        model.eval()

    all_probs: List[float] = []
    all_labels: List[int] = []
    total_loss = 0.0
    total_samples = 0

    for batch in loader:
        batch = batch.to(device)
        # y: [num_graphs] or [num_graphs, 1]
        y = batch.y.view(-1).float()

        with torch.set_grad_enabled(train):
            logits = model(batch.x, batch.edge_index, batch.batch)  # [num_graphs] or [num_graphs,1]
            logits = logits.view_as(y)
            loss = criterion(logits, y)

            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

        bs = y.size(0)
        total_loss += float(loss.item()) * bs
        total_samples += bs

        probs = torch.sigmoid(logits.detach()).cpu().numpy()
        labels = y.detach().cpu().numpy()
        all_probs.append(probs)
        all_labels.append(labels)

    if total_samples == 0:
        return 0.0, {"acc": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0}

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    avg_loss = total_loss / total_samples
    metrics = compute_binary_metrics(all_labels, all_probs, threshold=0.5)
    metrics["loss"] = float(avg_loss)
    return avg_loss, metrics

##### Training Loop (with Early Stop)

In [18]:
class EarlyStopper:
    def __init__(self, monitor: str = "f1", mode: str = "max",
                 patience: int = 5, min_delta: float = 1e-3):
        self.monitor = monitor
        self.mode = mode
        self.patience = patience
        self.min_delta = min_delta

        if mode == "max":
            self.best = -float("inf")
        elif mode == "min":
            self.best = float("inf")
        else:
            raise ValueError("mode must be 'max' or 'min'")

        self.num_bad = 0
        self.best_state = None

    def step(self, metrics: Dict[str, float], model: nn.Module) -> bool:
        """
        Returns True if we should stop (patience exhausted).
        """
        current = metrics.get(self.monitor, None)
        if current is None:
            return False

        improved = False
        if self.mode == "max":
            if current > self.best + self.min_delta:
                improved = True
        else:  # 'min'
            if current < self.best - self.min_delta:
                improved = True

        if improved:
            self.best = current
            self.num_bad = 0
            self.best_state = copy.deepcopy(model.state_dict())
        else:
            self.num_bad += 1

        return self.num_bad >= self.patience


def train_gnn_model(
    hyp: Dict,
    train_loader,
    val_loader,
    in_dim: int,
    device: torch.device = DEVICE,
) -> Dict:
    """
    High-level training loop for a single GNN config.
    Returns a dict with best metrics, history, and best_state_dict.
    """

    model = make_gnn_model(hyp, in_dim=in_dim)
    optimizer = make_gnn_optimizer(model, hyp)
    scheduler = make_scheduler(optimizer, hyp["warmup_steps"], hyp["total_steps"])

    criterion = gnn_criterion 

    es_cfg = hyp["early_stop"]
    early_stopper = EarlyStopper(
        monitor=es_cfg.get("monitor", "f1"),
        mode=es_cfg.get("mode", "max"),
        patience=es_cfg.get("patience", 5),
        min_delta=es_cfg.get("min_delta", 1e-3),
    )

    history = {"train": [], "val": []}

    for epoch in range(1, hyp["epochs"] + 1):
        train_loss, train_metrics = run_epoch(
            model, train_loader, optimizer, criterion, device, scheduler, train=True
        )
        val_loss, val_metrics = run_epoch(
            model, val_loader, optimizer, criterion, device, scheduler=None, train=False
        )

        history["train"].append(train_metrics)
        history["val"].append(val_metrics)

        print(
            f"[Epoch {epoch:03d}] "
            f"train_loss={train_metrics['loss']:.4f}, train_f1={train_metrics['f1']:.3f} | "
            f"val_loss={val_metrics['loss']:.4f}, val_f1={val_metrics['f1']:.3f}"
        )

        stop = early_stopper.step(val_metrics, model)
        if stop:
            print(f"Early stopping triggered at epoch {epoch}. Best {early_stopper.monitor} = {early_stopper.best:.4f}")
            break

    # Restore best weights if we got any improvement
    if early_stopper.best_state is not None:
        model.load_state_dict(early_stopper.best_state)

    result = {
        "model": model,
        "history": history,
        "best_metric": early_stopper.best,
        "best_monitor": early_stopper.monitor,
        "best_state_dict": early_stopper.best_state,
    }
    return result

##### Initiate Training

In [19]:
best_overall = None
best_cfg = None

for cfg in iter_gnn_configs(GNN_HYP):
    print("\n=== Config:", cfg["conv_type"], cfg["hidden_dim"], cfg["num_layers"],
          cfg["dropout"], cfg["aggr"], "lr", cfg["lr"], "===")
    out = train_gnn_model(cfg, train_loader, val_loader, in_dim=D_NODE, device=DEVICE)

    if best_overall is None or out["best_metric"] > best_overall["best_metric"]:
        best_overall = out
        best_cfg = cfg

print("\nBest config:", best_cfg)
print("Best validation", best_overall["best_monitor"], "=", best_overall["best_metric"])


=== Config: sage 32 1 0.0 mean lr 0.0002 ===


AttributeError: 'list' object has no attribute 'to'

### Model Evaluation

#### PRB-GraphSAGE

In [None]:
# ==== CELL: Inference + τ-Hysteresis + Confusion Matrices (Final Model) ====
# This cell evaluates the full stack (TCN, SVD, Fusion, Fusion+Hysteresis)
# and prints confusion matrices + basic metrics for each.
#
# Assumes you have:
#   - DualHeadTCN, SVDGeometryHead, FusionHead classes defined
#   - DEVICE, HYP, FUSION_ALPHA, TOP_K, K_PLUS_REST
#   - a dataloader (e.g., val_loader) yielding (x_seq, x_aux, y, meta)
#
# Positive class is 'attack' (label == 2).
#
# How it works:
#   1) Loads frozen TCN/SVD (best checkpoints if present) and FusionHead if saved.
#   2) Streams the loader to compute p_TCN, p_SVD, p_FUSED per window.
#   3) Applies τ-hysteresis to p_FUSED to get final ENFORCE decisions.
#   4) Prints confusion matrices + precision/recall/F1 for:
#        - TCN (p>=0.5)
#        - SVD (p>=0.5)
#        - Fusion (p>=0.5)
#        - Fusion + Hysteresis (ENFORCE boolean)
#
# Optional: set INCLUDE_AE=True and provide ae_provider(x_seq, meta)->[B] to include AE in fusion.
#
from typing import Dict, Any, Tuple
import os, math, json
import numpy as np
import torch
import torch.nn.functional as F

INCLUDE_AE = False   # set True when AE is wired
AE_PROVIDER = None   # function (x_seq, meta)-> torch.Tensor [B], values in [0,1]

# --- τ-hysteresis ---
class TauHysteresis:
    def __init__(self, theta_on=0.7, theta_off=0.5, L_on=100, L_off=100):
        self.theta_on=float(theta_on); self.theta_off=float(theta_off)
        self.L_on=int(L_on); self.L_off=int(L_off)
        self.state=False; self.streak=0
    def reset(self):
        self.state=False; self.streak=0
    def step(self, p: float) -> bool:
        if not self.state:
            if p >= self.theta_on:
                self.streak += 1
                if self.streak >= self.L_on:
                    self.state = True; self.streak = 0
            else:
                self.streak = 0
        else:
            if p <= self.theta_off:
                self.streak += 1
                if self.streak >= self.L_off:
                    self.state = False; self.streak = 0
            else:
                self.streak = 0
        return self.state

def load_frozen_bases(tcn_ckpt="tcn_best.pt", svd_ckpt="svd_head_best.pt"):
    tcn = DualHeadTCN(d_seq=(K_PLUS_REST + K_PLUS_REST + 3 + K_PLUS_REST),
                      hidden=HYP["hidden"], dropout=HYP["dropout"], temperature=1.0)
    svd = SVDGeometryHead(k_roles_plus_rest=K_PLUS_REST, hidden=16)
    if os.path.exists(tcn_ckpt):
        sd = torch.load(tcn_ckpt, map_location="cpu")
        tcn.load_state_dict(sd["model"], strict=False)
    if os.path.exists(svd_ckpt):
        sd = torch.load(svd_ckpt, map_location="cpu")
        svd.load_state_dict(sd["model"], strict=False)
    tcn.to(DEVICE).eval(); svd.to(DEVICE).eval()
    for p in tcn.parameters(): p.requires_grad=False
    for p in svd.parameters(): p.requires_grad=False
    return tcn, svd

def load_or_default_fusion(in_dim: int = 2):
    fh = FusionHead(in_dim=in_dim, hidden=8, dropout=0.0).to(DEVICE)
    if os.path.exists("fusion_head_best.pt"):
        sd = torch.load("fusion_head_best.pt", map_location="cpu")
        ck_in_dim = sd.get("in_dim", in_dim)
        if ck_in_dim != in_dim:
            fh = FusionHead(in_dim=ck_in_dim, hidden=8, dropout=0.0).to(DEVICE)
        fh.load_state_dict(sd["fusion"], strict=False)
    fh.eval()
    for p in fh.parameters(): p.requires_grad=False
    return fh

@torch.no_grad()
def base_probs(tcn, svd, x_seq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    logits, p_all = tcn(x_seq)
    p_tcn = p_all[:, 2]
    p_svd, _ = svd(x_seq[:, :, :K_PLUS_REST])
    return p_tcn, p_svd

@torch.no_grad()
def fused_probs(tcn, svd, fusion_head, x_seq: torch.Tensor, ae_score: torch.Tensor=None) -> torch.Tensor:
    p_tcn, p_svd = base_probs(tcn, svd, x_seq)
    # if FusionHead checkpoint exists and has params, use it; else α-blend
    has_params = sum(p.numel() for p in fusion_head.parameters()) > 0
    if has_params:
        if ae_score is not None:
            z = torch.stack([p_tcn, p_svd, ae_score.to(p_tcn.device)], dim=-1)
        else:
            z = torch.stack([p_tcn, p_svd], dim=-1)
        return fusion_head(z)
    return FUSION_ALPHA * p_tcn + (1.0 - FUSION_ALPHA) * p_svd

def confusion(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Any]:
    # y_true, y_pred: 0/1 arrays for not-attack/attack
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    prec = tp / (tp + fp) if (tp + fp) else 0.0
    rec  = tp / (tp + fn) if (tp + fn) else 0.0
    f1   = 2*prec*rec / (prec + rec) if (prec + rec) else 0.0
    acc  = (tp + tn) / max(1, (tp+tn+fp+fn))
    return {
        "matrix": [[tn, fp],
                   [fn, tp]],
        "acc": acc, "prec": prec, "rec": rec, "f1": f1,
        "tp": tp, "tn": tn, "fp": fp, "fn": fn
    }

def evaluate_confusions(loader) -> Dict[str, Any]:
    assert loader is not None and len(loader) > 0, "Loader is empty."
    tcn, svd = load_frozen_bases()
    fusion = load_or_default_fusion(in_dim=3 if INCLUDE_AE else 2)

    # Collect predictions/probs + labels and meta per-run for hysteresis
    probs_tcn, probs_svd, probs_fused = [], [], []
    labels = []
    metas = []   # list of dicts per item: run_id, start_tti

    for x_seq, x_aux, y, meta in loader:
        x_seq = x_seq.to(DEVICE).float()
        ae_score = None
        if INCLUDE_AE and AE_PROVIDER is not None:
            ae_score = AE_PROVIDER(x_seq, meta)

        p_fused = fused_probs(tcn, svd, fusion, x_seq, ae_score=ae_score).detach().cpu().numpy()
        p_tcn, p_svd = base_probs(tcn, svd, x_seq)
        probs_tcn.extend(p_tcn.detach().cpu().numpy().tolist())
        probs_svd.extend(p_svd.detach().cpu().numpy().tolist())
        probs_fused.extend(p_fused.tolist())
        labels.extend(((y == 2).numpy().astype(int)).tolist())

        B = y.size(0)
        for i in range(B):
            metas.append({
                "run_id": str(meta["run_id"][i]),
                "start_tti": int(meta["start_tti"][i])
            })

    y_true = np.array(labels, dtype=int)
    y_tcn  = (np.array(probs_tcn)  >= 0.5).astype(int)
    y_svd  = (np.array(probs_svd)  >= 0.5).astype(int)
    y_fuse = (np.array(probs_fused)>= 0.5).astype(int)

    # Confusions without hysteresis
    cm_tcn  = confusion(y_true, y_tcn)
    cm_svd  = confusion(y_true, y_svd)
    cm_fuse = confusion(y_true, y_fuse)

    # Apply hysteresis on the fused stream per run (ordered by start_tti)
    th_on, th_off, Lon, Loff = HYP["theta_on"], HYP["theta_off"], HYP["L_on"], HYP["L_off"]
    # Group by run
    from collections import defaultdict
    by_run = defaultdict(list)
    for i, m in enumerate(metas):
        by_run[m["run_id"]].append((m["start_tti"], i))

    y_fuse_hyst = np.zeros_like(y_fuse)
    for rid, lst in by_run.items():
        lst.sort(key=lambda x: x[0])
        fsm = TauHysteresis(th_on, th_off, Lon, Loff)
        for _, idx in lst:
            state = fsm.step(float(probs_fused[idx]))
            y_fuse_hyst[idx] = 1 if state else 0

    cm_hyst = confusion(y_true, y_fuse_hyst)

    # Pretty print
    def pretty(cm: Dict[str, Any], name: str):
        M = cm["matrix"]
        print(f"\n=== {name} ===")
        print("Confusion Matrix (rows=true [0=benign,1=attack], cols=pred):")
        print(f"[{M[0][0]:5d}  {M[0][1]:5d}]  ← true 0")
        print(f"[{M[1][0]:5d}  {M[1][1]:5d}]  ← true 1")
        print(f"acc={cm['acc']:.3f}  prec={cm['prec']:.3f}  rec={cm['rec']:.3f}  f1={cm['f1']:.3f} "
              f"(tp={cm['tp']}, fp={cm['fp']}, tn={cm['tn']}, fn={cm['fn']})")

    pretty(cm_tcn,  "TCN (p>=0.5)")
    pretty(cm_svd,  "SVD (p>=0.5)")
    pretty(cm_fuse, "Fusion (p>=0.5)")
    pretty(cm_hyst, "Fusion + τ-Hysteresis (final ENFORCE)")

    return {
        "cm_tcn": cm_tcn, "cm_svd": cm_svd, "cm_fuse": cm_fuse, "cm_hyst": cm_hyst,
        "probs": {"tcn": probs_tcn, "svd": probs_svd, "fuse": probs_fused},
        "y_true": y_true.tolist()
    }

# --- Run on validation loader by default ---
if 'val_loader' in globals() and val_loader and len(val_loader) > 0:
    _ = evaluate_confusions(val_loader)
else:
    print("No validation loader attached. Build metas/loaders, then re-run this cell.")