In [1]:
"""
Graph WaveNet / MTGNN-style (Adaptive adjacency + Dilated Gated TCN)
3-node graph trading model: ETH (target), BTC, ADA on 1-minute data.

Notes on labels:
- Your triple-barrier labels already provide a clean 3-class target:
    y_tb: 0=down (SHORT), 1=flat (FLAT / no-trade), 2=up (LONG)
- This notebook trains a single 3-class model (SHORT/FLAT/LONG) and reports:
    trade_auc: AUC(trade vs no-trade) where trade = {SHORT,LONG} vs FLAT
    dir_auc:   AUC(direction) on true-trade samples only (LONG vs SHORT)
"""


'\nGraph WaveNet / MTGNN-style (Adaptive adjacency + Dilated Gated TCN)\n3-node graph trading model: ETH (target), BTC, ADA on 1-minute data.\n\nNotes on labels:\n- Your triple-barrier labels already provide a clean 3-class target:\n    y_tb: 0=down (SHORT), 1=flat (FLAT / no-trade), 2=up (LONG)\n- This notebook trains a single 3-class model (SHORT/FLAT/LONG) and reports:\n    trade_auc: AUC(trade vs no-trade) where trade = {SHORT,LONG} vs FLAT\n    dir_auc:   AUC(direction) on true-trade samples only (LONG vs SHORT)\n'

In [2]:
# Step 0: imports + config + seed  (COPIED AS-IS FROM YOUR NOTEBOOK)
# ======================================================================

import os
import math
import random
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.preprocessing import RobustScaler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, roc_auc_score


def seed_everything(seed: int = 1234) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(100)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

torch.set_num_threads(max(1, os.cpu_count() or 4))

CFG: Dict[str, Any] = {
    # data
    "freq": "1min",
    "data_dir": Path("../dataset"),
    "final_test_frac": 0.10,

    # order book
    "book_levels": 15,
    "top_levels": 5,
    "near_levels": 5,

    # walk-forward windows (in sample-space)
    "train_min_frac": 0.50,
    "val_window_frac": 0.10,
    "test_window_frac": 0.10,
    "step_window_frac": 0.10,

    # scaling
    "max_abs_feat": 10.0,
    "max_abs_edge": 6.0,

    # correlations / graph
    "corr_windows": [6 * 5, 12 * 5, 24 * 5, 48 * 5, 84 * 5],  # 30m,1h,2h,4h,7h
    "corr_lags": [0, 1, 2, 5],  # lead-lag (no leakage)
    "edges_mode": "all_pairs",  # "manual" | "all_pairs"
    "edges": [("ADA", "BTC"), ("ADA", "ETH"), ("ETH", "BTC")],  # used if edges_mode="manual"
    "add_self_loops": True,
    "edge_transform": "fisher",  # "none" | "fisher"
    "edge_scale": True,
    "edge_dropout": 0.10,

    # triple-barrier
    "tb_horizon": 1 * 30,
    "lookback": 4 * 12 * 5,
    "tb_pt_mult": 1.2,
    "tb_sl_mult": 1.1,
    "tb_min_barrier": 0.001,
    "tb_max_barrier": 0.006,

    # training
    "batch_size": 128,
    "epochs": 20,
    "lr": 3e-4,
    "weight_decay": 5e-4,
    "grad_clip": 1.0,
    "dropout": 0.15,

    # stability tricks
    "label_smoothing": 0.02,
    "use_weighted_sampler": True,
    "use_onecycle": True,

    # model dims
    "hidden": 128,
    "gnn_layers": 3,

    # --- Temporal (Conv -> AttnPool)
    "tcn_channels": 128,
    "tcn_layers": 3,
    "tcn_kernel": 2,
    "tcn_dropout": 0.20,
    "tcn_causal": True,

    "attn_pool_hidden": 128,
    "attn_pool_dropout": 0.10,

    # --- Learnable adjacency (MTGNN-style)
    # A_learned options:
    #   "emb": A = softmax((E1 @ E2^T)/temp)
    #   "matrix": A = softmax(A_logits/temp)
    "adj_mode": "emb",
    "adj_emb_dim": 8,
    "adj_temperature": 1.0,

    # A_prior from edge_attr (last timestep of the sequence)
    "prior_use_abs": False,       # if True: use abs(mean(edge_attr)) for weights
    "prior_diag_boost": 1.0,      # ensure diag >= this before row-normalization
    "prior_row_normalize": True,

    # mixing alpha
    "alpha_mode": "learned",      # "fixed" | "learned"
    "adj_alpha": 0.50,            # used if alpha_mode="fixed"
    "adj_alpha_min": 0.05,        # clamp if learned
    "adj_alpha_max": 0.95,

    # adjacency regularization
    "adj_l1_lambda": 1e-3,
    "adj_prior_lambda": 1e-2,

    # trading eval
    "cost_bps": 1.0,

    # threshold sweep grids (val only)
    "thr_trade_grid": [0.50, 0.55, 0.60, 0.65, 0.70, 0.75],
    "thr_dir_grid":   [0.50, 0.55, 0.60, 0.65, 0.70],

    # min trades constraints
    "eval_min_trades": 50,

    # anti-overtrading threshold selection
    "max_trade_rate_val": 0.65,
    "trade_rate_penalty": 0.10,
    "thr_objective": "pnl_sum",  # "pnl_sum" | "pnl_sharpe" | "pnl_per_trade"

    # dynamic quantile thresholds for thr_trade
    "proxy_target_trades": [50, 100, 200],
}

ASSETS = ["ADA", "BTC", "ETH"]
ASSET2IDX = {a: i for i, a in enumerate(ASSETS)}
TARGET_ASSET = "ETH"
TARGET_NODE = ASSET2IDX[TARGET_ASSET]


def build_edge_list(cfg: Dict[str, Any], assets: List[str]) -> List[Tuple[str, str]]:
    mode = str(cfg.get("edges_mode", "manual"))
    if mode == "manual":
        edges = list(cfg["edges"])
    elif mode == "all_pairs":
        edges = [(s, t) for s in assets for t in assets if s != t]
    else:
        raise ValueError(f"Unknown edges_mode={mode}")

    if bool(cfg.get("add_self_loops", True)):
        edges = edges + [(a, a) for a in assets]
    return edges


EDGE_LIST = build_edge_list(CFG, ASSETS)
EDGE_NAMES = [f"{s}->{t}" for s, t in EDGE_LIST]
EDGE_INDEX = torch.tensor([[ASSET2IDX[s], ASSET2IDX[t]] for (s, t) in EDGE_LIST], dtype=torch.long)

print("EDGE_LIST:", EDGE_NAMES)
print("EDGE_INDEX:", EDGE_INDEX.tolist())


DEVICE: cpu
EDGE_LIST: ['ADA->BTC', 'ADA->ETH', 'BTC->ADA', 'BTC->ETH', 'ETH->ADA', 'ETH->BTC', 'ADA->ADA', 'BTC->BTC', 'ETH->ETH']
EDGE_INDEX: [[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1], [0, 0], [1, 1], [2, 2]]


In [3]:
# Step 1: my original data loading (UNCHANGED / COPIED AS-IS)
# ======================================================================

def load_asset(asset: str, freq: str, data_dir: Path, book_levels: int, part: Tuple[int, int] = (0, 80)) -> pd.DataFrame:
    path = data_dir / f"{asset}_{freq}.csv"
    df = pd.read_csv(path)
    df = df.iloc[int(len(df) * part[0] / 100): int(len(df) * part[1] / 100)]

    df["timestamp"] = pd.to_datetime(df["system_time"]).dt.round("min")
    df = df.sort_values("timestamp").set_index("timestamp")

    bid_cols = [f"bids_notional_{i}" for i in range(book_levels)]
    ask_cols = [f"asks_notional_{i}" for i in range(book_levels)]

    needed = ["midpoint", "spread", "buys", "sells"] + bid_cols + ask_cols
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise ValueError(f"{asset}: missing columns in CSV: {missing[:10]}{'...' if len(missing) > 10 else ''}")

    return df[needed]


def load_all_assets() -> pd.DataFrame:
    freq = CFG["freq"]
    data_dir = CFG["data_dir"]
    book_levels = CFG["book_levels"]

    def rename_cols(df_one: pd.DataFrame, asset: str) -> pd.DataFrame:
        rename_map = {
            "midpoint": asset,
            "buys": f"buys_{asset}",
            "sells": f"sells_{asset}",
            "spread": f"spread_{asset}",
        }
        for i in range(book_levels):
            rename_map[f"bids_notional_{i}"] = f"bids_vol_{asset}_{i}"
            rename_map[f"asks_notional_{i}"] = f"asks_vol_{asset}_{i}"
        return df_one.rename(columns=rename_map)

    df_ada = rename_cols(load_asset("ADA", freq, data_dir, book_levels, part=(0, 75)), "ADA")
    df_btc = rename_cols(load_asset("BTC", freq, data_dir, book_levels, part=(0, 75)), "BTC")
    df_eth = rename_cols(load_asset("ETH", freq, data_dir, book_levels, part=(0, 75)), "ETH")

    df = df_ada.join(df_btc).join(df_eth).reset_index()
    return df


df = load_all_assets()
for a in ASSETS:
    df[f"lr_{a}"] = np.log(df[a]).diff().fillna(0.0)

print("Loaded df:", df.shape)
print("Columns example:", df.columns[:20].tolist())
print("Time range:", df["timestamp"].min(), "->", df["timestamp"].max())
print(df.head(2))


Loaded df: (12831, 106)
Columns example: ['timestamp', 'ADA', 'spread_ADA', 'buys_ADA', 'sells_ADA', 'bids_vol_ADA_0', 'bids_vol_ADA_1', 'bids_vol_ADA_2', 'bids_vol_ADA_3', 'bids_vol_ADA_4', 'bids_vol_ADA_5', 'bids_vol_ADA_6', 'bids_vol_ADA_7', 'bids_vol_ADA_8', 'bids_vol_ADA_9', 'bids_vol_ADA_10', 'bids_vol_ADA_11', 'bids_vol_ADA_12', 'bids_vol_ADA_13', 'bids_vol_ADA_14']
Time range: 2021-04-07 11:34:00+00:00 -> 2021-04-16 10:15:00+00:00
                  timestamp      ADA  spread_ADA      buys_ADA      sells_ADA  \
0 2021-04-07 11:34:00+00:00  1.16205      0.0001  56936.467913  258248.957367   
1 2021-04-07 11:35:00+00:00  1.16800      0.0022  56491.336799   78665.286640   

   bids_vol_ADA_0  bids_vol_ADA_1  bids_vol_ADA_2  bids_vol_ADA_3  \
