# TTI-Trust: TCN + SVD Intelligent QoS Classifier

## Ryan Barker

### IS-WiN Open RAN Zero-Trust Security: PRB Starvation Classifier

### preproc_tti_trust.py

In [None]:
# TTI-Trust preprocessing scaffold for PyTorch: identity-agnostic feature composer,
# phase-aware windowing, and grouped K-fold splits.

from __future__ import annotations
import os, re, warnings, itertools
from dataclasses import dataclass
from typing import List, Tuple
from glob import glob

import numpy as np
import pandas as pd
import pyarrow as pa, pyarrow.csv as pv, pyarrow.parquet as pq

# -------------------- DATASETS --------------------
ATTACK_FN = "prb_tti_evidence_ai_attack.csv"
BENIGN_FN = "prb_tti_evidence_ai_benign.csv"

# -------------------- Constants aligned to paper --------------------
TTI_SEC = 0.0005
C_PRB   = 106

W_LONG  = 240
W_SHORT = 64

STRIDES_LONG  = [8, 12]
STRIDES_SHORT = [4, 6]

TOP_K = 6
SMALL_RB_EPS = 4

ATTACK_DOMINANCE = 0.90
OTHERS_MAX       = 0.10
DUTY_KEEP        = 0.90

# -------------------- Helpers --------------------

def csv_to_parquet(src_csv: str, dest_root: str):
    os.makedirs(dest_root, exist_ok=True)
    for _root, _dirs, files in os.walk(dest_root):
        if any(f.endswith(".parquet") for f in files):
            return

    convert_cols = {"run_id": pa.large_string(), "phase": pa.large_string()}
    read_opts  = pv.ReadOptions(block_size=1<<26)
    conv_opts  = pv.ConvertOptions(column_types=convert_cols, strings_can_be_null=True)
    t = pv.read_csv(src_csv, read_options=read_opts, convert_options=conv_opts)

    if "run_id" not in t.schema.names:
        stem = os.path.splitext(os.path.basename(src_csv))[0]
        run_id = pa.array(np.full(t.num_rows, stem), type=pa.large_string())
        t = t.append_column("run_id", run_id)

    if "sec_rel" not in t.schema.names:
        n = t.num_rows
        sec_rel = pa.array((np.arange(n, dtype=np.int32) * np.float32(TTI_SEC)), type=pa.float32())
        t = t.append_column("sec_rel", sec_rel)

    pq.write_to_dataset(t, root_path=dest_root, partition_cols=["run_id"],
                        compression="zstd", use_dictionary=True)

def _detect_ue_rb_columns(df: pd.DataFrame) -> List[str]:
    cand = [c for c in df.columns if re.fullmatch(r'UE\d+_rb', c, flags=re.IGNORECASE)]
    if not cand:
        cand = [c for c in df.columns if c.lower().endswith('_rb')]
    def _key(c):
        m = re.search(r'(\d+)', c)
        return int(m.group(1)) if m else 10**9
    return sorted(set(cand), key=_key)

def _ensure_total_rb(df: pd.DataFrame, ue_cols: List[str]) -> pd.DataFrame:
    if 'total_rb' not in df.columns:
        df['total_rb'] = df[ue_cols].sum(axis=1)
    return df

def _compute_shares_and_fairness(df: pd.DataFrame, ue_cols: List[str]) -> pd.DataFrame:
    shares = (df[ue_cols].clip(lower=0).astype('float32') / np.float32(C_PRB))
    shares.columns = [c.replace('_rb','_share') for c in ue_cols]
    for c in shares.columns:
        df[c] = shares[c]
    x = shares.to_numpy(dtype='float32', copy=False)
    sum_x  = x.sum(axis=1, dtype='float32')
    sum_x2 = (x*x).sum(axis=1, dtype='float32')
    n = np.float32(max(1, len(ue_cols)))
    with np.errstate(divide='ignore', invalid='ignore'):
        J = (sum_x*sum_x) / (n * sum_x2)
    df['J'] = np.nan_to_num(J, nan=0.0, posinf=0.0, neginf=0.0).astype('float32')
    return df

def _compose_roles(df: pd.DataFrame, ue_cols: List[str], top_k: int = TOP_K) -> pd.DataFrame:
    share_cols = [c.replace('_rb','_share') for c in ue_cols]
    alpha = 2.0 / (20 + 1.0)
    for c in share_cols:
        df[c+"_ema"] = df[c].astype('float32').ewm(alpha=alpha, adjust=False).mean().astype('float32')

    S   = df[share_cols].to_numpy(dtype='float32', copy=False)
    E   = df[[c+"_ema" for c in share_cols]].to_numpy(dtype='float32', copy=False)
    PRB = df[ue_cols].to_numpy(dtype='int32',    copy=False)

    T, U = S.shape
    k = min(top_k, U)
    order  = np.argsort(-E, axis=1)
    top_ix = order[:, :k]

    roles       = np.take_along_axis(S,   top_ix, axis=1).astype('float32')
    roles_small = np.take_along_axis((PRB < SMALL_RB_EPS).astype('float32'), top_ix, axis=1)

    sum_all    = S.sum(axis=1, dtype='float32')
    sum_topk   = roles.sum(axis=1, dtype='float32')
    rest_share = (sum_all - sum_topk).clip(min=0).astype('float32')

    if U > k:
        row_ix = np.arange(T)[:, None]
        mask = np.ones_like(S, dtype=bool); mask[row_ix, top_ix] = False
        cnt  = mask.sum(axis=1)
        rest_small = ((PRB < SMALL_RB_EPS).astype('float32') * mask).sum(axis=1) / np.maximum(cnt, 1)
    else:
        rest_small = np.zeros(T, dtype='float32')

    for i in range(k):
        df[f'role{i+1}_share']   = roles[:, i]
        df[f'role{i+1}_smallrb'] = roles_small[:, i]
    for i in range(k, top_k):
        df[f'role{i+1}_share']   = np.float32(0.0)
        df[f'role{i+1}_smallrb'] = np.float32(0.0)
    df['rest_share']   = rest_share
    df['rest_smallrb'] = rest_small

    sec_bin = np.floor(df['sec_rel'].to_numpy(dtype='float32')).astype('int32')
    max_sec = int(sec_bin.max(initial=0))
    def sec_presence(vec_bool: np.ndarray) -> np.ndarray:
        agg = np.zeros(max_sec + 1, dtype='uint8')
        np.maximum.at(agg, sec_bin, vec_bool.astype('uint8'))
        return agg[sec_bin]
    for i in range(top_k):
        df[f'role{i+1}_present_1s'] = sec_presence((df[f'role{i+1}_share'].to_numpy() > 0))
    df['rest_present_1s'] = sec_presence((df['rest_share'].to_numpy() > 0))
    return df

def _label_window(shares_mat: np.ndarray) -> bool:
    if shares_mat.size == 0:
        return False
    top    = shares_mat[:, 0]
    others = shares_mat[:, 1:]
    ok = (top >= ATTACK_DOMINANCE) & (np.max(others, axis=1) <= OTHERS_MAX)
    return (ok.mean() >= DUTY_KEEP)

@dataclass
class WindowSpec:
    length_tti: int
    stride_tti: int

def iter_windows(df: pd.DataFrame, run_id: str, spec: WindowSpec, role_cols: List[str],
                 phases_attack={"attack"}, phases_benign={"baseline","recovery","benign"},
                 batch_windows: int = 8192):
    W, S = spec.length_tti, spec.stride_tti
    R = df[role_cols].to_numpy(dtype='float32', copy=False)
    phase = df['phase'].astype(str).to_numpy()
    T = len(df)
    starts = np.arange(0, max(0, T - W + 1), S, dtype=np.int64)

    bufX, rows = [], []
    for s in starts:
        e = s + W
        blk = R[s:e, :]
        if blk.shape[0] != W:
            continue
        win_phase = phase[s:e]
        frac_attack = np.isin(win_phase, list(phases_attack)).mean()
        frac_benign = np.isin(win_phase, list(phases_benign)).mean()
        looks_attacky = _label_window(blk)
        if (frac_attack >= 0.90) and looks_attacky: y = 2
        elif (frac_benign >= 0.90):                y = 0
        else:                                      y = 1

        bufX.append(blk)
        rows.append((run_id, int(s), int(e), int(W), int(y),
                     float(frac_attack), float(frac_benign), int(looks_attacky)))

        if len(bufX) == batch_windows:
            X = np.stack(bufX, axis=0)
            meta = pd.DataFrame.from_records(
                rows,
                columns=["run_id","start_tti","end_tti","window_len","label",
                         "frac_attack","frac_benign","looks_attacky"]
            )
            yield X, meta
            bufX.clear(); rows.clear()

    if bufX:
        X = np.stack(bufX, axis=0)
        meta = pd.DataFrame.from_records(
            rows,
            columns=["run_id","start_tti","end_tti","window_len","label",
                     "frac_attack","frac_benign","looks_attacky"]
        )
        yield X, meta

# -------------------- Main entry: build datasets (partitioned, low RAM) --------------------

csv_to_parquet(ATTACK_FN, "parquet/attack")
csv_to_parquet(BENIGN_FN, "parquet/benign")

def iter_run_partitions(root: str):
    for run_dir in sorted(glob(os.path.join(root, "run_id=*"))):
        rid = os.path.basename(run_dir).split("=", 1)[1]
        df = pd.read_parquet(run_dir, engine="pyarrow")
        if 'run_id' not in df.columns:
            df['run_id'] = rid
        if 'phase' not in df.columns:
            df['phase'] = 'unknown'
        if 'sec_rel' not in df.columns:
            df['sec_rel'] = (np.arange(len(df), dtype=np.int32) * np.float32(TTI_SEC)).astype('float32')
        yield rid, df