0      876.869995     5984.169922        5.810000       18.240000   
1    33769.671875    23137.169922      550.299988      550.299988   

   bids_vol_ADA_4  ...  asks_vol_ETH_8  asks_vol_ETH_9  asks_vol_ETH_10  \
0    19844.640625  ...      37

In [4]:
# Step 1b: edge features (UNCHANGED / COPIED AS-IS)
# ======================================================================

def _fisher_z(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    x = np.clip(x, -0.999, 0.999)
    return 0.5 * np.log((1.0 + x + eps) / (1.0 - x + eps))


def build_corr_array(
    df_: pd.DataFrame,
    corr_windows: List[int],
    edges: List[Tuple[str, str]],
    lags: List[int],
    transform: str = "fisher",
) -> np.ndarray:
    """
    Edge features per time:
      for edge s->t:
        for lag in lags:
          corr(lr_s.shift(lag), lr_t) over rolling window
    No leakage: shift(lag>0) uses past of source.
    Self-loop edges a->a: constant 1.0.
    """
    T_ = len(df_)
    E_ = len(edges)
    W_ = len(corr_windows)
    Lg = len(lags)
    out = np.zeros((T_, E_, W_ * Lg), dtype=np.float32)

    lr_map = {a: df_[f"lr_{a}"].astype(float) for a in ASSETS}

    for ei, (s, t) in enumerate(edges):
        if s == t:
            out[:, ei, :] = 1.0
            continue

        src0 = lr_map[s]
        dst0 = lr_map[t]

        feat_idx = 0
        for lag in lags:
            src = src0.shift(int(lag)) if int(lag) > 0 else src0

            for w in corr_windows:
                r = src.rolling(int(w), min_periods=1).corr(dst0)
                r = np.nan_to_num(r.to_numpy(dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0)
                if transform == "fisher":
                    r = _fisher_z(r).astype(np.float32)
                out[:, ei, feat_idx] = r
                feat_idx += 1

    return out.astype(np.float32)


edge_feat = build_corr_array(
    df,
    CFG["corr_windows"],
    EDGE_LIST,
    CFG["corr_lags"],
    transform=str(CFG.get("edge_transform", "fisher")),
)

print("edge_feat shape:", edge_feat.shape, "(T,E,edge_dim)")
print("edge_dim =", edge_feat.shape[-1], " = windows * lags =", len(CFG["corr_windows"]) * len(CFG["corr_lags"]))
print("Edge names:", EDGE_NAMES)
print("edge_feat sample [t=100, first 3 edges]:\n", edge_feat[100, :3, :])
print("edge_feat stats: mean=", float(edge_feat.mean()), "std=", float(edge_feat.std()))


edge_feat shape: (12831, 9, 20) (T,E,edge_dim)
edge_dim = 20  = windows * lags = 20
Edge names: ['ADA->BTC', 'ADA->ETH', 'BTC->ADA', 'BTC->ETH', 'ETH->ADA', 'ETH->BTC', 'ADA->ADA', 'BTC->BTC', 'ETH->ETH']
edge_feat sample [t=100, first 3 edges]:
 [[ 6.7925054e-01  7.8185719e-01  8.3443433e-01  8.3443433e-01
   8.3443433e-01  3.4026209e-01  1.4996846e-01  1.6863135e-01
   1.6863135e-01  1.6863135e-01 -2.9266590e-02 -1.5999632e-01
  -2.5908518e-01 -2.5908518e-01 -2.5908518e-01  1.4277337e-01
  -1.0870378e-02  5.0888385e-04  5.0888385e-04  5.0888385e-04]
 [ 6.3383293e-01  6.9067067e-01  8.7368768e-01  8.7368768e-01
   8.7368768e-01  3.4575835e-01  1.5627505e-01  2.0534241e-01
   2.0534241e-01  2.0534241e-01  6.9384493e-02 -1.1163518e-01
  -1.7551416e-01 -1.7551416e-01 -1.7551416e-01 -1.6881377e-01
  -1.0781832e-01 -6.3380465e-02 -6.3380465e-02 -6.3380465e-02]
 [ 6.7925054e-01  7.8185719e-01  8.3443433e-01  8.3443433e-01
   8.3443433e-01 -4.8168253e-02 -1.6241662e-01 -1.6193953e-01
  -1.61

In [5]:
# Step 1c: triple-barrier labels (UNCHANGED / COPIED AS-IS)
# ======================================================================

def triple_barrier_labels_from_lr(
    lr: pd.Series,
    horizon: int,
    vol_window: int,
    pt_mult: float,
    sl_mult: float,
    min_barrier: float,
    max_barrier: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returns:
      y_tb: {0=down, 1=flat/no-trade, 2=up}
      exit_ret: realized log-return to exit (tp/sl/timeout)
      exit_t: exit index
      thr: barrier per t (float array, len T)
    No leakage: vol is shift(1).
    """
    lr = lr.astype(float).copy()
    T = len(lr)

    vol = lr.rolling(vol_window, min_periods=max(10, vol_window // 10)).std().shift(1)
    thr = (vol * np.sqrt(horizon)).clip(lower=min_barrier, upper=max_barrier)

    y = np.ones(T, dtype=np.int64)
    exit_ret = np.zeros(T, dtype=np.float32)
    exit_t = np.arange(T, dtype=np.int64)

    lr_np = lr.fillna(0.0).to_numpy(dtype=np.float64)
    thr_np = thr.fillna(min_barrier).to_numpy(dtype=np.float64)

    for t in range(T - horizon - 1):
        up = pt_mult * thr_np[t]
        dn = -sl_mult * thr_np[t]

        cum = 0.0
        hit = 1
        et = t + horizon
        er = 0.0

        for dt in range(1, horizon + 1):
            cum += lr_np[t + dt]
            if cum >= up:
                hit, et, er = 2, t + dt, cum
                break
            if cum <= dn:
                hit, et, er = 0, t + dt, cum
                break

        if hit == 1:
            er = float(np.sum(lr_np[t + 1: t + horizon + 1]))
            et = t + horizon

        y[t] = hit
        exit_ret[t] = er
        exit_t[t] = et

    return y, exit_ret, exit_t, thr_np


y_tb, exit_ret, exit_t, tb_thr = triple_barrier_labels_from_lr(
    df["lr_ETH"],
    horizon=CFG["tb_horizon"],
    vol_window=CFG["lookback"],
    pt_mult=CFG["tb_pt_mult"],
    sl_mult=CFG["tb_sl_mult"],
    min_barrier=CFG["tb_min_barrier"],
    max_barrier=CFG["tb_max_barrier"],
)

# two-stage labels
y_trade = (y_tb != 1).astype(np.int64)  # 1=trade, 0=no-trade
y_dir = (y_tb == 2).astype(np.int64)    # 1=up, 0=down (meaningful only when y_trade==1)

dist = np.bincount(y_tb, minlength=3)
print("TB dist [down,flat,up]:", dist)
print("Trade ratio (true):", float(y_trade.mean()))


TB dist [down,flat,up]: [2875 7413 2543]
Trade ratio (true): 0.42225859247135844


In [6]:
# Step 1d: node tensor (UNCHANGED / COPIED AS-IS)
# ======================================================================

EPS = 1e-6


def safe_log1p(x: np.ndarray) -> np.ndarray:
    return np.log1p(np.maximum(x, 0.0))


def build_node_tensor(df_: pd.DataFrame) -> Tuple[np.ndarray, List[str]]:
    """
    Features per asset:
      lr, spread,
      log_buys, log_sells, ofi,
      DI_15,
      DI_L0..DI_L4,
      near_ratio_bid, near_ratio_ask,
      di_near, di_far
    """
    book_levels = CFG["book_levels"]
    top_k = CFG["top_levels"]
    near_k = CFG["near_levels"]

    if near_k >= book_levels:
        raise ValueError("CFG['near_levels'] must be < CFG['book_levels']")

    feat_names = [
        "lr", "spread",
        "log_buys", "log_sells", "ofi",
        "DI_15",
        "DI_L0", "DI_L1", "DI_L2", "DI_L3", "DI_L4",
        "near_ratio_bid", "near_ratio_ask",
        "di_near", "di_far",
    ]

    feats_all = []
    for a in ASSETS:
        lr = df_[f"lr_{a}"].values.astype(np.float32)
        spread = df_[f"spread_{a}"].values.astype(np.float32)

        buys = df_[f"buys_{a}"].values.astype(np.float32)
        sells = df_[f"sells_{a}"].values.astype(np.float32)

        log_buys = safe_log1p(buys).astype(np.float32)
        log_sells = safe_log1p(sells).astype(np.float32)

        ofi = ((buys - sells) / (buys + sells + EPS)).astype(np.float32)

        bids_lvls = np.stack([df_[f"bids_vol_{a}_{i}"].values.astype(np.float32) for i in range(book_levels)], axis=1)
        asks_lvls = np.stack([df_[f"asks_vol_{a}_{i}"].values.astype(np.float32) for i in range(book_levels)], axis=1)

        bid_sum = bids_lvls.sum(axis=1)
        ask_sum = asks_lvls.sum(axis=1)
        di_15 = ((bid_sum - ask_sum) / (bid_sum + ask_sum + EPS)).astype(np.float32)

        di_levels = []
        for i in range(top_k):
            b = bids_lvls[:, i]
            s = asks_lvls[:, i]
            di_levels.append(((b - s) / (b + s + EPS)).astype(np.float32))
        di_l0_4 = np.stack(di_levels, axis=1)  # (T,5)

        bid_near = bids_lvls[:, :near_k].sum(axis=1)
        ask_near = asks_lvls[:, :near_k].sum(axis=1)
        bid_far = bids_lvls[:, near_k:].sum(axis=1)
        ask_far = asks_lvls[:, near_k:].sum(axis=1)

        near_ratio_bid = (bid_near / (bid_far + EPS)).astype(np.float32)
        near_ratio_ask = (ask_near / (ask_far + EPS)).astype(np.float32)

        di_near = ((bid_near - ask_near) / (bid_near + ask_near + EPS)).astype(np.float32)
        di_far = ((bid_far - ask_far) / (bid_far + ask_far + EPS)).astype(np.float32)

        Xa = np.column_stack([
            lr, spread,
            log_buys, log_sells, ofi,
            di_15,
            di_l0_4[:, 0], di_l0_4[:, 1], di_l0_4[:, 2], di_l0_4[:, 3], di_l0_4[:, 4],
            near_ratio_bid, near_ratio_ask,
            di_near, di_far,
        ]).astype(np.float32)

        feats_all.append(Xa)

    X = np.stack(feats_all, axis=1).astype(np.float32)  # (T,N,F)
    return X, feat_names


X_node_raw, node_feat_names = build_node_tensor(df)
T = len(df)
L = CFG["lookback"]
H = CFG["tb_horizon"]

t_min = L - 1
t_max = T - H - 2
sample_t = np.arange(t_min, t_max + 1)
n_samples = len(sample_t)

print("X_node_raw:", X_node_raw.shape, "edge_feat:", edge_feat.shape)
print("node_feat_names:", node_feat_names)
print("n_samples:", n_samples, "| t range:", int(sample_t[0]), "->", int(sample_t[-1]))
print(
    "Feature stats (TARGET asset, lr):",
    "mean=", float(X_node_raw[:, TARGET_NODE, node_feat_names.index("lr")].mean()),
    "std=", float(X_node_raw[:, TARGET_NODE, node_feat_names.index("lr")].std()),
)


X_node_raw: (12831, 3, 15) edge_feat: (12831, 9, 20)
node_feat_names: ['lr', 'spread', 'log_buys', 'log_sells', 'ofi', 'DI_15', 'DI_L0', 'DI_L1', 'DI_L2', 'DI_L3', 'DI_L4', 'near_ratio_bid', 'near_ratio_ask', 'di_near', 'di_far']
n_samples: 12561 | t range: 239 -> 12799
Feature stats (TARGET asset, lr): mean= 1.5748046280350536e-05 std= 0.0010532913729548454


In [7]:
# Step 1e: splits (UNCHANGED / COPIED AS-IS)
# ======================================================================

def make_final_holdout_split(n_samples_: int, final_test_frac: float) -> Tuple[np.ndarray, np.ndarray]:
    if not (0.0 < final_test_frac < 0.5):
        raise ValueError("final_test_frac should be in (0, 0.5)")
    n_final = max(1, int(round(final_test_frac * n_samples_)))
    n_cv = n_samples_ - n_final
    if n_cv <= 50:
        raise ValueError("Too few samples left for CV after holdout split.")
    idx_cv = np.arange(0, n_cv, dtype=np.int64)
    idx_final = np.arange(n_cv, n_samples_, dtype=np.int64)
    return idx_cv, idx_final


def make_walk_forward_splits(
    n_samples_: int,
    train_min_frac: float,
    val_window_frac: float,
    test_window_frac: float,
    step_window_frac: float,
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
    train_min = int(train_min_frac * n_samples_)
    val_w = max(1, int(val_window_frac * n_samples_))
    test_w = max(1, int(test_window_frac * n_samples_))
    step_w = max(1, int(step_window_frac * n_samples_))

    splits = []
    start = train_min
    while True:
        tr_end = start
        va_end = tr_end + val_w
        te_end = va_end + test_w
        if te_end > n_samples_:
            break

        idx_train = np.arange(0, tr_end, dtype=np.int64)
        idx_val = np.arange(tr_end, va_end, dtype=np.int64)
        idx_test = np.arange(va_end, te_end, dtype=np.int64)
        splits.append((idx_train, idx_val, idx_test))

        start += step_w

    return splits


idx_cv_all, idx_final_test = make_final_holdout_split(n_samples, CFG["final_test_frac"])
n_samples_cv = len(idx_cv_all)
n_samples_final = len(idx_final_test)

print("Holdout split:")
print(f"  n_samples total: {n_samples}")
print(f"  n_samples CV   : {n_samples_cv} ({100 * n_samples_cv / n_samples:.1f}%)")
print(f"  n_samples FINAL: {n_samples_final} ({100 * n_samples_final / n_samples:.1f}%)")
print("  CV range   :", int(idx_cv_all[0]), int(idx_cv_all[-1]))
print("  FINAL range:", int(idx_final_test[0]), int(idx_final_test[-1]))

walk_splits = make_walk_forward_splits(
    n_samples_=n_samples_cv,
    train_min_frac=CFG["train_min_frac"],
    val_window_frac=CFG["val_window_frac"],
    test_window_frac=CFG["test_window_frac"],
    step_window_frac=CFG["step_window_frac"],
)

print("\nWalk-forward folds:", len(walk_splits))
for i, (a, b, c) in enumerate(walk_splits, 1):
    print(f"  fold {i}: train={len(a)} | val={len(b)} | test={len(c)}")


Holdout split:
  n_samples total: 12561
  n_samples CV   : 11305 (90.0%)
  n_samples FINAL: 1256 (10.0%)
  CV range   : 0 11304
  FINAL range: 11305 12560

Walk-forward folds: 4
  fold 1: train=5652 | val=1130 | test=1130
  fold 2: train=6782 | val=1130 | test=1130
  fold 3: train=7912 | val=1130 | test=1130
  fold 4: train=9042 | val=1130 | test=1130


In [8]:
# Step 1f: dataset/scaling helpers (UNCHANGED / COPIED AS-IS)
# ======================================================================

class LobGraphSequenceDataset2Stage(Dataset):
    """
    Returns:
      x_seq: (L,N,F)
      e_seq: (L,E,edge_dim)
      y_trade: scalar
      y_dir: scalar
      exit_ret: scalar
    """
    def __init__(
        self,
        X_node: np.ndarray,
        E_feat: np.ndarray,
        y_trade_arr: np.ndarray,
        y_dir_arr: np.ndarray,
        exit_ret_arr: np.ndarray,
        sample_t_: np.ndarray,
        indices: np.ndarray,
        lookback: int,
    ):
        self.X_node = X_node
        self.E_feat = E_feat
        self.y_trade = y_trade_arr
        self.y_dir = y_dir_arr
        self.exit_ret = exit_ret_arr
        self.sample_t = sample_t_
        self.indices = indices.astype(np.int64)
        self.L = int(lookback)

    def __len__(self) -> int:
        return int(len(self.indices))

    def __getitem__(self, i: int):
        sidx = int(self.indices[i])
        t = int(self.sample_t[sidx])
        t0 = t - self.L + 1

        x_seq = self.X_node[t0:t + 1]  # (L,N,F)
        e_seq = self.E_feat[t0:t + 1]  # (L,E,D)

        yt = int(self.y_trade[t])
        yd = int(self.y_dir[t])
        er = float(self.exit_ret[t])

        return (
            torch.from_numpy(x_seq),
            torch.from_numpy(e_seq),
            torch.tensor(yt, dtype=torch.long),
            torch.tensor(yd, dtype=torch.long),
            torch.tensor(er, dtype=torch.float32),
        )


def collate_fn_2stage(batch):
    xs, es, yts, yds, ers = zip(*batch)
    return (
        torch.stack(xs, 0),   # (B,L,N,F)
        torch.stack(es, 0),   # (B,L,E,D)
        torch.stack(yts, 0),  # (B,)
        torch.stack(yds, 0),  # (B,)
        torch.stack(ers, 0),  # (B,)
    )


def fit_scale_nodes_train_only(
    X_node_raw_: np.ndarray,
    sample_t_: np.ndarray,
    idx_train: np.ndarray,
    max_abs: float = 10.0
) -> Tuple[np.ndarray, RobustScaler]:
    last_train_t = int(sample_t_[int(idx_train[-1])])
    train_time_mask = np.arange(0, last_train_t + 1)

    X_train_time = X_node_raw_[train_time_mask]  # (Ttr,N,F)
    _, _, Fdim = X_train_time.shape

    scaler = RobustScaler(with_centering=True, with_scaling=True, quantile_range=(5.0, 95.0))
    scaler.fit(X_train_time.reshape(-1, Fdim))

    X_scaled = scaler.transform(X_node_raw_.reshape(-1, Fdim)).reshape(X_node_raw_.shape).astype(np.float32)
    X_scaled = np.clip(X_scaled, -max_abs, max_abs).astype(np.float32)
    X_scaled = np.nan_to_num(X_scaled, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return X_scaled, scaler


def fit_scale_edges_train_only(
    E_raw_: np.ndarray,
    sample_t_: np.ndarray,
    idx_train: np.ndarray,
    max_abs: float = 6.0
) -> Tuple[np.ndarray, RobustScaler]:
    """
    Robust-scale edge features per fold (train timeline only).
    Fisher-transformed correlations can be heavy-tailed.
    """
    last_train_t = int(sample_t_[int(idx_train[-1])])
    train_time_mask = np.arange(0, last_train_t + 1)

    E_train_time = E_raw_[train_time_mask]  # (Ttr,E,D)
    _, _, D = E_train_time.shape

    scaler = RobustScaler(with_centering=True, with_scaling=True, quantile_range=(5.0, 95.0))
    scaler.fit(E_train_time.reshape(-1, D))

    E_scaled = scaler.transform(E_raw_.reshape(-1, D)).reshape(E_raw_.shape).astype(np.float32)
    E_scaled = np.clip(E_scaled, -max_abs, max_abs).astype(np.float32)
    E_scaled = np.nan_to_num(E_scaled, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return E_scaled, scaler


def subset_trade_indices(indices: np.ndarray, sample_t_: np.ndarray, y_trade_arr: np.ndarray) -> np.ndarray:
    tt = sample_t_[indices]
    mask = (y_trade_arr[tt] == 1)
    return indices[mask]


def split_trade_ratio(indices: np.ndarray, sample_t_: np.ndarray, y_trade_arr: np.ndarray) -> float:
    tt = sample_t_[indices]
    return float(y_trade_arr[tt].mean()) if len(tt) else float("nan")


In [None]:
# Step 0b: Graph WaveNet config additions (NEW)
# ======================================================================

CFG.update({
    # Graph WaveNet core channels
    "gwn_residual_channels": 64,
    "gwn_dilation_channels": 64,
    "gwn_skip_channels": 128,
    "gwn_end_channels": 128,

    # blocks/layers: dilations reset each block
    "gwn_blocks": 3,
    "gwn_layers_per_block": 2,
    "gwn_kernel_size": 2,

    # adaptive adjacency
    "adaptive_topk": 3,  # for 3 nodes, 3 keeps all; keep as knob for generalization

    # optional PnL-ish regularization during training
    "trade_prob_penalty": 0.01,  # penalize over-trading via mean(p_short+p_long)
Ðš
    # checkpointing
    "ckpt_dir": Path("./checkpoints_MTGNN1m_auc_3class"),
    "sel_metric_dir_weight": 0.50,  # selection metric: trade_auc + w * dir_auc
})

CFG["ckpt_dir"].mkdir(parents=True, exist_ok=True)

CLASS_NAMES = ["SHORT", "FLAT", "LONG"]
print("3-class mapping:", {0: "SHORT", 1: "FLAT", 2: "LONG"})
print("Checkpoint dir:", str(CFG["ckpt_dir"].resolve()))


3-class mapping: {0: 'SHORT', 1: 'FLAT', 2: 'LONG'}
Checkpoint dir: /Users/vitalii/Desktop/Model_Market_Microstructure/Graph_Neural_Network_for_Market_Microstructure/TGNN2026/checkpoints_gwnet_3class


In [10]:
# Step 2: dataset/dataloader (3-class)  (NEW)
# ======================================================================

class LobGraphSequenceDataset3Class(Dataset):
    """
    Returns:
      x_seq:    (L,N,F)
      e_seq:    (L,E,D)
      y_tb:     scalar in {0,1,2} (SHORT, FLAT, LONG)
      exit_ret: scalar (log-return to exit)
      sidx:     scalar sample index (for time-order reconstruction)
    """
    def __init__(
        self,
        X_node: np.ndarray,
        E_feat: np.ndarray,
        y_tb_arr: np.ndarray,
        exit_ret_arr: np.ndarray,
        sample_t_: np.ndarray,
        indices: np.ndarray,
        lookback: int,
    ):
        self.X_node = X_node
        self.E_feat = E_feat
        self.y_tb = y_tb_arr
        self.exit_ret = exit_ret_arr
        self.sample_t = sample_t_
        self.indices = indices.astype(np.int64)
        self.L = int(lookback)

    def __len__(self) -> int:
        return int(len(self.indices))

    def __getitem__(self, i: int):
        sidx = int(self.indices[i])
        t = int(self.sample_t[sidx])
        t0 = t - self.L + 1

        x_seq = self.X_node[t0:t + 1]  # (L,N,F)
        e_seq = self.E_feat[t0:t + 1]  # (L,E,D)
        y = int(self.y_tb[t])
        er = float(self.exit_ret[t])

        return (
            torch.from_numpy(x_seq),
            torch.from_numpy(e_seq),
            torch.tensor(y, dtype=torch.long),
            torch.tensor(er, dtype=torch.float32),
            torch.tensor(sidx, dtype=torch.long),
        )


def collate_fn_3class(batch):
    xs, es, ys, ers, sidxs = zip(*batch)
    return (
        torch.stack(xs, 0),    # (B,L,N,F)
        torch.stack(es, 0),    # (B,L,E,D)
        torch.stack(ys, 0),    # (B,)
        torch.stack(ers, 0),   # (B,)
        torch.stack(sidxs, 0), # (B,)
    )


def make_ce_weights_3class(y_np: np.ndarray) -> torch.Tensor:
    y_np = np.asarray(y_np, dtype=np.int64)
    counts = np.bincount(y_np, minlength=3).astype(np.float64)
    counts = np.maximum(counts, 1.0)
    w = counts.sum() / (3.0 * counts)
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)


def make_weighted_sampler_3class(y_np: np.ndarray) -> WeightedRandomSampler:
    y_np = np.asarray(y_np, dtype=np.int64)
    counts = np.bincount(y_np, minlength=3).astype(np.float64)
    counts = np.maximum(counts, 1.0)
    class_w = counts.sum() / (3.0 * counts)
    sample_w = class_w[y_np].astype(np.float64)
    sample_w = torch.tensor(sample_w, dtype=torch.double)
    return WeightedRandomSampler(weights=sample_w, num_samples=len(sample_w), replacement=True)


In [11]:
# Step 3: model definition (Graph WaveNet / MTGNN-style)  (NEW)
# ======================================================================

def build_static_adjacency_from_edges(edge_index: torch.Tensor, n_nodes: int, eps: float = 1e-8) -> torch.Tensor:
    """
    Build A_static (N,N), row-normalized, using the presence of edges in EDGE_LIST.
    """
    A = torch.zeros((n_nodes, n_nodes), dtype=torch.float32)
    src = edge_index[:, 0].long()
    dst = edge_index[:, 1].long()
    A[src, dst] = 1.0
    A = A / (A.sum(dim=-1, keepdim=True) + eps)
    return A


def build_adj_prior_from_edge_attr(
    edge_attr_last: torch.Tensor,    # (B,E,D)
    edge_index: torch.Tensor,        # (E,2) [src,dst]
    n_nodes: int,
    use_abs: bool = False,
    diag_boost: float = 1.0,
    row_normalize: bool = True,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Build A_prior (B,N,N) from edge_attr at the last timestep.
      w = sigmoid(mean(edge_attr)) in [0,1]
    Fill A[src,dst] = w, enforce diag >= diag_boost, row-normalize.
    """
    edge_attr_last = torch.nan_to_num(edge_attr_last, nan=0.0, posinf=0.0, neginf=0.0)
    B, E, D = edge_attr_last.shape
    r = edge_attr_last.mean(dim=-1)  # (B,E)
    if use_abs:
        r = r.abs()
    w = torch.sigmoid(r)  # (B,E)

    A = torch.zeros((B, n_nodes, n_nodes), device=edge_attr_last.device, dtype=edge_attr_last.dtype)
    src = edge_index[:, 0].to(edge_attr_last.device)
    dst = edge_index[:, 1].to(edge_attr_last.device)
    A[:, src, dst] = w

    diag = torch.arange(n_nodes, device=edge_attr_last.device)
    A[:, diag, diag] = torch.maximum(A[:, diag, diag], torch.full_like(A[:, diag, diag], float(diag_boost)))

    if row_normalize:
        A = A / (A.sum(dim=-1, keepdim=True) + eps)

    return torch.nan_to_num(A, nan=0.0, posinf=0.0, neginf=0.0)


class AdaptiveAdjacency(nn.Module):
    """
    Graph WaveNet-style adaptive adjacency from node embeddings:
      logits = relu(E1 @ E2^T) / temp
      (optional) top-k per row
      A_adapt = softmax(logits_row)
    Also returns sparsity_proxy = sigmoid(logits) for L1(offdiag).
    """
    def __init__(self, n_nodes: int, cfg: Dict[str, Any]):
        super().__init__()
        self.n = int(n_nodes)
        k = int(cfg.get("adj_emb_dim", 8))
        self.E1 = nn.Parameter(0.01 * torch.randn(self.n, k))
        self.E2 = nn.Parameter(0.01 * torch.randn(self.n, k))
        self.temp = float(cfg.get("adj_temperature", 1.0))
        self.temp = max(self.temp, 1e-3)
        self.topk = int(cfg.get("adaptive_topk", self.n))

    def forward(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        logits = (self.E1 @ self.E2.t())
        logits = F.relu(logits) / self.temp  # (N,N)
        sparsity_proxy = torch.sigmoid(logits)

        if self.topk is not None and 0 < self.topk < self.n:
            vals, idx = torch.topk(logits, k=self.topk, dim=-1)
            mask = torch.full_like(logits, fill_value=float("-inf"))
            mask.scatter_(-1, idx, vals)
            logits = mask

        A = torch.softmax(logits, dim=-1)  # row-stochastic
        return A, sparsity_proxy, logits


class LearnableSupportMix(nn.Module):
    """
    Blend supports (static, prior, adapt) using softmax weights.
    """
    def __init__(self, n_supports: int = 3):
        super().__init__()
        self.w_logits = nn.Parameter(torch.zeros(n_supports, dtype=torch.float32))

    def forward(self) -> torch.Tensor:
        return torch.softmax(self.w_logits, dim=0)


class CausalConv2dTime(nn.Module):
    """
    2D convolution causal along time dimension only.
    Input:  (B,C,N,T)
    Conv kernel: (1,k), dilation: (1,d)
    """
    def __init__(self, in_ch: int, out_ch: int, kernel_size: int, dilation: int):
        super().__init__()
        self.k = int(kernel_size)
        self.d = int(dilation)
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=(1, self.k), dilation=(1, self.d))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pad_left = (self.k - 1) * self.d
        x = F.pad(x, (pad_left, 0, 0, 0))  # pad time (W) on the left
        return self.conv(x)


def graph_message_passing(x: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
    """
    x: (B,C,N,T)
    A: (B,N,N) with A[src,dst]
    returns: (B,C,N,T) where dst aggregates from src
    """
    return torch.einsum("bcnt,bnm->bcmt", x, A)


class GraphWaveNetBlock(nn.Module):
    def __init__(self, residual_ch: int, dilation_ch: int, skip_ch: int, kernel_size: int, dilation: int, dropout: float):
        super().__init__()
        self.filter_conv = CausalConv2dTime(residual_ch, dilation_ch, kernel_size=kernel_size, dilation=dilation)
        self.gate_conv = CausalConv2dTime(residual_ch, dilation_ch, kernel_size=kernel_size, dilation=dilation)

        self.residual_conv = nn.Conv2d(dilation_ch, residual_ch, kernel_size=(1, 1))
        self.skip_conv = nn.Conv2d(dilation_ch, skip_ch, kernel_size=(1, 1))

        self.dropout = nn.Dropout(float(dropout))
        self.bn = nn.BatchNorm2d(residual_ch)

    def forward(self, x: torch.Tensor, A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x: (B,residual_ch,N,T)
        A: (B,N,N)
        """
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        A = torch.nan_to_num(A, nan=0.0, posinf=0.0, neginf=0.0)

        residual = x

        f = torch.tanh(self.filter_conv(x))
        g = torch.sigmoid(self.gate_conv(x))
        z = f * g  # (B,dilation_ch,N,T)

        z = self.dropout(z)

        skip = self.skip_conv(z)  # (B,skip_ch,N,T)

        out = self.residual_conv(z)  # (B,residual_ch,N,T)
        out = graph_message_passing(out, A)  # spatial mixing
        out = out + residual  # residual
        out = self.bn(out)

        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
        skip = torch.nan_to_num(skip, nan=0.0, posinf=0.0, neginf=0.0)
        return out, skip


class GraphWaveNet3Class(nn.Module):
    """
    Input:
      x_seq: (B,L,N,F)
      e_seq: (B,L,E,D)   (only used to build A_prior from last step)

    Output:
      logits_eth: (B,3) for ETH node only: [SHORT, FLAT, LONG]
    """
    def __init__(self, node_in: int, edge_dim: int, cfg: Dict[str, Any], n_nodes: int, target_node: int):
        super().__init__()
        self.cfg = cfg
        self.n_nodes = int(n_nodes)
        self.target_node = int(target_node)

        residual_ch = int(cfg["gwn_residual_channels"])
        dilation_ch = int(cfg["gwn_dilation_channels"])
        skip_ch = int(cfg["gwn_skip_channels"])
        end_ch = int(cfg["gwn_end_channels"])
        k = int(cfg["gwn_kernel_size"])
        blocks = int(cfg["gwn_blocks"])
        layers_per_block = int(cfg["gwn_layers_per_block"])
        drop = float(cfg.get("dropout", 0.0))

        self.in_proj = nn.Linear(int(node_in), residual_ch)

        # supports
        A_static = build_static_adjacency_from_edges(EDGE_INDEX, n_nodes=self.n_nodes)
        self.register_buffer("A_static", A_static)

        self.adapt = AdaptiveAdjacency(n_nodes=self.n_nodes, cfg=cfg)
        self.support_mix = LearnableSupportMix(n_supports=3)

        # blocks with dilation schedule resetting each block
        self.blocks = nn.ModuleList()
        for b in range(blocks):
            for l in range(layers_per_block):
                dilation = 2 ** l
                self.blocks.append(GraphWaveNetBlock(
                    residual_ch=residual_ch,
                    dilation_ch=dilation_ch,
                    skip_ch=skip_ch,
                    kernel_size=k,
                    dilation=dilation,
                    dropout=drop,
                ))

        self.end1 = nn.Conv2d(skip_ch, end_ch, kernel_size=(1, 1))
        self.end2 = nn.Conv2d(end_ch, 3, kernel_size=(1, 1))  # 3-class

        # init
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def _compute_supports(self, e_seq: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]:
        """
        Build A_prior (batch), A_adapt (global), then mix:
          A_mix = w0*A_static + w1*A_prior + w2*A_adapt
        """
        B, L_, E, D = e_seq.shape
        e_last = e_seq[:, -1, :, :]  # (B,E,D)

        A_prior = build_adj_prior_from_edge_attr(
            edge_attr_last=e_last,
            edge_index=EDGE_INDEX.to(e_seq.device),
            n_nodes=self.n_nodes,
            use_abs=bool(self.cfg.get("prior_use_abs", False)),
            diag_boost=float(self.cfg.get("prior_diag_boost", 1.0)),
            row_normalize=bool(self.cfg.get("prior_row_normalize", True)),
        )  # (B,N,N)

        A_adapt_base, sparsity_proxy, adapt_logits = self.adapt()  # (N,N)
        A_adapt = A_adapt_base.unsqueeze(0).expand(B, -1, -1)      # (B,N,N)

        w = self.support_mix()  # (3,)
        A_static = self.A_static.to(e_seq.device).to(e_seq.dtype).unsqueeze(0).expand(B, -1, -1)

        A_mix = w[0] * A_static + w[1] * A_prior + w[2] * A_adapt
        A_mix = A_mix / (A_mix.sum(dim=-1, keepdim=True) + 1e-8)

        # regularization terms (adapt only)
        N = self.n_nodes
        offdiag = (1.0 - torch.eye(N, device=e_seq.device, dtype=e_seq.dtype))
        l1_off = (sparsity_proxy.to(e_seq.dtype) * offdiag).abs().mean()
        mse_prior = ((A_adapt - A_prior) ** 2 * offdiag).mean()

        aux = {
            "support_w": w.detach().cpu().numpy().tolist(),
            "l1_off": float(l1_off.detach().cpu().item()),
            "mse_prior": float(mse_prior.detach().cpu().item()),
            "_l1_off_t": l1_off,
            "_mse_prior_t": mse_prior,
        }
        return A_mix, aux

    def forward(self, x_seq: torch.Tensor, e_seq: torch.Tensor, return_aux: bool = False):
        x_seq = torch.nan_to_num(x_seq, nan=0.0, posinf=0.0, neginf=0.0)
        e_seq = torch.nan_to_num(e_seq, nan=0.0, posinf=0.0, neginf=0.0)

        B, L_, N, Fdim = x_seq.shape
        assert N == self.n_nodes

        # (B,L,N,F) -> (B,N,L,residual_ch) -> (B,residual_ch,N,L)
        x = self.in_proj(x_seq)              # (B,L,N,C)
        x = x.permute(0, 3, 2, 1).contiguous()  # (B,C,N,T)

        A_mix, aux = self._compute_supports(e_seq)

        skip_sum = None
        for blk in self.blocks:
            x, skip = blk(x, A_mix)
            skip_sum = skip if skip_sum is None else (skip_sum + skip)

        y = F.relu(skip_sum)
        y = F.relu(self.end1(y))
        y = self.end2(y)  # (B,3,N,T)

        logits_eth = y[:, :, self.target_node, -1]  # (B,3)
        logits_eth = torch.nan_to_num(logits_eth, nan=0.0, posinf=0.0, neginf=0.0)

        if return_aux:
            return logits_eth, aux
        return logits_eth


In [None]:
# Step 4: training loop with folds (same split logic)  (NEW)
# ======================================================================

def total_loss_with_adj_reg(loss: torch.Tensor, aux: Dict[str, Any], cfg: Dict[str, Any]) -> torch.Tensor:
    lam_l1 = float(cfg.get("adj_l1_lambda", 0.0))
    lam_pr = float(cfg.get("adj_prior_lambda", 0.0))
    reg = 0.0
    if lam_l1 > 0:
        reg = reg + lam_l1 * aux["_l1_off_t"]
    if lam_pr > 0:
        reg = reg + lam_pr * aux["_mse_prior_t"]
    return loss + reg


def _safe_auc_binary(y_true: np.ndarray, score: np.ndarray) -> float:
    y_true = np.asarray(y_true, dtype=np.int64)
    score = np.asarray(score, dtype=np.float64)
    if y_true.size == 0 or len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, score))


def compute_trade_dir_auc_from_probs(y_tb_true: np.ndarray, prob3: np.ndarray) -> Tuple[float, float]:
    """
    y_tb_true: (n,) in {0,1,2}
    prob3: (n,3) in order [SHORT,FLAT,LONG]
    trade_auc: trade vs no-trade where trade={0,2} vs flat=1, score = 1 - p_flat
    dir_auc: LONG vs SHORT on true-trade samples only, score = p_long/(p_long+p_short)
    """
    y_tb_true = np.asarray(y_tb_true, dtype=np.int64)
    prob3 = np.asarray(prob3, dtype=np.float64)

    y_trade_bin = (y_tb_true != 1).astype(np.int64)
    p_trade = 1.0 - prob3[:, 1]
    trade_auc = _safe_auc_binary(y_trade_bin, p_trade)

    mask_trade = (y_tb_true != 1)
    y_dir_bin = (y_tb_true[mask_trade] == 2).astype(np.int64)  # 1=LONG, 0=SHORT
    p_short = prob3[mask_trade, 0]
    p_long = prob3[mask_trade, 2]
    p_dir = p_long / (p_long + p_short + 1e-12)
    dir_auc = _safe_auc_binary(y_dir_bin, p_dir)

    return trade_auc, dir_auc


def pnl_from_probs_3class(prob3: np.ndarray, exit_ret_arr: np.ndarray, thr_trade: float, thr_dir: float, cost_bps: float) -> Dict[str, Any]:
    prob3 = np.asarray(prob3, dtype=np.float64)
    exit_ret_arr = np.asarray(exit_ret_arr, dtype=np.float64)

    p_short = prob3[:, 0]
    p_flat = prob3[:, 1]
    p_long = prob3[:, 2]

    trade_conf = 1.0 - p_flat
    dir_prob = p_long / (p_long + p_short + 1e-12)
    dir_conf = np.maximum(dir_prob, 1.0 - dir_prob)

    mask = (trade_conf >= float(thr_trade)) & (dir_conf >= float(thr_dir))

    action = np.zeros_like(exit_ret_arr, dtype=np.float64)
    action[mask] = np.where(dir_prob[mask] >= 0.5, 1.0, -1.0)

    cost = (float(cost_bps) * 1e-4) * mask.astype(np.float64)
    pnl = action * exit_ret_arr - cost

    n = int(len(exit_ret_arr))
    n_tr = int(mask.sum())

    return {
        "n": n,
        "n_trades": n_tr,
        "trade_rate": float(n_tr / max(1, n)),
        "pnl_sum": float(pnl.sum()),
        "pnl_mean": float(pnl.mean()) if n else float("nan"),
        "pnl_per_trade": float(pnl.sum() / max(1, n_tr)),
        "pnl_sharpe": float((pnl.mean() / (pnl.std() + 1e-12)) * np.sqrt(288)) if n else float("nan"),
    }


def build_trade_threshold_grid(p_trade: np.ndarray, base_grid: Optional[List[float]], target_trades_list: Optional[List[int]]) -> List[float]:
    p_trade = np.asarray(p_trade, dtype=np.float64)
    p_trade = p_trade[np.isfinite(p_trade)]
    if p_trade.size == 0:
        return base_grid or [0.5]

    thrs = set(float(t) for t in (base_grid or []))

    if target_trades_list:
        N = int(p_trade.size)
        for k in target_trades_list:
            k = int(k)
            if k <= 0:
                continue
            if k >= N:
                thr = float(np.min(p_trade))
            else:
                q = 1.0 - (k / N)
                thr = float(np.quantile(p_trade, q))
            thrs.add(float(np.clip(thr, 0.01, 0.99)))

    out = sorted(thrs)
    cleaned = []
    for t in out:
        if not cleaned or abs(t - cleaned[-1]) > 1e-6:
            cleaned.append(float(t))
    return cleaned


def sweep_thresholds_3class(prob3: np.ndarray, exit_ret_arr: np.ndarray, cfg: Dict[str, Any], min_trades: int, target_trade_rate: Optional[float]) -> pd.DataFrame:
    prob3 = np.asarray(prob3, dtype=np.float64)
    p_flat = prob3[:, 1]
    p_trade = 1.0 - p_flat

    thr_trade_grid = build_trade_threshold_grid(
        p_trade=p_trade,
        base_grid=cfg.get("thr_trade_grid", [0.5]),
        target_trades_list=cfg.get("proxy_target_trades", None),
    )
    thr_dir_grid = cfg.get("thr_dir_grid", [0.5])

    obj = str(cfg.get("thr_objective", "pnl_sum"))
    max_rate = cfg.get("max_trade_rate_val", None)
    penalty = float(cfg.get("trade_rate_penalty", 0.0))

    rows = []
    for thr_t in thr_trade_grid:
        for thr_d in thr_dir_grid:
            m = pnl_from_probs_3class(prob3, exit_ret_arr, thr_t, thr_d, cfg["cost_bps"])
            if int(m["n_trades"]) < int(min_trades):
                continue
            if max_rate is not None and float(m["trade_rate"]) > float(max_rate):
                continue

            base = float(m.get(obj, np.nan))
            if not np.isfinite(base):
                continue

            if target_trade_rate is not None:
                score = base - penalty * abs(float(m["trade_rate"]) - float(target_trade_rate))
            else:
                score = base - penalty * float(m["trade_rate"])

            rows.append({"thr_trade": float(thr_t), "thr_dir": float(thr_d), "score": float(score), **m})

    if not rows:
        # relax
        return sweep_thresholds_3class(prob3, exit_ret_arr, cfg, min_trades=1, target_trade_rate=target_trade_rate)

    return pd.DataFrame(rows).sort_values(["score", "pnl_sum"], ascending=False)


@torch.no_grad()
def predict_probs_on_indices_3class(
    model: nn.Module,
    X_scaled: np.ndarray,
    edge_scaled: np.ndarray,
    indices: np.ndarray,
    cfg: Dict[str, Any],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    ds = LobGraphSequenceDataset3Class(
        X_node=X_scaled,
        E_feat=edge_scaled,
        y_tb_arr=y_tb,
        exit_ret_arr=exit_ret,
        sample_t_=sample_t,
        indices=indices,
        lookback=cfg["lookback"],
    )
    loader = DataLoader(ds, batch_size=int(cfg["batch_size"]), shuffle=False, collate_fn=collate_fn_3class, num_workers=0)

    probs = []
    ers = []
    ys = []
    for x, e, y, er, _sidx in loader:
        x = x.to(DEVICE).float()
        e = e.to(DEVICE).float()
        logits = model(x, e, return_aux=False)
        p = torch.softmax(logits, dim=-1).detach().cpu().numpy()
        probs.append(p)
        ers.append(er.detach().cpu().numpy())
        ys.append(y.detach().cpu().numpy())

    return np.concatenate(probs, axis=0), np.concatenate(ers, axis=0), np.concatenate(ys, axis=0)


@torch.no_grad()
def eval_3class_on_indices(
    model: nn.Module,
    X_scaled: np.ndarray,
    edge_scaled: np.ndarray,
    indices: np.ndarray,
    loss_fn: nn.Module,
    cfg: Dict[str, Any],
) -> Dict[str, Any]:
    ds = LobGraphSequenceDataset3Class(
        X_node=X_scaled,
        E_feat=edge_scaled,
        y_tb_arr=y_tb,
        exit_ret_arr=exit_ret,
        sample_t_=sample_t,
        indices=indices,
        lookback=cfg["lookback"],
    )
    loader = DataLoader(ds, batch_size=int(cfg["batch_size"]), shuffle=False, collate_fn=collate_fn_3class, num_workers=0)

    model.eval()
    total_loss = 0.0
    n = 0

    probs = []
    ers = []
    ys = []

    for x, e, y, er, _sidx in loader:
        x = x.to(DEVICE).float()
        e = e.to(DEVICE).float()
        y = y.to(DEVICE).long()

        logits, aux = model(x, e, return_aux=True)
        ce = loss_fn(logits, y)
        loss = total_loss_with_adj_reg(ce, aux, cfg)

        total_loss += float(loss.item()) * int(y.size(0))
        n += int(y.size(0))

        p = torch.softmax(logits, dim=-1).detach().cpu().numpy()
        probs.append(p)
        ers.append(er.detach().cpu().numpy())
        ys.append(y.detach().cpu().numpy())

    prob3 = np.concatenate(probs, axis=0) if probs else np.zeros((0, 3), dtype=np.float64)
    er_arr = np.concatenate(ers, axis=0) if ers else np.zeros((0,), dtype=np.float64)
    y_arr = np.concatenate(ys, axis=0) if ys else np.zeros((0,), dtype=np.int64)

    trade_auc, dir_auc = compute_trade_dir_auc_from_probs(y_arr, prob3)

    y_pred = prob3.argmax(axis=1) if len(y_arr) else np.array([], dtype=np.int64)
    acc = float(accuracy_score(y_arr, y_pred)) if len(y_arr) else float("nan")
    f1m = float(f1_score(y_arr, y_pred, average="macro")) if len(y_arr) else float("nan")
    cm = confusion_matrix(y_arr, y_pred, labels=[0, 1, 2]) if len(y_arr) else None

    return {
        "loss": float(total_loss / max(1, n)),
        "acc": acc,
        "f1m": f1m,
        "cm": cm,
        "trade_auc": float(trade_auc) if np.isfinite(trade_auc) else float("nan"),
        "dir_auc": float(dir_auc) if np.isfinite(dir_auc) else float("nan"),
        "prob3": prob3,
        "er": er_arr,
        "y": y_arr,
    }


def train_one_fold_3class(
    fold_id: int,
    X_scaled: np.ndarray,
    edge_scaled: np.ndarray,
    idx_train: np.ndarray,
    idx_val: np.ndarray,
    idx_test: np.ndarray,
    node_scaler: RobustScaler,
    edge_scaler: RobustScaler,
    cfg: Dict[str, Any],
) -> Dict[str, Any]:
    # train labels for weights/sampler
    t_train = sample_t[idx_train]
    y_train = y_tb[t_train].astype(np.int64)

    tr_ds = LobGraphSequenceDataset3Class(X_scaled, edge_scaled, y_tb, exit_ret, sample_t, idx_train, cfg["lookback"])
    va_ds = LobGraphSequenceDataset3Class(X_scaled, edge_scaled, y_tb, exit_ret, sample_t, idx_val,   cfg["lookback"])
    te_ds = LobGraphSequenceDataset3Class(X_scaled, edge_scaled, y_tb, exit_ret, sample_t, idx_test,  cfg["lookback"])

    sampler = None
    shuffle = True
    if bool(cfg.get("use_weighted_sampler", True)):
        sampler = make_weighted_sampler_3class(y_train)
        shuffle = False

    tr_loader = DataLoader(tr_ds, batch_size=int(cfg["batch_size"]), shuffle=shuffle, sampler=sampler, collate_fn=collate_fn_3class, num_workers=0)
    va_loader = DataLoader(va_ds, batch_size=int(cfg["batch_size"]), shuffle=False, collate_fn=collate_fn_3class, num_workers=0)
    te_loader = DataLoader(te_ds, batch_size=int(cfg["batch_size"]), shuffle=False, collate_fn=collate_fn_3class, num_workers=0)

    model = GraphWaveNet3Class(
        node_in=int(X_scaled.shape[-1]),
        edge_dim=int(edge_scaled.shape[-1]),
        cfg=cfg,
        n_nodes=len(ASSETS),
        target_node=TARGET_NODE,
    ).to(DEVICE)

    ce_w = make_ce_weights_3class(y_train)
    loss_fn = nn.CrossEntropyLoss(weight=ce_w, label_smoothing=float(cfg.get("label_smoothing", 0.0)))

    opt = torch.optim.AdamW(model.parameters(), lr=float(cfg["lr"]), weight_decay=float(cfg["weight_decay"]))

    use_onecycle = bool(cfg.get("use_onecycle", True))
    if use_onecycle:
        sch = torch.optim.lr_scheduler.OneCycleLR(
            opt,
            max_lr=float(cfg["lr"]),
            epochs=int(cfg["epochs"]),
            steps_per_epoch=max(1, len(tr_loader)),
            pct_start=0.15,
            div_factor=10.0,
            final_div_factor=50.0,
        )
    else:
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=3)

    best_sel = -1e18
    best_state = None
    best_epoch = -1
    patience = 7
    bad = 0

    sel_w_dir = float(cfg.get("sel_metric_dir_weight", 0.5))
    trade_pen = float(cfg.get("trade_prob_penalty", 0.0))

    def _sel_metric(trade_auc: float, dir_auc: float) -> float:
        ta = float(trade_auc) if np.isfinite(trade_auc) else -1e18
        da = float(dir_auc) if np.isfinite(dir_auc) else 0.0
        return ta + sel_w_dir * da

    for ep in range(1, int(cfg["epochs"]) + 1):
        model.train()
        tot_loss = 0.0
        n = 0

        for x, e, y, _er, _sidx in tr_loader:
            x = x.to(DEVICE).float()
            e = e.to(DEVICE).float()
            y = y.to(DEVICE).long()

            opt.zero_grad(set_to_none=True)

            logits, aux = model(x, e, return_aux=True)
            ce = loss_fn(logits, y)

            # optional: discourage overtrading probability mass
            if trade_pen > 0:
                p = torch.softmax(logits, dim=-1)
                p_trade = (p[:, 0] + p[:, 2]).mean()
                ce = ce + trade_pen * p_trade

            loss = total_loss_with_adj_reg(ce, aux, cfg)
            if not torch.isfinite(loss):
                continue

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            opt.step()
            if use_onecycle:
                sch.step()

            tot_loss += float(loss.item()) * int(y.size(0))
            n += int(y.size(0))

        tr_loss = tot_loss / max(1, n)

        # quick val metrics (AUCs) each epoch
        model.eval()
        va_probs = []
        va_ys = []
        for x, e, y, _er, _sidx in va_loader:
            x = x.to(DEVICE).float()
            e = e.to(DEVICE).float()
            logits = model(x, e, return_aux=False)
            p = torch.softmax(logits, dim=-1).detach().cpu().numpy()
            va_probs.append(p)
            va_ys.append(y.numpy())

        va_prob3 = np.concatenate(va_probs, axis=0) if va_probs else np.zeros((0, 3), dtype=np.float64)
        va_y = np.concatenate(va_ys, axis=0) if va_ys else np.zeros((0,), dtype=np.int64)
        trade_auc, dir_auc = compute_trade_dir_auc_from_probs(va_y, va_prob3)
        sel = _sel_metric(trade_auc, dir_auc)

        if sel > best_sel:
            best_sel = sel
            best_epoch = ep
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1

        if not use_onecycle:
            sch.step(sel)

        lr_now = opt.param_groups[0]["lr"]
        w_support = model.support_mix().detach().cpu().numpy().tolist()
        print(
            f"[fold {fold_id:02d}] ep {ep:02d} lr={lr_now:.2e} "
            f"tr_loss={tr_loss:.4f} val_trade_auc={trade_auc:.3f} val_dir_auc={dir_auc:.3f} "
            f"sel={sel:.3f} best={best_sel:.3f}@ep{best_epoch:02d} supports={np.round(w_support, 3).tolist()}"
        )

        if bad >= patience:
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    # final eval on val + test (and threshold selection on val only)
    val_eval = eval_3class_on_indices(model, X_scaled, edge_scaled, idx_val, loss_fn, cfg)
    test_eval = eval_3class_on_indices(model, X_scaled, edge_scaled, idx_test, loss_fn, cfg)

    # thresholds chosen on VAL for PnL, then applied to TEST
    true_val_trade_rate = split_trade_ratio(idx_val, sample_t, y_trade)
    sweep_val = sweep_thresholds_3class(
        prob3=val_eval["prob3"],
        exit_ret_arr=val_eval["er"],
        cfg=cfg,
        min_trades=int(cfg["eval_min_trades"]),
        target_trade_rate=float(true_val_trade_rate),
    )
    best_thr = sweep_val.iloc[0].to_dict()
    thr_trade = float(best_thr["thr_trade"])
    thr_dir = float(best_thr["thr_dir"])

    pnl_val = pnl_from_probs_3class(val_eval["prob3"], val_eval["er"], thr_trade, thr_dir, cfg["cost_bps"])
    pnl_test = pnl_from_probs_3class(test_eval["prob3"], test_eval["er"], thr_trade, thr_dir, cfg["cost_bps"])

    print(
        f"[fold {fold_id:02d}] chosen thresholds on VAL: thr_trade={thr_trade:.3f} thr_dir={thr_dir:.3f} "
        f"| val pnl_sum={pnl_val['pnl_sum']:.4f} val trade_rate={pnl_val['trade_rate']:.3f}"
    )
    print(
        f"[fold {fold_id:02d}] TEST (fixed thresholds from VAL): "
        f"trade_auc={test_eval['trade_auc']:.3f} dir_auc={test_eval['dir_auc']:.3f} "
        f"pnl_sum={pnl_test['pnl_sum']:.4f} trade_rate={pnl_test['trade_rate']:.3f} trades={pnl_test['n_trades']}"
    )

    return {
        "fold": int(fold_id),
        "model_state": {k: v.detach().cpu().clone() for k, v in model.state_dict().items()},
        "cfg": cfg,
        "node_scaler": node_scaler,
        "edge_scaler": edge_scaler,
        "idx_train": idx_train,
        "idx_val": idx_val,
        "idx_test": idx_test,
        "best_epoch": int(best_epoch),
        "best_sel": float(best_sel),
        "val_eval": val_eval,
        "test_eval": test_eval,
        "thr_trade": thr_trade,
        "thr_dir": thr_dir,
        "pnl_val": pnl_val,
        "pnl_test": pnl_test,
        "sweep_val_head": sweep_val.head(5),
    }


def save_checkpoint(path: Path, payload: Dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(payload, str(path))


def load_checkpoint(path: Path) -> Dict[str, Any]:
    return torch.load(str(path), map_location="cpu", weights_only=False)


def run_walk_forward_cv_3class() -> Tuple[pd.DataFrame, List[Dict[str, Any]], Path]:
    fold_artifacts = []
    rows = []

    best_overall_sel = -1e18
    best_overall_path = None

    for fi, (idx_tr, idx_va, idx_te) in enumerate(walk_splits, 1):
        print("\n" + "=" * 80)
        print(f"FOLD {fi}/{len(walk_splits)} sizes: train={len(idx_tr)} val={len(idx_va)} test={len(idx_te)}")
        print(f"True trade ratio (val):  {split_trade_ratio(idx_va, sample_t, y_trade):.3f}")
        print(f"True trade ratio (test): {split_trade_ratio(idx_te, sample_t, y_trade):.3f}")

        # fold scaling (fit only on fold train timeline)
        X_scaled, node_scaler = fit_scale_nodes_train_only(X_node_raw, sample_t, idx_tr, max_abs=CFG["max_abs_feat"])
        if bool(CFG.get("edge_scale", True)):
            edge_scaled, edge_scaler = fit_scale_edges_train_only(edge_feat, sample_t, idx_tr, max_abs=CFG["max_abs_edge"])
        else:
            edge_scaled = edge_feat.astype(np.float32)
            edge_scaler = None

        artifact = train_one_fold_3class(
            fold_id=fi,
            X_scaled=X_scaled,
            edge_scaled=edge_scaled,
            idx_train=idx_tr,
            idx_val=idx_va,
            idx_test=idx_te,
            node_scaler=node_scaler,
            edge_scaler=edge_scaler,
            cfg=CFG,
        )

        # save fold checkpoint
        ckpt_path = CFG["ckpt_dir"] / f"fold_{fi:02d}_best.pt"
        save_checkpoint(ckpt_path, {
            "kind": "fold_best",
            "fold": fi,
            "model_state": artifact["model_state"],
            "cfg": dict(CFG),
            "node_scaler": artifact["node_scaler"],
            "edge_scaler": artifact["edge_scaler"],
            "thr_trade": artifact["thr_trade"],
            "thr_dir": artifact["thr_dir"],
            "idx_train": artifact["idx_train"],
            "idx_val": artifact["idx_val"],
            "idx_test": artifact["idx_test"],
        })
        print("Saved fold checkpoint:", str(ckpt_path))

        # track overall best by selection metric
        if float(artifact["best_sel"]) > best_overall_sel:
            best_overall_sel = float(artifact["best_sel"])
            best_overall_path = ckpt_path

        fold_artifacts.append(artifact)

        rows.append({
            "fold": fi,
            "val_trade_auc": artifact["val_eval"]["trade_auc"],
            "val_dir_auc": artifact["val_eval"]["dir_auc"],
            "test_trade_auc": artifact["test_eval"]["trade_auc"],
            "test_dir_auc": artifact["test_eval"]["dir_auc"],
            "thr_trade": artifact["thr_trade"],
            "thr_dir": artifact["thr_dir"],
            "test_trade_rate_pred": artifact["pnl_test"]["trade_rate"],
            "test_pnl_sum": artifact["pnl_test"]["pnl_sum"],
            "test_pnl_mean": artifact["pnl_test"]["pnl_mean"],
            "test_n_trades": artifact["pnl_test"]["n_trades"],
            "best_sel": artifact["best_sel"],
        })

    cv_summary = pd.DataFrame(rows)
    assert best_overall_path is not None
    overall_copy = CFG["ckpt_dir"] / "overall_best.pt"
    save_checkpoint(overall_copy, {**load_checkpoint(best_overall_path), "kind": "overall_best", "source_ckpt": str(best_overall_path)})
    print("\nSaved overall best checkpoint:", str(overall_copy))
    return cv_summary, fold_artifacts, overall_copy


cv_summary_3c, fold_artifacts_3c, overall_best_ckpt = run_walk_forward_cv_3class()

print("\n" + "=" * 80)
print("CV summary (3-class model; TEST uses thresholds selected on VAL):")
print(cv_summary_3c)
print("\nMeans (debug only):")
print(cv_summary_3c.mean(numeric_only=True))



FOLD 1/4 sizes: train=5652 val=1130 test=1130
True trade ratio (val):  0.365
True trade ratio (test): 0.304
[fold 01] ep 01 lr=9.84e-05 tr_loss=1.1131 val_trade_auc=0.476 val_dir_auc=0.509 sel=0.730 best=0.730@ep01 supports=[0.333, 0.333, 0.333]
[fold 01] ep 02 lr=2.34e-04 tr_loss=0.9962 val_trade_auc=0.473 val_dir_auc=0.529 sel=0.737 best=0.737@ep02 supports=[0.334, 0.333, 0.334]
[fold 01] ep 03 lr=3.00e-04 tr_loss=0.9589 val_trade_auc=0.476 val_dir_auc=0.519 sel=0.735 best=0.737@ep02 supports=[0.334, 0.333, 0.334]
[fold 01] ep 04 lr=2.97e-04 tr_loss=0.9220 val_trade_auc=0.485 val_dir_auc=0.534 sel=0.751 best=0.751@ep04 supports=[0.335, 0.33, 0.335]
[fold 01] ep 05 lr=2.90e-04 tr_loss=0.8931 val_trade_auc=0.493 val_dir_auc=0.580 sel=0.783 best=0.783@ep05 supports=[0.337, 0.326, 0.337]
[fold 01] ep 06 lr=2.77e-04 tr_loss=0.8848 val_trade_auc=0.489 val_dir_auc=0.591 sel=0.785 best=0.785@ep06 supports=[0.339, 0.321, 0.339]
[fold 01] ep 07 lr=2.61e-04 tr_loss=0.8391 val_trade_auc=0.500 v

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL pathlib.PosixPath was not an allowed global by default. Please use `torch.serialization.add_safe_globals([pathlib.PosixPath])` or the `torch.serialization.safe_globals([pathlib.PosixPath])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [17]:

def load_checkpoint(path: Path) -> Dict[str, Any]:
    return torch.load(str(path), map_location="cpu", weights_only=False)

overall_best_ckpt = CFG["ckpt_dir"] / "overall_best.pt"
best_overall_path = CFG["ckpt_dir"] / "fold_04_best.pt"

save_checkpoint(overall_best_ckpt, {**load_checkpoint(best_overall_path), "kind": "overall_best", "source_ckpt": str(best_overall_path)})


In [18]:
# Step 5: evaluation (final holdout) + production-fit  (NEW)
# ======================================================================

def apply_node_scaler(X_raw: np.ndarray, scaler: RobustScaler, max_abs: float) -> np.ndarray:
    Fdim = X_raw.shape[-1]
    X_scaled = scaler.transform(X_raw.reshape(-1, Fdim)).reshape(X_raw.shape).astype(np.float32)
    X_scaled = np.clip(X_scaled, -max_abs, max_abs).astype(np.float32)
    return np.nan_to_num(X_scaled, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)


def apply_edge_scaler(E_raw: np.ndarray, scaler: RobustScaler, max_abs: float) -> np.ndarray:
    D = E_raw.shape[-1]
    E_scaled = scaler.transform(E_raw.reshape(-1, D)).reshape(E_raw.shape).astype(np.float32)
    E_scaled = np.clip(E_scaled, -max_abs, max_abs).astype(np.float32)
    return np.nan_to_num(E_scaled, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)


def build_model_from_ckpt(ckpt: Dict[str, Any]) -> nn.Module:
    cfg = ckpt["cfg"]
    model = GraphWaveNet3Class(
        node_in=int(X_node_raw.shape[-1]),
        edge_dim=int(edge_feat.shape[-1]),
        cfg=cfg,
        n_nodes=len(ASSETS),
        target_node=TARGET_NODE,
    ).to(DEVICE)
    model.load_state_dict(ckpt["model_state"])
    model.eval()
    return model


@torch.no_grad()
def evaluate_checkpoint_on_indices(ckpt_path: Path, indices: np.ndarray, label: str) -> Dict[str, Any]:
    ckpt = load_checkpoint(ckpt_path)
    model = build_model_from_ckpt(ckpt)

    node_scaler = ckpt["node_scaler"]
    edge_scaler = ckpt["edge_scaler"]

    X_scaled = apply_node_scaler(X_node_raw, node_scaler, max_abs=float(CFG["max_abs_feat"]))
    if edge_scaler is not None:
        E_scaled = apply_edge_scaler(edge_feat, edge_scaler, max_abs=float(CFG["max_abs_edge"]))
    else:
        E_scaled = edge_feat.astype(np.float32)

    # loss fn for reporting
    t_train = sample_t[np.asarray(ckpt["idx_train"], dtype=np.int64)]
    y_train = y_tb[t_train].astype(np.int64)
    ce_w = make_ce_weights_3class(y_train)
    loss_fn = nn.CrossEntropyLoss(weight=ce_w, label_smoothing=float(CFG.get("label_smoothing", 0.0)))

    ev = eval_3class_on_indices(model, X_scaled, E_scaled, indices.astype(np.int64), loss_fn, CFG)

    pnl = pnl_from_probs_3class(ev["prob3"], ev["er"], float(ckpt["thr_trade"]), float(ckpt["thr_dir"]), float(CFG["cost_bps"]))

    print("\n" + "=" * 80)
    print(f"{label}")
    print(f"ckpt: {str(ckpt_path)}")
    print(f"trade_auc={ev['trade_auc']:.3f} | dir_auc={ev['dir_auc']:.3f}")
    print(f"pnl_sum={pnl['pnl_sum']:.4f} | trade_rate={pnl['trade_rate']:.3f} | trades={pnl['n_trades']}")
    return {"eval": ev, "pnl": pnl}


# Evaluate overall best (from CV) on FINAL holdout (10%) without refit
holdout_indices = idx_final_test.astype(np.int64)
_ = evaluate_checkpoint_on_indices(overall_best_ckpt, holdout_indices, label="FINAL HOLDOUT (10%) using overall_best.pt (no refit)")

# Production-fit: train on CV(90%) with final val window; select thresholds on val_final; evaluate on holdout (10%)
def run_production_fit_3class() -> Dict[str, Any]:
    print("\n" + "=" * 80)
    print("PRODUCTION-FIT (train on CV(90%) -> select thresholds on val_final -> eval on FINAL holdout(10%))")

    val_w = max(1, int(CFG["val_window_frac"] * n_samples_cv))
    train_end = n_samples_cv - val_w

    idx_train_final = np.arange(0, train_end, dtype=np.int64)
    idx_val_final = np.arange(train_end, n_samples_cv, dtype=np.int64)
    idx_holdout = idx_final_test.astype(np.int64)

    print("Sizes:")
    print("  train_final:", len(idx_train_final))
    print("  val_final  :", len(idx_val_final))
    print("  holdout    :", len(idx_holdout))
    print(f"True trade ratio (val_final): {split_trade_ratio(idx_val_final, sample_t, y_trade):.3f}")
    print(f"True trade ratio (holdout):   {split_trade_ratio(idx_holdout, sample_t, y_trade):.3f}")

    X_scaled, node_scaler = fit_scale_nodes_train_only(X_node_raw, sample_t, idx_train_final, max_abs=CFG["max_abs_feat"])
    if bool(CFG.get("edge_scale", True)):
        edge_scaled, edge_scaler = fit_scale_edges_train_only(edge_feat, sample_t, idx_train_final, max_abs=CFG["max_abs_edge"])
    else:
        edge_scaled = edge_feat.astype(np.float32)
        edge_scaler = None

    artifact = train_one_fold_3class(
        fold_id=99,  # production id
        X_scaled=X_scaled,
        edge_scaled=edge_scaled,
        idx_train=idx_train_final,
        idx_val=idx_val_final,
        idx_test=idx_holdout,
        node_scaler=node_scaler,
        edge_scaler=edge_scaler,
        cfg=CFG,
    )

    ckpt_path = CFG["ckpt_dir"] / "production_best.pt"
    save_checkpoint(ckpt_path, {
        "kind": "production_best",
        "model_state": artifact["model_state"],
        "cfg": dict(CFG),
        "node_scaler": artifact["node_scaler"],
        "edge_scaler": artifact["edge_scaler"],
        "thr_trade": artifact["thr_trade"],
        "thr_dir": artifact["thr_dir"],
        "idx_train": artifact["idx_train"],
        "idx_val": artifact["idx_val"],
        "idx_test": artifact["idx_test"],  # holdout
    })
    print("Saved production checkpoint:", str(ckpt_path))

    print("\nPRODUCTION FINAL HOLDOUT RESULT:")
    print(f"trade_auc={artifact['test_eval']['trade_auc']:.3f} | dir_auc={artifact['test_eval']['dir_auc']:.3f}")
    print(f"trade_rate={artifact['pnl_test']['trade_rate']:.3f} | pnl_sum={artifact['pnl_test']['pnl_sum']:.4f} | trades={artifact['pnl_test']['n_trades']}")

    return {"artifact": artifact, "ckpt_path": ckpt_path}


prod_out = run_production_fit_3class()



FINAL HOLDOUT (10%) using overall_best.pt (no refit)
ckpt: checkpoints_gwnet_3class/overall_best.pt
trade_auc=0.534 | dir_auc=0.521
pnl_sum=-0.1026 | trade_rate=0.613 | trades=770

PRODUCTION-FIT (train on CV(90%) -> select thresholds on val_final -> eval on FINAL holdout(10%))
Sizes:
  train_final: 10175
  val_final  : 1130
  holdout    : 1256
True trade ratio (val_final): 0.504
True trade ratio (holdout):   0.576
[fold 99] ep 01 lr=9.80e-05 tr_loss=1.1010 val_trade_auc=0.545 val_dir_auc=0.511 sel=0.800 best=0.800@ep01 supports=[0.333, 0.334, 0.333]
[fold 99] ep 02 lr=2.34e-04 tr_loss=1.0059 val_trade_auc=0.616 val_dir_auc=0.556 sel=0.894 best=0.894@ep02 supports=[0.333, 0.334, 0.333]
[fold 99] ep 03 lr=3.00e-04 tr_loss=0.9740 val_trade_auc=0.627 val_dir_auc=0.559 sel=0.906 best=0.906@ep03 supports=[0.333, 0.334, 0.333]
[fold 99] ep 04 lr=2.97e-04 tr_loss=0.9573 val_trade_auc=0.626 val_dir_auc=0.565 sel=0.908 best=0.908@ep04 supports=[0.335, 0.329, 0.335]
[fold 99] ep 05 lr=2.90e-04 

In [19]:
# Step 6: checkpoint save/load utilities + "test-only from checkpoint"  (NEW)
# ======================================================================

# Edit this path to any saved checkpoint:
TEST_ONLY_CKPT_PATH = prod_out["ckpt_path"]  # e.g. CFG["ckpt_dir"] / "production_best.pt"

# Choose what to evaluate:
# - For fold checkpoints: use the saved idx_test (fold test)
# - For production checkpoint: idx_test is the FINAL holdout
ckpt_loaded = load_checkpoint(TEST_ONLY_CKPT_PATH)
idx_eval = np.asarray(ckpt_loaded["idx_test"], dtype=np.int64)

_ = evaluate_checkpoint_on_indices(TEST_ONLY_CKPT_PATH, idx_eval, label="TEST-ONLY FROM CHECKPOINT")



TEST-ONLY FROM CHECKPOINT
ckpt: checkpoints_gwnet_3class/production_best.pt
trade_auc=0.503 | dir_auc=0.570
pnl_sum=0.0915 | trade_rate=0.205 | trades=258


In [20]:
# Step 7: debug / sanity checks  (NEW)
# ======================================================================

# One forward pass shape sanity check
with torch.no_grad():
    # use first 2 samples from CV space
    idx_dbg = np.arange(0, min(2, len(idx_cv_all)), dtype=np.int64)

    # scale using a simple train-only scaler on a small prefix (debug only)
    idx_tr_dbg = np.arange(0, max(200, int(0.2 * len(idx_cv_all))), dtype=np.int64)
    X_dbg, _scn = fit_scale_nodes_train_only(X_node_raw, sample_t, idx_tr_dbg, max_abs=CFG["max_abs_feat"])
    E_dbg, _sce = fit_scale_edges_train_only(edge_feat, sample_t, idx_tr_dbg, max_abs=CFG["max_abs_edge"])

    ds_dbg = LobGraphSequenceDataset3Class(X_dbg, E_dbg, y_tb, exit_ret, sample_t, idx_dbg, CFG["lookback"])
    x_seq, e_seq, y0, er0, sidx0 = ds_dbg[0]

    print("Single sample shapes:")
    print("  x_seq:", tuple(x_seq.shape), "(L,N,F)")
    print("  e_seq:", tuple(e_seq.shape), "(L,E,D)")
    print("  y_tb:", int(y0.item()), "->", CLASS_NAMES[int(y0.item())])
    print("  exit_ret:", float(er0.item()), "sidx:", int(sidx0.item()))

    m_dbg = GraphWaveNet3Class(
        node_in=int(X_node_raw.shape[-1]),
        edge_dim=int(edge_feat.shape[-1]),
        cfg=CFG,
        n_nodes=len(ASSETS),
        target_node=TARGET_NODE,
    ).to(DEVICE)

    xb = x_seq.unsqueeze(0).to(DEVICE).float()  # (1,L,N,F)
    eb = e_seq.unsqueeze(0).to(DEVICE).float()  # (1,L,E,D)
    logits, aux = m_dbg(xb, eb, return_aux=True)

    print("\nModel forward sanity:")
    print("  logits:", logits.shape, "expected (B,3)")
    print("  logits finite:", bool(torch.isfinite(logits).all().item()))
    print("  support weights:", aux["support_w"])
    print("  reg l1_off:", aux["l1_off"], "mse_prior:", aux["mse_prior"])


Single sample shapes:
  x_seq: (240, 3, 15) (L,N,F)
  e_seq: (240, 9, 20) (L,E,D)
  y_tb: 0 -> SHORT
  exit_ret: -0.008445847779512405 sidx: 0

Model forward sanity:
  logits: torch.Size([1, 3]) expected (B,3)
  logits finite: True
  support weights: [0.3333333432674408, 0.3333333432674408, 0.3333333432674408]
  reg l1_off: 0.33333662152290344 mse_prior: 0.005591095890849829