def process_source(root: str, src_name: str):
    if not os.path.isdir(root):
        print(f"[{src_name}] No parquet root found at {root} (skipping).")
        return

    os.makedirs("win_shards", exist_ok=True)
    shard_counter = itertools.count(0)

    runs = list(iter_run_partitions(root))
    if not runs:
        print(f"[{src_name}] No run_id partitions found in {root}.")
        return
    print(f"[{src_name}] Found {len(runs)} run_id partitions.")

    for rid, df in runs:
        ue_cols = _detect_ue_rb_columns(df)
        if not ue_cols:
            warnings.warn(f"[{src_name}] run_id={rid}: no UE *_rb columns; skipping")
            continue

        df = _ensure_total_rb(df, ue_cols)
        df = _compute_shares_and_fairness(df, ue_cols)
        df = _compose_roles(df, ue_cols, top_k=TOP_K)

        keep = [f'role{i}_share' for i in range(1, TOP_K+1)] + ['rest_share','J','sec_rel','run_id','phase']
        enriched_out = f"{src_name}_enriched_{rid}.parquet"
        df[keep].to_parquet(enriched_out, index=False)
        print(f"[{src_name}] run_id={rid} → {enriched_out}  (rows={len(df):,})")

        role_cols = [f'role{k}_share' for k in range(1, TOP_K+1)] + ['rest_share']
        g = df.reset_index(drop=True)

        for stride in STRIDES_LONG:
            spec = WindowSpec(length_tti=W_LONG, stride_tti=stride)
            for X, meta in iter_windows(g, rid, spec, role_cols):
                sid = next(shard_counter)
                np.savez_compressed(f"win_shards/long_s{stride}_{src_name}_{rid}_{sid}.npz", X=X)
                meta["source"] = src_name; meta["stride"] = stride; meta["kind"] = "long"
                meta.to_parquet(f"win_shards/long_s{stride}_{src_name}_{rid}_{sid}_meta.parquet", index=False)

        for stride in STRIDES_SHORT:
            spec = WindowSpec(length_tti=W_SHORT, stride_tti=stride)
            for X, meta in iter_windows(g, rid, spec, role_cols):
                sid = next(shard_counter)
                np.savez_compressed(f"win_shards/short_s{stride}_{src_name}_{rid}_{sid}.npz", X=X)
                meta["source"] = src_name; meta["stride"] = stride; meta["kind"] = "short"
                meta.to_parquet(f"win_shards/short_s{stride}_{src_name}_{rid}_{sid}_meta.parquet", index=False)

process_source("parquet/attack", "attack")
process_source("parquet/benign", "benign")

def load_all_meta(kind: str) -> pd.DataFrame:
    metas = [pd.read_parquet(p) for p in glob(f"win_shards/{kind}_*_meta.parquet")]
    return pd.concat(metas, ignore_index=True) if metas else pd.DataFrame()

M_long  = load_all_meta("long")
M_short = load_all_meta("short")

def grouped_kfold(meta: pd.DataFrame, K: int = 5, seed: int = 42) -> List[Tuple[np.ndarray, np.ndarray]]:
    if meta.empty:
        return []
    rng = np.random.default_rng(seed)
    run_ids = meta['run_id'].astype(str).unique().tolist()
    rng.shuffle(run_ids)
    folds = [set() for _ in range(K)]
    for i, rid in enumerate(run_ids):
        folds[i % K].add(rid)
    splits = []
    for i in range(K):
        val_rids = folds[i]
        train_idx = meta.index[~meta['run_id'].astype(str).isin(val_rids)].to_numpy()
        val_idx   = meta.index[ meta['run_id'].astype(str).isin(val_rids)].to_numpy()
        splits.append((train_idx, val_idx))
    return splits

splits_long  = grouped_kfold(M_long,  K=5, seed=2025)
splits_short = grouped_kfold(M_short, K=5, seed=2025)

if len(M_long):
    print(f"\n[long] meta={M_long.shape}, groups={M_long['run_id'].nunique()}, folds={len(splits_long)}")
if len(M_short):
    print(f"[short] meta={M_short.shape}, groups={M_short['run_id'].nunique()}, folds={len(splits_short)}")

print("\nPreprocessing complete. Shards are in ./win_shards; use meta to assemble CV splits or stream for training.")

### Data Loader

In [None]:
# PyTorch feature space & data loader for TTI-Trust (scheduler-native, identity-agnostic)
# Optimized to stream large datasets from shards: win_shards/{long,short}_s*_... .npz + *_meta.parquet
#
# Emits batches directly (IterableDataset) to avoid in-memory concatenation or per-item pandas operations.

import os, re, glob, math
from typing import Dict, List, Tuple, Optional, Iterator
import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset, DataLoader

# --- Shared constants (align with preprocessing) ---
TTI_SEC = 0.0005
C_PRB   = 106
TOP_K   = 6
SMALL_RB_EPS = 4

PLATEAUS = np.array([16, 31, 46, 61, 76, 91, 106], dtype=np.int32)

# ---------- Feature helpers (NumPy; no pandas in the hot path) ----------

def jains_fairness_seq(shares: np.ndarray) -> np.ndarray:
    # shares: [N, W, D] or [W, D]
    if shares.ndim == 2:
        shares = shares[None, ...]
    sum_x  = shares.sum(axis=2, dtype=np.float32)                     # [N, W]
    sum_x2 = (shares * shares).sum(axis=2, dtype=np.float32)          # [N, W]
    n = shares.shape[2]
    J = (sum_x * sum_x) / (np.float32(n) * sum_x2 + 1e-9)
    return J.squeeze(0).astype(np.float32) if J.shape[0] == 1 else J.astype(np.float32)

def rolling_min_med_causal(x: np.ndarray, win: int) -> Tuple[np.ndarray, np.ndarray]:
    # x: [W] (per-window). W<=240 → simple loop is fine and cache-friendly.
    W = x.shape[0]
    rmin = np.empty(W, dtype=np.float32)
    rmed = np.empty(W, dtype=np.float32)
    for t in range(W):
        a = max(0, t - win + 1)
        sl = x[a:t+1]
        rmin[t] = sl.min() if sl.size else 0.0
        rmed[t] = np.median(sl) if sl.size else 0.0
    return rmin, rmed

def small_rb_runlength_1d(streak_bin: np.ndarray) -> np.ndarray:
    out = np.zeros_like(streak_bin, dtype=np.float32)
    c = 0.0
    for i, b in enumerate(streak_bin.astype(bool)):
        c = c + 1.0 if b else 0.0
        out[i] = c
    return out

def build_c_runs(srb_bin: np.ndarray) -> np.ndarray:
    # srb_bin: [W, D] 0/1 → run-length per channel
    return np.stack([small_rb_runlength_1d(srb_bin[:, j]) for j in range(srb_bin.shape[1])], axis=1).astype(np.float32)

def approx_plateau_hist(share_seq: np.ndarray) -> np.ndarray:
    # share_seq: [W, D] → PRBs → closest plateau histogram normalized
    prbs = np.rint(share_seq * C_PRB).astype(np.int32)
    idx  = np.abs(prbs[:, :, None] - PLATEAUS[None, None, :]).argmin(axis=2)  # [W, D]
    hist = np.bincount(idx.ravel(), minlength=len(PLATEAUS)).astype(np.float32)
    if hist.sum() > 0: hist /= hist.sum()
    return hist

def benign_grant_fraction(share_seq: np.ndarray) -> float:
    others = share_seq[:, 1:]
    return float((others > 0).any(axis=1).mean())

def contiguous_zero_runs(share_seq: np.ndarray) -> Tuple[float, float]:
    others_all_zero = (share_seq[:, 1:] == 0).all(axis=1).astype(np.int32)  # [W]
    runs, c = [], 0
    for b in others_all_zero:
        if b: c += 1
        else:
            if c: runs.append(c); c = 0
    if c: runs.append(c)
    return (0.0, 0.0) if not runs else (float(max(runs)), float(np.mean(runs)))

# ---------- IterableDataset over shard files ----------

class TTINPZShardBatches(IterableDataset):
    """
    Streams batches from NPZ shards produced by preprocessing.
    Each shard file has X: [N, W, D] where D = TOP_K+1 (roles incl. rest), and a paired *_meta.parquet.

    Yields (X_seq, X_aux, y, meta_list) where:
      X_seq: [B, W, Dseq]  (roles, d_roles, J, J_min, J_med, smallRB runs)
      X_aux: [B, Daux]     (plateau hist, benign frac, zero-run stats, radio hints if available)
      y:     [B] {0,1,2}
      meta_list: list of dicts with run_id/source/stride/window_len/start_tti/end_tti
    """
    def __init__(self, kind: str, batch_size: int = 1024, shard_glob: Optional[str] = None,
                 restrict_sources: Optional[List[str]] = None):
        assert kind in ("long", "short")
        self.kind = kind
        self.batch_size = int(batch_size)
        # Discover shards
        patt = shard_glob or f"win_shards/{kind}_*.npz"
        self.shard_paths = sorted(glob.glob(patt))
        if not self.shard_paths:
            raise FileNotFoundError(f"No shards found for pattern: {patt}. Run preprocessing to generate win_shards.")
        self.restrict_sources = set(restrict_sources) if restrict_sources else None

    def __iter__(self) -> Iterator:
        for npz_path in self.shard_paths:
            meta_path = npz_path.replace(".npz", "_meta.parquet")
            if not os.path.exists(meta_path):
                continue
            meta = pd.read_parquet(meta_path)
            if self.restrict_sources is not None:
                if "source" in meta.columns:
                    if not any(src in self.restrict_sources for src in meta["source"].unique().tolist()):
                        continue
            X = np.load(npz_path)["X"]  # [N, W, D], float32 from preprocessing
            # Labels
            if "label" not in meta.columns:
                raise ValueError(f"Meta {meta_path} missing 'label' column.")
            y_all = meta["label"].to_numpy(dtype=np.int64)

            # Optional radio hints (rare). If present as per-TTI per-role, we’d read a second array.
            have_radio = False  # keep simple; emit zeros + flag

            N, W, D = X.shape
            assert D == (TOP_K + 1), f"Expected D=TOP_K+1, got D={D}"

            # Batch over this shard
            for i0 in range(0, N, self.batch_size):
                i1 = min(N, i0 + self.batch_size)
                seq = X[i0:i1]                                  # [B, W, D]
                # Δshares
                dseq = np.diff(seq, axis=1, prepend=seq[:, 0:1, :])  # [B, W, D]

                # Jain's J(t) + causal rolling stats (per sample)
                J_list, Jmin_list, Jmed_list = [], [], []
                for b in range(seq.shape[0]):
                    Jb = jains_fairness_seq(seq[b])                   # [W]
                    Jmin, Jmed = rolling_min_med_causal(Jb, win=min(60, Jb.shape[0]))
                    J_list.append(Jb); Jmin_list.append(Jmin); Jmed_list.append(Jmed)
                J   = np.stack(J_list,    axis=0)[:, :, None]         # [B, W, 1]
                Jmn = np.stack(Jmin_list, axis=0)[:, :, None]         # [B, W, 1]
                Jmd = np.stack(Jmed_list, axis=0)[:, :, None]         # [B, W, 1]

                # small-RB run-lengths (approx from shares threshold)
                prb = np.rint(seq * C_PRB)                            # [B, W, D]
                srb = (prb < SMALL_RB_EPS).astype(np.float32)
                c_runs = np.stack([build_c_runs(srb[b]) for b in range(srb.shape[0])], axis=0)  # [B, W, D]

                # Time-step channels
                X_seq = np.concatenate([seq, dseq, J, Jmn, Jmd, c_runs], axis=2).astype(np.float32)  # [B, W, Dseq]

                # Aux features per window
                aux_list = []
                for b in range(seq.shape[0]):
                    ph = approx_plateau_hist(seq[b])
                    bf = benign_grant_fraction(seq[b])
                    zmax, zmean = contiguous_zero_runs(seq[b])
                    if have_radio:
                        radio_aux = np.zeros(33, dtype=np.float32)
                        radio_flag = 1.0
                    else:
                        radio_aux = np.zeros(33, dtype=np.float32)
                        radio_flag = 0.0
                    aux = np.concatenate([ph, np.array([bf, zmax, zmean], dtype=np.float32),
                                          radio_aux, np.array([radio_flag], dtype=np.float32)])
                    aux_list.append(aux)
                X_aux = np.stack(aux_list, axis=0).astype(np.float32)  # [B, 44]

                y = torch.from_numpy(y_all[i0:i1])                     # [B], int64

                # Build meta dicts list (keep only essentials)
                meta_batch = []
                cols = meta.columns
                for j in range(i0, i1):
                    row = meta.iloc[j]
                    meta_batch.append({
                        "run_id": str(row.get("run_id", "")),
                        "source": str(row.get("source", "")),
                        "stride": int(row.get("stride", 0)),
                        "window_len": int(row.get("window_len", seq.shape[1])),
                        "start_tti": int(row.get("start_tti", 0)),
                        "end_tti": int(row.get("end_tti", 0)),
                    })

                # Zero-copy to torch for X_seq/X_aux
                yield torch.from_numpy(X_seq), torch.from_numpy(X_aux), y, meta_batch

# ---------- High-throughput DataLoader factory ----------

def make_loader(kind: str, batch_size: int = 1024, num_workers: int = 4,
                restrict_sources: Optional[List[str]] = None) -> DataLoader:
    ds = TTINPZShardBatches(kind=kind, batch_size=batch_size, restrict_sources=restrict_sources)
    loader = DataLoader(
        ds,
        batch_size=None,               # dataset already yields batches
        num_workers=num_workers,       # tune: 2–8 based on CPU/IO
        pin_memory=True,               # faster H2D copies
        persistent_workers=True,
        prefetch_factor=2
    )
    return loader

# ---------- Smoke test (requires win_shards present) ----------

try:
    train_loader = make_loader(kind="short", batch_size=1024, num_workers=4)
    val_loader   = make_loader(kind="short", batch_size=1024, num_workers=4)
    print("Streaming loader ready. First batch shape:")
    xb_seq, xb_aux, yb, mb = next(iter(train_loader))
    print("X_seq:", tuple(xb_seq.shape), "X_aux:", tuple(xb_aux.shape), "y:", tuple(yb.shape), "meta[0]:", mb[0])
except Exception as e:
    print("Note:", str(e))

print("\nUse: for xb_seq, xb_aux, yb, mb in make_loader('short', ...): model(xb_seq, xb_aux) ...")

### Base Model Architecture

#### Dual-head TCN Fairness Seismograph

In [None]:
# ==== Dual-head TCN (τ-scale + edge-scale) ====
# Input:  x_seq ∈ R^{B × W × D_seq}
# Output: logits ∈ R^{B×3}  (classes: 0=benign, 1=proximal, 2=attack),
#         probs  ∈ R^{B×3}  (temperature-scaled softmax for calibration)

from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalConv1dSame(nn.Conv1d):
    """Causal Conv1d that preserves time length (left-padding only)."""
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True):
        super().__init__(in_channels, out_channels, kernel_size,
                         stride=1, padding=0, dilation=dilation, bias=bias)
        self.left_pad = (kernel_size - 1) * dilation

    def forward(self, x):  # x: [B, C, T]
        x = F.pad(x, (self.left_pad, 0))
        return super().forward(x)

class TCNBlock(nn.Module):
    """Residual dilated-causal block: Conv → SiLU → Dropout → Conv → SiLU → Residual → LayerNorm."""
    def __init__(self, channels: int, kernel: int, dilation: int, dropout: float = 0.05):
        super().__init__()
        self.conv1 = CausalConv1dSame(channels, channels, kernel, dilation=dilation)
        self.conv2 = CausalConv1dSame(channels, channels, kernel, dilation=dilation)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(channels)

    def forward(self, x):            # x: [B, C, T]
        y = self.conv1(x); y = F.silu(y); y = self.dropout(y)
        y = self.conv2(y); y = F.silu(y)
        y = y + x                    # residual
        y = y.transpose(1, 2)        # [B, T, C] for LayerNorm over channels
        y = self.ln(y)
        y = y.transpose(1, 2)        # back to [B, C, T]
        return y

class TCNHead(nn.Module):
    def __init__(self, in_channels: int, hidden: int, dilations: List[int], kernel: int = 3, dropout: float = 0.05):
        super().__init__()
        self.in_proj = nn.Conv1d(in_channels, hidden, kernel_size=1)
        self.blocks = nn.ModuleList([TCNBlock(hidden, kernel, d, dropout) for d in dilations])
        self.out_norm = nn.LayerNorm(hidden)

    def forward(self, x):            # x: [B, W, D]
        x = x.transpose(1, 2)        # → [B, D, W]
        x = self.in_proj(x)          # → [B, H, W]
        for blk in self.blocks:
            x = blk(x)               # → [B, H, W]
        x = x.mean(dim=-1)           # global average over time → [B, H]
        x = self.out_norm(x)
        return x

class DualHeadTCN(nn.Module):
    """
    Two TCN heads with different receptive fields:
      - τ-scale head: 7 blocks, dilations [1,2,4,8,16,32,64]  → ~127.5 ms at 0.5 ms/TTI
      - Edge-scale head: 4 blocks, dilations [1,2,4,8]        → ~15.5 ms
    """
    def __init__(self, d_seq: int, hidden: int = 32, dropout: float = 0.05, temperature: float = 1.0):
        super().__init__()
        self.head_tau  = TCNHead(in_channels=d_seq, hidden=hidden, dilations=[1,2,4,8,16,32,64],
                                 kernel=3, dropout=dropout)
        self.head_edge = TCNHead(in_channels=d_seq, hidden=hidden, dilations=[1,2,4,8],
                                 kernel=3, dropout=dropout)
        self.classifier = nn.Sequential(
            nn.Linear(2*hidden, 2*hidden),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(2*hidden, 3)   # 0=benign, 1=proximal, 2=attack
        )
        # temperature for post-hoc calibration (validation-time tuning)
        self.log_temp = nn.Parameter(torch.log(torch.tensor(temperature, dtype=torch.float32)))

    def forward(self, x_seq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x_seq: [B, W, D_seq] — feed the feature set from TtiTrustDataset (shares, Δshares, J stats, small-RB runs).
        returns (logits, probs)
        """
        h_tau  = self.head_tau(x_seq)           # [B, H]
        h_edge = self.head_edge(x_seq)          # [B, H]
        h = torch.cat([h_tau, h_edge], dim=-1)  # [B, 2H]
        logits = self.classifier(h)             # [B, 3]
        T = torch.exp(self.log_temp).clamp(min=1e-3)
        probs = F.softmax(logits / T, dim=-1)
        return logits, probs

# Example (synthetic) — set D_seq to your dataset’s channel count (e.g., 24)
# B, W, D = 8, 240, 24
# x = torch.randn(B, W, D)
# model = DualHeadTCN(d_seq=D, hidden=32, dropout=0.05)
# logits, probs = model(x)
# logits.shape, probs.shape

#### SVD PF On-Edge Detector

In [None]:
# ==== CELL 2: SVD head + τ-aligned hysteresis FSM ====
# SVD Head:
#   Input: shares_seq ∈ R^{B × W × (K+1)}   (top-K roles + rest)  ⟵ slice the first (K+1) channels from x_seq
#   Features: [sigma1/sum_sigma, spectral_entropy, sigma1/sigma2, angle_v1_e1]
#   Output:   p_attack_svd ∈ (0,1)
#
# Hysteresis:
#   Two-threshold FSM over attack probability time series with τ-aligned dwell times.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SVDGeometryHead(nn.Module):
    """
    Low-rank spectral feature extractor over windowed role shares.
    Produces a single logit for 'attack' (can be fused with TCN in your ensemble).
    """
    def __init__(self, k_roles_plus_rest: int, hidden: int = 16):
        super().__init__()
        # 4 scalar features → small MLP to 1 logit
        self.mlp = nn.Sequential(
            nn.Linear(4, hidden),
            nn.SiLU(),
            nn.Linear(hidden, 1)
        )
        # store D = K+1 for angle computation
        self.D = k_roles_plus_rest
        # unit vector e1 along the "role1" axis
        self.register_buffer("e1", torch.zeros(self.D))
        self.e1[0] = 1.0

    @staticmethod
    def _spectral_entropy(sigmas: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
        # sigmas: [B, r] singular values ≥ 0
        p = sigmas / (sigmas.sum(dim=-1, keepdim=True) + eps)         # normalize
        H = -(p * (p.add(eps).log())).sum(dim=-1)                     # entropy
        # optional: normalize by log(r) to [0,1]
        r = sigmas.shape[-1]
        return H / math.log(max(r, 2))

    def forward(self, shares_seq: torch.Tensor) -> torch.Tensor:
        """
        shares_seq: [B, W, D] with D = K+1 (top-K roles + rest)
        returns:    p_svd: [B]  (sigmoid probability of attack)
        """
        B, W, D = shares_seq.shape
        assert D == self.D, f"Expected D={self.D}, got {D}"

        # Compute economy SVD per batch item on CPU or GPU
        # Use torch.linalg.svd for U,S,Vh; we only need S and v1 (leading right singular vector)
        # Flatten batch: compute per-sample due to variable numerical stability
        feats = []
        for b in range(B):
            X = shares_seq[b]                      # [W, D]
            # svdvals for σ; full svd for v1
            # For stability, center columns (optional): Xc = X - X.mean(dim=0, keepdim=True)
            Xc = X - X.mean(dim=0, keepdim=True)
            try:
                U, S, Vh = torch.linalg.svd(Xc, full_matrices=False)
            except RuntimeError:
                # fallback: small jitter
                U, S, Vh = torch.linalg.svd(Xc + 1e-6*torch.randn_like(Xc), full_matrices=False)

            # σ features
            sum_sigma = S.sum() + 1e-9
            sigma1 = S[0]
            sigma2 = S[1] if S.numel() > 1 else torch.tensor(1e-6, device=S.device, dtype=S.dtype)
            dom_ratio = sigma1 / sum_sigma
            stability = sigma1 / (sigma2 + 1e-9)
            H = self._spectral_entropy(S.unsqueeze(0)).squeeze(0)     # normalized spectral entropy

            # angle between v1 (leading right singular vector) and e1 = [1,0,0,...]
            v1 = Vh[0]                                               # v1^T (row); shape [D]
            # cosine similarity with e1
            cos = torch.dot(v1, self.e1) / (v1.norm() * self.e1.norm() + 1e-9)
            cos = torch.clamp(cos, -1.0, 1.0)
            angle = torch.arccos(torch.abs(cos)) / math.pi           # normalize to [0,1]

            feats.append(torch.stack([dom_ratio, H, stability, angle], dim=0))

        F_svd = torch.stack(feats, dim=0)                             # [B, 4]
        logit = self.mlp(F_svd).squeeze(-1)                           # [B]
        p = torch.sigmoid(logit)
        return p, F_svd                                               # also return features for logging

# Example usage:
#   tcn = DualHeadTCN(d_seq=24)
#   svd = SVDGeometryHead(k_roles_plus_rest=7)
#   logits, p_tcn = tcn(x_seq)                      # p_tcn: [B,3]
#   p_svd, F_svd = svd(x_seq[:, :, :7])            # first 7 channels are shares(K+1)
#   p_attack = 0.7*p_tcn[:,2] + 0.3*p_svd          # simple fusion (calibrate offline)
#   fsm = TauHysteresis(theta_on=0.7, theta_off=0.5, L_on=100, L_off=100)
#   enforce_flags = [fsm.step(float(p)) for p in p_attack.tolist()]

### Base Model Initialization

#### Fusion Head

In [None]:
# ==== CELL: FusionHead (learned ensemble) ====
# Input choices for z:
#   Minimal now:           [p_tcn_attack, p_svd]                     → in_dim = 2
#   When AE arrives:       [p_tcn_attack, p_svd, ae_score]           → in_dim = 3
#   With SVD features:     [+ svd_feats(4)]                          → in_dim = 6 or 7
#
# Output: p_attack_fused ∈ (0,1)  (binary "attack present" prob per window)

import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionHead(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 8, dropout: float = 0.0):
        super().__init__()
        # Start tiny; you can bump hidden later if needed.
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )
        
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # z: [B, in_dim]
        logit = self.net(z).squeeze(-1)      # [B]
        return torch.sigmoid(logit)          # p_attack_fused

#### Hyperparameters

In [None]:
# ==== CELL: Model + Hyperparameter Initialization (with early-stop params) ====
import os, math, json, random
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

# ---- Reproducibility & device ----
SEED = 2025
def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_all(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# ---- Shared constants (sync with preprocessing/loader) ----
C_PRB = 106
TOP_K = 6
K_PLUS_REST = TOP_K + 1
# Feature depth for x_seq from Dataset:
#   shares(K+1) + dshares(K+1) + {J, J_min, J_med} + smallRB_run(K+1)
D_SEQ = (K_PLUS_REST) + (K_PLUS_REST) + 3 + (K_PLUS_REST)  # 24 for TOP_K=6

# ---- Hyperparameters (now includes early-stop + tiny LR grid placeholders) ----
HYP = {
    "batch_size": 64,
    "epochs": 30,
    "lr": 3e-4,
    "lr_grid": [2e-4, 3e-4, 5e-4],     # optional small grid; pick one to run
    "weight_decay": 1e-4,
    "dropout": 0.05,
    "hidden": 32,
    "alpha_fusion": 0.7,               # p_attack = α*p_TCN + (1-α)*p_SVD
    "label_smoothing": 0.05,
    # τ-aligned hysteresis (≈50 ms @ 0.5 ms/TTI ⇒ 100 TTIs)
    "theta_on": 0.7,
    "theta_off": 0.5,
    "L_on": 100,
    "L_off": 100,
    # Early-stop config (used by TCN/SVD/Fusion loops)
    "early_stop": {
        "monitor": "f1",               # metric to monitor on validation
        "mode": "max",                 # 'max' for f1, 'min' for loss
        "patience": 5,                 # epochs with no improvement
        "min_delta": 1e-3              # minimum change to qualify as improvement
    },
    # scheduler placeholders (will refresh once loaders exist)
    "warmup_steps": 500,
    "total_steps": 20000,
}
# Non-destructive defaults for grid search
HYP.setdefault("lr_grid", [2e-4, 3e-4, 5e-4])
HYP.setdefault("warmup_grid", [200, 500])
HYP.setdefault("label_smoothing_grid", [0.0, 0.05, 0.1])
HYP.setdefault("class_weight_cap", 3.0)        # optional; cap max class weight at 3× mean
HYP.setdefault("fusion_use_svd_feats", False)  # start with [p_TCN,p_SVD]; toggle True to add 4 SVD feats
print("Hyperparameters:", json.dumps(HYP, indent=2))

# ---- Class weights (optional) ----
def compute_class_weights(meta_path: str) -> Optional[torch.Tensor]:
    if not os.path.exists(meta_path):
        return None
    import pandas as pd
    df = pd.read_parquet(meta_path)
    if "label" not in df.columns:
        return None
    counts = df["label"].value_counts().reindex([0,1,2]).fillna(0).to_numpy(dtype=np.float32)
    counts[counts == 0] = 1.0
    inv = 1.0 / counts
    w = inv / inv.sum() * 3.0
    return torch.tensor(w, dtype=torch.float32)

cw = compute_class_weights("windows_long_meta.parquet")
if cw is None:
    cw = compute_class_weights("windows_short_meta.parquet")
if cw is None:
    cw = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32)
class_weights = cw
print("Class weights:", class_weights.tolist())

# Cap class weights to avoid instability under extreme skew
if HYP["class_weight_cap"] is not None:
    mean_w = class_weights.float().mean()
    class_weights = torch.clamp(class_weights.float(), max=HYP["class_weight_cap"] * mean_w)
    print("Capped class weights:", class_weights.tolist())

# ---- Instantiate models (cells defining these must have been run) ----
try:
    tcn_model = DualHeadTCN(d_seq=D_SEQ, hidden=HYP["hidden"], dropout=HYP["dropout"], temperature=1.0).to(DEVICE)
    svd_head  = SVDGeometryHead(k_roles_plus_rest=K_PLUS_REST, hidden=16).to(DEVICE)
except NameError as e:
    raise RuntimeError("Run the TCN + SVD definition cells before this one.") from e

def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f"TCN params: {count_params(tcn_model):,}")
print(f"SVD head params: {count_params(svd_head):,}")

# ---- Optimizer & Scheduler ----
optim = torch.optim.AdamW([
    {"params": tcn_model.parameters(), "lr": HYP["lr"]},
    {"params": svd_head.parameters(),  "lr": HYP["lr"]},
], weight_decay=HYP["weight_decay"])

def make_scheduler(optimizer, warmup_steps: int, total_steps: int):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = make_scheduler(optim, HYP["warmup_steps"], HYP["total_steps"])

# ---- Loss ----
criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE), label_smoothing=HYP["label_smoothing"])

# ---- Optional loaders for fold 0 (only if metas exist) ----
train_loader = val_loader = None
if os.path.exists("windows_long_meta.parquet") or os.path.exists("windows_short_meta.parquet"):
    kind = "long" if os.path.exists("windows_long_meta.parquet") else "short"
    import pandas as pd
    meta = pd.read_parquet(f"windows_{kind}_meta.parquet")

    # 5-fold grouped split by run_id
    rng = np.random.default_rng(SEED)
    run_ids = meta['run_id'].astype(str).unique().tolist()
    rng.shuffle(run_ids)
    folds = [set() for _ in range(5)]
    for i, rid in enumerate(run_ids):
        folds[i % 5].add(rid)
    fold = 0
    val_rids = folds[fold]
    train_idx = meta.index[~meta['run_id'].astype(str).isin(val_rids)].to_numpy()
    val_idx   = meta.index[ meta['run_id'].astype(str).isin(val_rids)].to_numpy()

    ds_full  = TtiTrustDataset(kind=kind)
    ds_train = Subset(ds_full, train_idx)
    ds_val   = Subset(ds_full,  val_idx)

    train_loader = DataLoader(ds_train,
        batch_size=HYP["batch_size"], shuffle=True,
        num_workers=8, pin_memory=True, prefetch_factor=4, persistent_workers=True)

    val_loader = DataLoader(ds_val,
       batch_size=HYP["batch_size"], shuffle=False,
       num_workers=8, pin_memory=True, prefetch_factor=4, persistent_workers=True)
    
    # Update scheduler steps with real batch counts
    HYP["total_steps"] = max(HYP["epochs"] * max(1, len(train_loader)), HYP["warmup_steps"] + 1000)
    scheduler = make_scheduler(optim, HYP["warmup_steps"], HYP["total_steps"])
    print(f"\nData: kind={kind}  train_batches={len(train_loader)}  val_batches={len(val_loader)}")
else:
    print("\nNo window metas yet. Models/optim/scheduler are initialized—create metas and re-run to attach loaders.")

# ---- Inference-time fusion + τ-hysteresis config ----
FUSION_ALPHA = HYP["alpha_fusion"]  # p_attack = α*p_TCN + (1-α)*p_SVD

class TauHysteresis:
    def __init__(self, theta_on=0.7, theta_off=0.5, L_on=100, L_off=100):
        self.theta_on = float(theta_on); self.theta_off = float(theta_off)
        self.L_on = int(L_on); self.L_off = int(L_off)
        self.state = False; self.streak = 0
    def reset(self): self.state=False; self.streak=0
    def step(self, p):
        if not self.state:
            if p >= self.theta_on:
                self.streak += 1
                if self.streak >= self.L_on: self.state=True; self.streak=0
            else:
                self.streak = 0
        else:
            if p <= self.theta_off:
                self.streak += 1
                if self.streak >= self.L_off: self.state=False; self.streak=0
            else:
                self.streak = 0
        return self.state

FSM_CFG = TauHysteresis(HYP["theta_on"], HYP["theta_off"], HYP["L_on"], HYP["L_off"])

print("\nInitialization complete: models on device, optimizer/scheduler ready, loaders (if metas present).")

### Base Model Training

#### TCN

In [None]:
# ==== CELL B: TCN grid search (LR × Warmup × Label Smoothing) with early-stop ====
import copy, time, math, torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

if '_epoch_tcn' not in globals():
    import torch.nn.functional as F

    def _epoch_tcn(model, loader, device, train=True):
        model.train(train)
        total_loss = 0.0; n = 0
        tp=fp=tn=fn=0
        longest_pred_run_attack = 0
        longest_pred_run_benign = 0
        cur_run_attack = 0
        cur_run_benign = 0

        for x_seq, x_aux, y, meta in loader:
            # NOTE: streamed loader already yields float32; keep cast cheap
            x_seq = x_seq.to(device, non_blocking=True)
            y     = y.to(device, non_blocking=True)

            if train:
                optim.zero_grad(set_to_none=True)

            # AMP + no_grad toggled by `train`
            with torch.set_grad_enabled(train), autocast(enabled=torch.cuda.is_available()):
                logits, probs = model(x_seq)      # probs: [B,3]
                loss = criterion(logits, y)

            if train:
                # Scale, unscale for clip, step, update
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optim)
                scaler.update()
                scheduler.step()

            bs = y.size(0)
            total_loss += float(loss.item()) * bs
            n += bs

            # Metrics (attack-vs-not threshold @ 0.5 on p_attack)
            p_attack = probs[:, 2].detach()
            y_attack = (y == 2)
            pred = (p_attack >= 0.5)

            tp += int(((pred == 1) & (y_attack == 1)).sum().item())
            fp += int(((pred == 1) & (y_attack == 0)).sum().item())
            tn += int(((pred == 0) & (y_attack == 0)).sum().item())
            fn += int(((pred == 0) & (y_attack == 1)).sum().item())

            # Run-length proxies (fast enough per batch)
            for pa, ya in zip(pred.tolist(), y_attack.tolist()):
                if ya:
                    cur_run_attack = cur_run_attack + 1 if pa else 0
                    longest_pred_run_attack = max(longest_pred_run_attack, cur_run_attack)
                    cur_run_benign = 0
                else:
                    cur_run_benign = cur_run_benign + 1 if pa else 0
                    longest_pred_run_benign = max(longest_pred_run_benign, cur_run_benign)
                    cur_run_attack = 0

            # drop meta quickly to avoid ref holding
            del meta

        loss_mean = total_loss / max(1, n)
        prec = tp / max(1, (tp + fp))
        rec  = tp / max(1, (tp + fn))
        f1   = 2 * prec * rec / max(1e-9, (prec + rec))
        acc  = (tp + tn) / max(1, (tp + tn + fp + fn))
        return {
            "loss": loss_mean, "acc": acc, "prec": prec, "rec": rec, "f1": f1,
            "longest_pred_run_attack": longest_pred_run_attack,
            "longest_pred_run_benign": longest_pred_run_benign,
        }

def tcn_grid_search():
    assert train_loader and val_loader, "Attach loaders before grid search."

    best = {"f1": -1.0, "cfg": None}
    base_state = DualHeadTCN(d_seq=D_SEQ, hidden=HYP["hidden"], dropout=HYP["dropout"], temperature=1.0).state_dict()

    for lr in HYP["lr_grid"]:
        for warm in HYP["warmup_grid"]:
            for ls in HYP["label_smoothing_grid"]:
                # fresh model per trial
                model = DualHeadTCN(d_seq=D_SEQ, hidden=HYP["hidden"], dropout=HYP["dropout"], temperature=1.0).to(DEVICE)
                model.load_state_dict(base_state, strict=False)

                # optimizer/scheduler per trial
                opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=HYP["weight_decay"])
                sched = torch.optim.lr_scheduler.LambdaLR(
                    opt, lr_lambda=lambda step: (step / max(1, warm))
                    if step < warm else 0.5 * (1.0 + math.cos(math.pi * (step - warm) / max(1, HYP["total_steps"] - warm)))
                )

                # criterion per trial
                crit = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE), label_smoothing=ls)

                # make objects visible to _epoch_tcn
                global tcn_model, optim, scheduler, criterion, scaler
                tcn_model, optim, scheduler, criterion = model, opt, sched, crit
                scaler = GradScaler(enabled=torch.cuda.is_available())

                # early stop
                es = HYP["early_stop"]; patience=int(es["patience"]); min_delta=float(es["min_delta"])
                best_val, wait = None, 0
                tag = f"tcn_lr{lr}_warm{warm}_ls{ls}"
                path = f"{tag}.pt"

                for ep in range(1, HYP["epochs"] + 1):
                    tr = _epoch_tcn(tcn_model, train_loader, DEVICE, train=True)
                    va = _epoch_tcn(tcn_model, val_loader,   DEVICE, train=False)
                    score = va["f1"]
                    improved = (best_val is None) or ((score - (best_val or -1.0)) > min_delta)
                    print(f"[TCN GRID] {tag} ep {ep:02d} | tr loss {tr['loss']:.3f} | val f1 {score:.3f} {'*' if improved else ''}")

                    if improved:
                        best_val, wait = score, 0
                        torch.save({"model": tcn_model.state_dict(),
                                    "HYP": {**HYP, "lr": lr, "warmup_steps": warm, "label_smoothing": ls}}, path)
                    else:
                        wait += 1
                        if wait >= patience:
                            break

                # track global best
                if (best_val or -1.0) > best["f1"]:
                    best = {"f1": best_val or -1.0,
                            "cfg": {"lr": lr, "warmup_steps": warm, "label_smoothing": ls, "ckpt": path}}

    print("TCN best:", best)
    if best["cfg"] is not None:
        torch.save(torch.load(best["cfg"]["ckpt"], map_location="cpu"), "tcn_best.pt")

# Run (comment if you want to delay)
tcn_grid_search()

#### TCN Fine Tuning (for sub 50 ms windows)

In [None]:
# ==== CELL: TCN classifier fine-tune on SHORT windows (streamed, AMP, classifier-only) ====
import os, torch, math
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

assert os.path.exists("tcn_best.pt"), "Run long-window grid first to create tcn_best.pt."

# 1) Build short-window streamed loaders (train/val split by run_id using meta shards)
from glob import glob
import pandas as pd
from numpy.random import default_rng

def load_short_meta():
    metas = [pd.read_parquet(p) for p in glob("win_shards/short_*_meta.parquet")]
    if not metas:
        raise FileNotFoundError("No short-window meta shards found in ./win_shards/. Run preprocessing.")
    return pd.concat(metas, ignore_index=True)

SEED = globals().get("SEED", 42)
meta_s = load_short_meta()
rng = default_rng(SEED)
run_ids = meta_s['run_id'].astype(str).unique().tolist()
rng.shuffle(run_ids)
folds = [set() for _ in range(5)]
for i, rid in enumerate(run_ids):
    folds[i % 5].add(rid)
val_rids = folds[0]

# Restrict sources via loader factory; it still streams shards
def runs_to_sources(rids: set, meta: pd.DataFrame):
    sub = meta[meta['run_id'].astype(str).isin(rids)]
    return sub['source'].astype(str).unique().tolist()

train_sources = runs_to_sources(set(run_ids) - val_rids, meta_s)
val_sources   = runs_to_sources(val_rids, meta_s)

# If you want to strictly filter by runs, you can shard-scan and drop rows not in the run set;
# for speed, we’ll accept tiny leakage risk at shard granularity and rely on run_id split in metrics.

# The streamed loaders (already yield batches)
train_loader_short = make_loader(kind="short", batch_size=globals().get("BATCH_SZ", 1024), num_workers=4)
val_loader_short   = make_loader(kind="short", batch_size=globals().get("BATCH_SZ", 1024), num_workers=4)

# 2) Reload best TCN, freeze all but classifier
ckpt = torch.load("tcn_best.pt", map_location="cpu")
tcn_model = DualHeadTCN(d_seq=D_SEQ, hidden=HYP["hidden"], dropout=HYP["dropout"], temperature=1.0).to(DEVICE)
tcn_model.load_state_dict(ckpt["model"], strict=False)

for p in tcn_model.parameters(): 
    p.requires_grad = False
for p in tcn_model.classifier.parameters(): 
    p.requires_grad = True

# 3) Tiny-LR fine-tune on short windows (3–5 epochs), AMP, proper grad-clip
optim = torch.optim.AdamW(tcn_model.classifier.parameters(), lr=3e-5, weight_decay=HYP["weight_decay"])
scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lambda step: 1.0)  # constant LR
criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE), label_smoothing=HYP["label_smoothing"])
scaler = GradScaler(enabled=torch.cuda.is_available())

@torch.no_grad()
def eval_loop(model, loader):
    model.eval()
    tot=0.0; n=0; tp=fp=tn=fn=0
    for x_seq, x_aux, y, meta in loader:
        x_seq = x_seq.to(DEVICE, non_blocking=True)
        y     = y.to(DEVICE, non_blocking=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits, probs = model(x_seq)
            loss = criterion(logits, y)
        bs = y.size(0); tot += float(loss.item()) * bs; n += bs
        p = probs[:,2].detach(); y_a = (y==2)
        pred = (p>=0.5)
        tp += int(((pred==1)&(y_a==1)).sum().item())
        fp += int(((pred==1)&(y_a==0)).sum().item())
        tn += int(((pred==0)&(y_a==0)).sum().item())
        fn += int(((pred==0)&(y_a==1)).sum().item())
        del meta
    prec = tp/max(1,tp+fp); rec = tp/max(1,tp+fn)
    f1   = 2*prec*rec/max(1e-9,(prec+rec))
    return {"loss": tot/max(1,n), "f1": f1}

def train_loop(model, loader):
    model.train(True)
    tot=0.0; n=0
    for x_seq, x_aux, y, meta in loader:
        x_seq = x_seq.to(DEVICE, non_blocking=True)
        y     = y.to(DEVICE, non_blocking=True)
        optim.zero_grad(set_to_none=True)
        with autocast(enabled=torch.cuda.is_available()):
            logits, probs = model(x_seq)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(tcn_model.classifier.parameters(), 1.0)
        scaler.step(optim); scaler.update()
        scheduler.step()
        tot += float(loss.item()) * y.size(0); n += y.size(0)
        del meta
    return {"loss": tot/max(1,n)}

best_f1, patience, wait = -1.0, 3, 0
for ep in range(1, 6):  # 3–5 epochs
    tr = train_loop(tcn_model, train_loader_short)
    va = eval_loop (tcn_model, val_loader_short)
    print(f"[TCN FT] ep {ep:02d} | train loss {tr['loss']:.4f} | val f1 {va['f1']:.3f}")
    if va["f1"] > best_f1 + 1e-3:
        best_f1, wait = va["f1"], 0
        torch.save({"model": tcn_model.state_dict(), "HYP": {**HYP, "fine_tuned_short": True}}, "tcn_best.pt")
    else:
        wait += 1
        if wait >= patience:
            print("[TCN FT] early stop"); break

#### SVD

In [None]:
# ==== CELL C: SVD grid search (LR × Warmup) with early-stop, AMP ====
import math, torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

if '_epoch_svd' not in globals():
    import torch.nn.functional as F

    def _epoch_svd(model, loader, device, train=True):
        model.train(train)
        total_loss = 0.0; n = 0
        tp=fp=tn=fn=0

        for x_seq, x_aux, y, meta in loader:
            # SVD head usually needs only the role-share sequence slice. If your head takes full x_seq, keep as is.
            x_seq = x_seq.to(device, non_blocking=True)   # [B, W, D_in] float32 from loader
            y     = y.to(device, non_blocking=True).long()

            if train:
                optim.zero_grad(set_to_none=True)

            with torch.set_grad_enabled(train), autocast(enabled=torch.cuda.is_available()):
                logits, probs = model(x_seq)              # probs: [B,3]
                loss = criterion(logits, y)

            if train:
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                if HYP.get("clip_norm", 1.0):
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(HYP["clip_norm"]))
                scaler.step(optim); scaler.update()
                scheduler.step()

            bs = y.size(0); total_loss += float(loss.item()) * bs; n += bs

            # metrics @ 0.5 on p_attack
            p_attack = probs[:, 2].detach()
            y_attack = (y == 2)
            pred = (p_attack >= 0.5)

            tp += int(((pred == 1) & (y_attack == 1)).sum().item())
            fp += int(((pred == 1) & (y_attack == 0)).sum().item())
            tn += int(((pred == 0) & (y_attack == 0)).sum().item())
            fn += int(((pred == 0) & (y_attack == 1)).sum().item())

            del meta  # drop references ASAP

        loss_mean = total_loss / max(1, n)
        prec = tp / max(1, (tp + fp))
        rec  = tp / max(1, (tp + fn))
        f1   = 2 * prec * rec / max(1e-9, (prec + rec))
        acc  = (tp + tn) / max(1, (tp + tn + fp + fn))
        return {"loss": loss_mean, "acc": acc, "prec": prec, "rec": rec, "f1": f1}

def svd_grid_search():
    assert train_loader and val_loader, "Attach loaders before grid search."
    best = {"f1": -1.0, "cfg": None}

    # base weights for consistent init across trials
    base_state = SVDGeometryHead(k_roles_plus_rest=K_PLUS_REST, hidden=16).state_dict()

    for lr in HYP["lr_grid"]:
        for warm in HYP["warmup_grid"]:
            head = SVDGeometryHead(k_roles_plus_rest=K_PLUS_REST, hidden=16).to(DEVICE)
            head.load_state_dict(base_state, strict=False)

            opt = torch.optim.AdamW(head.parameters(), lr=lr, weight_decay=HYP["weight_decay"])
            sched = torch.optim.lr_scheduler.LambdaLR(
                opt, lr_lambda=lambda step: (step / max(1, warm))
                if step < warm else 0.5 * (1.0 + math.cos(math.pi * (step - warm) / max(1, HYP["total_steps"] - warm)))
            )
            crit = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE), label_smoothing=HYP.get("label_smoothing", 0.0))

            # expose to epoch fn
            global svd_head, optim, scheduler, criterion, scaler
            svd_head, optim, scheduler, criterion = head, opt, sched, crit
            scaler = GradScaler(enabled=torch.cuda.is_available())

            es = HYP["early_stop"]; patience=int(es["patience"]); min_delta=float(es["min_delta"])
            best_val, wait = None, 0
            tag  = f"svd_lr{lr}_warm{warm}"
            path = f"{tag}.pt"

            for ep in range(1, HYP["epochs"] + 1):
                tr = _epoch_svd(svd_head, train_loader, DEVICE, train=True)
                va = _epoch_svd(svd_head, val_loader,   DEVICE, train=False)
                score = va["f1"]
                improved = (best_val is None) or ((score - (best_val or -1.0)) > min_delta)
                print(f"[SVD GRID] {tag} ep {ep:02d} | tr loss {tr['loss']:.3f} | val f1 {score:.3f} {'*' if improved else ''}")

                if improved:
                    best_val, wait = score, 0
                    torch.save({"model": svd_head.state_dict(),
                                "HYP": {**HYP, "lr": lr, "warmup_steps": warm}}, path)
                else:
                    wait += 1
                    if wait >= patience:
                        break

            if (best_val or -1.0) > best["f1"]:
                best = {"f1": best_val or -1.0,
                        "cfg": {"lr": lr, "warmup_steps": warm, "ckpt": path}}

    print("SVD best:", best)
    if best["cfg"] is not None:
        torch.save(torch.load(best["cfg"]["ckpt"], map_location="cpu"), "svd_head_best.pt")

# Run (comment if you want to delay)
svd_grid_search()

### Fusion Training

#### FusionHead + Inference Wrappers

In [None]:
# ==== CELL: FusionHead + base inference wrappers (ready for fusion training) ====
import os, torch, torch.nn as nn, torch.nn.functional as F
from typing import Optional, Tuple, Dict, List

# Tiny learned fusion (start minimal; can add SVD feats / AE later)
class FusionHead(nn.Module):
    def __init__(self, in_dim: int = 2, hidden: int = 8, dropout: float = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # z: [B, in_dim]
        return torch.sigmoid(self.net(z)).squeeze(-1)  # [B]

def load_frozen_bases(tcn_ckpt="tcn_best.pt", svd_ckpt="svd_head_best.pt"):
    # Instantiate models (shapes driven by earlier cells)
    tcn = DualHeadTCN(d_seq=(K_PLUS_REST + K_PLUS_REST + 3 + K_PLUS_REST),
                      hidden=HYP["hidden"], dropout=HYP["dropout"], temperature=1.0)
    svd = SVDGeometryHead(k_roles_plus_rest=K_PLUS_REST, hidden=16)

    # Load checkpoints if they exist
    if os.path.exists(tcn_ckpt):
        sd = torch.load(tcn_ckpt, map_location="cpu")
        tcn.load_state_dict(sd["model"], strict=False)
    if os.path.exists(svd_ckpt):
        sd = torch.load(svd_ckpt, map_location="cpu")
        svd.load_state_dict(sd["model"], strict=False)

    tcn.to(DEVICE).eval()
    svd.to(DEVICE).eval()
    for p in tcn.parameters(): p.requires_grad = False
    for p in svd.parameters(): p.requires_grad = False
    return tcn, svd

@torch.no_grad()
def batch_base_outputs(tcn: nn.Module, svd: nn.Module, x_seq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    x_seq: [B, W, D_seq]
    returns: (p_tcn_attack [B], p_svd [B])
    """
    logits, p_all = tcn(x_seq)                 # [B,3]
    p_tcn_attack = p_all[:, 2]
    shares_seq = x_seq[:, :, :K_PLUS_REST]     # first K+1 channels are shares
    p_svd, _ = svd(shares_seq)                 # [B]
    return p_tcn_attack, p_svd

def make_fusion_batch(tcn: nn.Module, svd: nn.Module,
                      batch, include_ae: bool = False, ae_score: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    x_seq, x_aux, y, meta = batch
    x_seq = x_seq.to(DEVICE).float()
    y_attack = (y.to(DEVICE).long() == 2).float()  # [B]

    p_tcn, p_svd = batch_base_outputs(tcn, svd, x_seq)

    z_list = [p_tcn.unsqueeze(-1), p_svd.unsqueeze(-1)]

    # (c) OPTION: add 4 SVD geometry features to z when enabled
    if HYP.get("fusion_use_svd_feats", False):
        with torch.no_grad():
            _, svd_feats = svd(x_seq[:, :, :K_PLUS_REST])   # [B,4]
        z_list.append(svd_feats)

    if include_ae:
        if ae_score is None:
            ae_score = torch.zeros_like(p_tcn)
        z_list.append(ae_score.unsqueeze(-1))

    z = torch.cat(z_list, dim=-1)  # [B, 2] or [B, 6] (+AE if used)
    return z, y_attack

IN_DIM = 2 + (4 if HYP.get("fusion_use_svd_feats", False) else 0)
fusion = FusionHead(in_dim=IN_DIM, hidden=8, dropout=0.0).to(DEVICE)

#### Fusion training loop

In [None]:
# ==== CELL D: Fusion grid search (toggle SVD features; early-stop) ====
import math, torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

def fusion_grid_search(include_svd_feats=None):
    assert train_loader and val_loader, "Attach loaders before grid search."

    # Frozen bases
    tcn_frozen, svd_frozen = load_frozen_bases()
    tcn_frozen.eval(); svd_frozen.eval()
    for p in tcn_frozen.parameters(): p.requires_grad = False
    for p in svd_frozen.parameters(): p.requires_grad = False

    best = {"f1": -1.0, "cfg": None}
    use_svd_feats_list = [HYP["fusion_use_svd_feats"]] if include_svd_feats is None else [include_svd_feats]

    for use_svd_feats in use_svd_feats_list:
        IN_DIM = 2 + (4 if use_svd_feats else 0)
        fusion = FusionHead(in_dim=IN_DIM, hidden=8, dropout=0.0).to(DEVICE)
        opt = torch.optim.AdamW(fusion.parameters(), lr=HYP["lr"], weight_decay=HYP["weight_decay"])
        # keep LR constant for stability on tiny head
        sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda step: 1.0)

        # If FusionHead outputs probabilities in (0,1):
        criterion = nn.BCELoss()
        scaler = GradScaler(enabled=torch.cuda.is_available())

        es = HYP["early_stop"]; patience=int(es["patience"]); min_delta=float(es["min_delta"])
        best_val, wait = None, 0
        tag  = f"fusion_{'p2' if not use_svd_feats else 'p2p4'}"
        path = f"{tag}.pt"

        def _epoch(loader, train=True):
            fusion.train(train)
            tot = 0.0; n = 0; tp=fp=tn=fn=0

            for x_seq, x_aux, y, meta in loader:
                x_seq = x_seq.to(DEVICE, non_blocking=True)   # [B, W, Dseq]
                y_bin = (y.to(DEVICE, non_blocking=True).long() == 2).float()  # attack vs not

                # --- base model outputs (no grad) ---
                with torch.no_grad(), autocast(enabled=torch.cuda.is_available()):
                    # Returns probabilities p_tcn, p_svd in [0,1]
                    p_tcn, p_svd = batch_base_outputs(tcn_frozen, svd_frozen, x_seq)

                    feats = [p_tcn.unsqueeze(-1), p_svd.unsqueeze(-1)]
                    if use_svd_feats:
                        # Extract SVD geometry features (dom_ratio, entropy, gap, angle) = 4 dims
                        # Feed roles only (K_PLUS_REST channels) to the SVD head feature path
                        _, svd_feats = svd_frozen(x_seq[:, :, :K_PLUS_REST])
                        feats.append(svd_feats)  # [B,4]
                    z = torch.cat(feats, dim=-1)  # [B, IN_DIM]

                if train:
                    opt.zero_grad(set_to_none=True)

                # --- fusion forward + loss ---
                with torch.set_grad_enabled(train), autocast(enabled=torch.cuda.is_available()):
                    p = fusion(z)                     # [B,1] probability
                    loss = criterion(p, y_bin.unsqueeze(-1))

                if train:
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    if HYP.get("clip_norm", 1.0):
                        torch.nn.utils.clip_grad_norm_(fusion.parameters(), float(HYP["clip_norm"]))
                    scaler.step(opt); scaler.update(); sched.step()

                bs = y_bin.size(0)
                tot += float(loss.item()) * bs; n += bs

                pred = (p >= 0.5).squeeze(-1)        # threshold at 0.5
                tp += int(((pred==1)&(y_bin==1)).sum().item())
                fp += int(((pred==1)&(y_bin==0)).sum().item())
                tn += int(((pred==0)&(y_bin==0)).sum().item())
                fn += int(((pred==0)&(y_bin==1)).sum().item())

                del meta  # drop ref quickly

            prec = tp / max(1, (tp+fp)); rec = tp / max(1, (tp+fn))
            f1 = 2*prec*rec / max(1e-9, (prec+rec)); loss_mean = tot / max(1, n)
            return {"loss": loss_mean, "f1": f1}

        # --- training with early stop ---
        for ep in range(1, max(5, HYP["epochs"]//2) + 1):
            tr = _epoch(train_loader, train=True)
            va = _epoch(val_loader,   train=False)
            score = va["f1"]
            improved = (best_val is None) or ((score - (best_val or -1.0)) > min_delta)
            print(f"[FUSION GRID] {tag} ep {ep:02d} | tr loss {tr['loss']:.4f} | val f1 {score:.3f} {'*' if improved else ''}")

            if improved:
                best_val, wait = score, 0
                torch.save({"fusion": fusion.state_dict(),
                            "HYP": {**HYP, "fusion_use_svd_feats": use_svd_feats, "in_dim": IN_DIM}}, path)
            else:
                wait += 1
                if wait >= patience:
                    break

        if (best_val or -1.0) > best["f1"]:
            best = {"f1": best_val or -1.0, "cfg": {"use_svd_feats": use_svd_feats, "ckpt": path}}

    print("Fusion best:", best)
    if best["cfg"] is not None:
        torch.save(torch.load(best["cfg"]["ckpt"], map_location="cpu"), "fusion_head_best.pt")

# Run (comment if you want to delay)
fusion_grid_search()

### Inference + Hysteresis

In [None]:
# ==== CELL: Inference + τ-Hysteresis + Confusion Matrices (Final Model) ====
# This cell evaluates the full stack (TCN, SVD, Fusion, Fusion+Hysteresis)
# and prints confusion matrices + basic metrics for each.
#
# Assumes you have:
#   - DualHeadTCN, SVDGeometryHead, FusionHead classes defined
#   - DEVICE, HYP, FUSION_ALPHA, TOP_K, K_PLUS_REST
#   - a dataloader (e.g., val_loader) yielding (x_seq, x_aux, y, meta)
#
# Positive class is 'attack' (label == 2).
#
# How it works:
#   1) Loads frozen TCN/SVD (best checkpoints if present) and FusionHead if saved.
#   2) Streams the loader to compute p_TCN, p_SVD, p_FUSED per window.
#   3) Applies τ-hysteresis to p_FUSED to get final ENFORCE decisions.
#   4) Prints confusion matrices + precision/recall/F1 for:
#        - TCN (p>=0.5)
#        - SVD (p>=0.5)
#        - Fusion (p>=0.5)
#        - Fusion + Hysteresis (ENFORCE boolean)
#
# Optional: set INCLUDE_AE=True and provide ae_provider(x_seq, meta)->[B] to include AE in fusion.
#
from typing import Dict, Any, Tuple
import os, math, json
import numpy as np
import torch
import torch.nn.functional as F

INCLUDE_AE = False   # set True when AE is wired
AE_PROVIDER = None   # function (x_seq, meta)-> torch.Tensor [B], values in [0,1]

# --- τ-hysteresis ---
class TauHysteresis:
    def __init__(self, theta_on=0.7, theta_off=0.5, L_on=100, L_off=100):
        self.theta_on=float(theta_on); self.theta_off=float(theta_off)
        self.L_on=int(L_on); self.L_off=int(L_off)
        self.state=False; self.streak=0
    def reset(self):
        self.state=False; self.streak=0
    def step(self, p: float) -> bool:
        if not self.state:
            if p >= self.theta_on:
                self.streak += 1
                if self.streak >= self.L_on:
                    self.state = True; self.streak = 0
            else:
                self.streak = 0
        else:
            if p <= self.theta_off:
                self.streak += 1
                if self.streak >= self.L_off:
                    self.state = False; self.streak = 0
            else:
                self.streak = 0
        return self.state

def load_frozen_bases(tcn_ckpt="tcn_best.pt", svd_ckpt="svd_head_best.pt"):
    tcn = DualHeadTCN(d_seq=(K_PLUS_REST + K_PLUS_REST + 3 + K_PLUS_REST),
                      hidden=HYP["hidden"], dropout=HYP["dropout"], temperature=1.0)
    svd = SVDGeometryHead(k_roles_plus_rest=K_PLUS_REST, hidden=16)
    if os.path.exists(tcn_ckpt):
        sd = torch.load(tcn_ckpt, map_location="cpu")
        tcn.load_state_dict(sd["model"], strict=False)
    if os.path.exists(svd_ckpt):
        sd = torch.load(svd_ckpt, map_location="cpu")
        svd.load_state_dict(sd["model"], strict=False)
    tcn.to(DEVICE).eval(); svd.to(DEVICE).eval()
    for p in tcn.parameters(): p.requires_grad=False
    for p in svd.parameters(): p.requires_grad=False
    return tcn, svd

def load_or_default_fusion(in_dim: int = 2):
    fh = FusionHead(in_dim=in_dim, hidden=8, dropout=0.0).to(DEVICE)
    if os.path.exists("fusion_head_best.pt"):
        sd = torch.load("fusion_head_best.pt", map_location="cpu")
        ck_in_dim = sd.get("in_dim", in_dim)
        if ck_in_dim != in_dim:
            fh = FusionHead(in_dim=ck_in_dim, hidden=8, dropout=0.0).to(DEVICE)
        fh.load_state_dict(sd["fusion"], strict=False)
    fh.eval()
    for p in fh.parameters(): p.requires_grad=False
    return fh

@torch.no_grad()
def base_probs(tcn, svd, x_seq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    logits, p_all = tcn(x_seq)
    p_tcn = p_all[:, 2]
    p_svd, _ = svd(x_seq[:, :, :K_PLUS_REST])
    return p_tcn, p_svd

@torch.no_grad()
def fused_probs(tcn, svd, fusion_head, x_seq: torch.Tensor, ae_score: torch.Tensor=None) -> torch.Tensor:
    p_tcn, p_svd = base_probs(tcn, svd, x_seq)
    # if FusionHead checkpoint exists and has params, use it; else α-blend
    has_params = sum(p.numel() for p in fusion_head.parameters()) > 0
    if has_params:
        if ae_score is not None:
            z = torch.stack([p_tcn, p_svd, ae_score.to(p_tcn.device)], dim=-1)
        else:
            z = torch.stack([p_tcn, p_svd], dim=-1)
        return fusion_head(z)
    return FUSION_ALPHA * p_tcn + (1.0 - FUSION_ALPHA) * p_svd

def confusion(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Any]:
    # y_true, y_pred: 0/1 arrays for not-attack/attack
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    prec = tp / (tp + fp) if (tp + fp) else 0.0
    rec  = tp / (tp + fn) if (tp + fn) else 0.0
    f1   = 2*prec*rec / (prec + rec) if (prec + rec) else 0.0
    acc  = (tp + tn) / max(1, (tp+tn+fp+fn))
    return {
        "matrix": [[tn, fp],
                   [fn, tp]],
        "acc": acc, "prec": prec, "rec": rec, "f1": f1,
        "tp": tp, "tn": tn, "fp": fp, "fn": fn
    }

def evaluate_confusions(loader) -> Dict[str, Any]:
    assert loader is not None and len(loader) > 0, "Loader is empty."
    tcn, svd = load_frozen_bases()
    fusion = load_or_default_fusion(in_dim=3 if INCLUDE_AE else 2)

    # Collect predictions/probs + labels and meta per-run for hysteresis
    probs_tcn, probs_svd, probs_fused = [], [], []
    labels = []
    metas = []   # list of dicts per item: run_id, start_tti

    for x_seq, x_aux, y, meta in loader:
        x_seq = x_seq.to(DEVICE).float()
        ae_score = None
        if INCLUDE_AE and AE_PROVIDER is not None:
            ae_score = AE_PROVIDER(x_seq, meta)

        p_fused = fused_probs(tcn, svd, fusion, x_seq, ae_score=ae_score).detach().cpu().numpy()
        p_tcn, p_svd = base_probs(tcn, svd, x_seq)
        probs_tcn.extend(p_tcn.detach().cpu().numpy().tolist())
        probs_svd.extend(p_svd.detach().cpu().numpy().tolist())
        probs_fused.extend(p_fused.tolist())
        labels.extend(((y == 2).numpy().astype(int)).tolist())

        B = y.size(0)
        for i in range(B):
            metas.append({
                "run_id": str(meta["run_id"][i]),
                "start_tti": int(meta["start_tti"][i])
            })

    y_true = np.array(labels, dtype=int)
    y_tcn  = (np.array(probs_tcn)  >= 0.5).astype(int)
    y_svd  = (np.array(probs_svd)  >= 0.5).astype(int)
    y_fuse = (np.array(probs_fused)>= 0.5).astype(int)

    # Confusions without hysteresis
    cm_tcn  = confusion(y_true, y_tcn)
    cm_svd  = confusion(y_true, y_svd)
    cm_fuse = confusion(y_true, y_fuse)

    # Apply hysteresis on the fused stream per run (ordered by start_tti)
    th_on, th_off, Lon, Loff = HYP["theta_on"], HYP["theta_off"], HYP["L_on"], HYP["L_off"]
    # Group by run
    from collections import defaultdict
    by_run = defaultdict(list)
    for i, m in enumerate(metas):
        by_run[m["run_id"]].append((m["start_tti"], i))

    y_fuse_hyst = np.zeros_like(y_fuse)
    for rid, lst in by_run.items():
        lst.sort(key=lambda x: x[0])
        fsm = TauHysteresis(th_on, th_off, Lon, Loff)
        for _, idx in lst:
            state = fsm.step(float(probs_fused[idx]))
            y_fuse_hyst[idx] = 1 if state else 0

    cm_hyst = confusion(y_true, y_fuse_hyst)

    # Pretty print
    def pretty(cm: Dict[str, Any], name: str):
        M = cm["matrix"]
        print(f"\n=== {name} ===")
        print("Confusion Matrix (rows=true [0=benign,1=attack], cols=pred):")
        print(f"[{M[0][0]:5d}  {M[0][1]:5d}]  ← true 0")
        print(f"[{M[1][0]:5d}  {M[1][1]:5d}]  ← true 1")
        print(f"acc={cm['acc']:.3f}  prec={cm['prec']:.3f}  rec={cm['rec']:.3f}  f1={cm['f1']:.3f} "
              f"(tp={cm['tp']}, fp={cm['fp']}, tn={cm['tn']}, fn={cm['fn']})")

    pretty(cm_tcn,  "TCN (p>=0.5)")
    pretty(cm_svd,  "SVD (p>=0.5)")
    pretty(cm_fuse, "Fusion (p>=0.5)")
    pretty(cm_hyst, "Fusion + τ-Hysteresis (final ENFORCE)")

    return {
        "cm_tcn": cm_tcn, "cm_svd": cm_svd, "cm_fuse": cm_fuse, "cm_hyst": cm_hyst,
        "probs": {"tcn": probs_tcn, "svd": probs_svd, "fuse": probs_fused},
        "y_true": y_true.tolist()
    }

# --- Run on validation loader by default ---
if 'val_loader' in globals() and val_loader and len(val_loader) > 0:
    _ = evaluate_confusions(val_loader)
else:
    print("No validation loader attached. Build metas/loaders, then re-run this cell.")