## Step 0

In [1]:
# ЛОГИЧЕСКИЙ БЛОК: imports + reproducibility + GLOBAL config
# ИСПОЛНЕНИЕ БЛОКА:

import os, math, random
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import RobustScaler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, roc_auc_score

def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(100)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

# -------------------------------
# GLOBAL CONFIG (всё тут)
# -------------------------------
CFG = {
    # data
    "freq": "5min",
    "data_dir": Path("../dataset"),  
    # NEW: holdout final test split (по времени, на sample-space)
    "final_test_frac": 0.10, 

    "book_levels": 15,         # сколько уровней стакана грузим
    "top_levels": 5,           # DI_L0..DI_L4
    "near_levels": 5,          # near=0..4, far=5..14

    # walk-forward windows (в sample-space)
    "train_min_frac": 0.50,
    "val_window_frac": 0.10,
    "test_window_frac": 0.10,
    "step_window_frac": 0.10,

    # scaling
    "max_abs_feat": 10.0,

    # correlations
    "corr_windows": [6, 12, 24, 48, 84],  # 30m,1h,2h,4h,7h
    "edges": [("ADA","BTC"), ("ADA","ETH"), ("ETH","BTC")],

    # triple-barrier (labels)
    "tb_horizon": 1*12,       # 1h     # нужен для sample_t (чтобы TB-exit не вылезал за конец)
    "lookback": 7*12,   
    "tb_pt_mult": 1.2,
    "tb_sl_mult": 1.1,
    "tb_min_barrier": 0.001,
    "tb_max_barrier": 0.006,
    # training (общие)
    "batch_size": 64,
    "epochs": 30,
    "lr": 2e-4,
    "weight_decay": 1e-3,
    "grad_clip": 1.0,
    "dropout": 0.2,
    "hidden": 64,
    "gnn_layers": 2,
    "lstm_hidden": 64,
    "lstm_layers": 1,
    "use_amp": True,

    "temporal_mode": "transformer",   # "transformer" | "attnpool"

    # transformer params
    "n_heads": 4,          # 2-4 обычно ок при hidden=64
    "n_layers": 1,         # 1-2 слоя
    "attn_dropout": 0.2,   # dropout внутри attention/ffn трансформера
    "use_posenc": True,    # positional encoding on/off

    # trading eval
    "cost_bps": 2.0,

    # confidence thresholds (для PnL по порогу)
    "thr_trade_grid": [0.50, 0.55, 0.60, 0.65, 0.70],
    "thr_dir_grid":   [0.50, 0.55, 0.60, 0.65, 0.70],

    # ---- PnL proxy during DIR training (grid selector)
    # можно сделать уже/шире, но по умолчанию переиспользуем thr_*_grid
    "proxy_thr_trade_grid": None,  # None -> использовать thr_trade_grid
    "proxy_thr_dir_grid":   None,  # None -> использовать thr_dir_grid
    "proxy_min_trades": 50,        # защита от "лучший pnl = 0 потому что 0 трейдов"
}

ASSETS = ["ADA", "BTC", "ETH"]
ASSET2IDX = {a:i for i,a in enumerate(ASSETS)}
TARGET_ASSET = "ETH"
TARGET_NODE = ASSET2IDX[TARGET_ASSET]

EDGES = CFG["edges"]
EDGE_INDEX = torch.tensor([[ASSET2IDX[s], ASSET2IDX[t]] for (s,t) in EDGES], dtype=torch.long)  # [E,2]
print("EDGE_INDEX:", EDGE_INDEX.tolist())


DEVICE: cpu
EDGE_INDEX: [[0, 1], [0, 2], [2, 1]]


## 1. load data + basic returns

In [2]:
# ЛОГИЧЕСКИЙ БЛОК: load data + log returns (без target) + все уровни стакана
# ИСПОЛНЕНИЕ БЛОКА:

def load_asset(asset: str, freq: str, data_dir: Path, book_levels: int, part = [0,100]) -> pd.DataFrame:
    path = data_dir / f"{asset}_{freq}.csv"
    df = pd.read_csv(path)
    df = df.iloc[int(len(df)*part[0]/100) : int(len(df)*part[1]/100)]
    df["timestamp"] = pd.to_datetime(df["system_time"]).dt.round("min")
    df = df.sort_values("timestamp").set_index("timestamp")

    bid_cols = [f"bids_notional_{i}" for i in range(book_levels)]
    ask_cols = [f"asks_notional_{i}" for i in range(book_levels)]

    needed = ["midpoint", "spread", "buys", "sells"] + bid_cols + ask_cols
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise ValueError(f"{asset}: missing columns in CSV: {missing[:10]}{'...' if len(missing) > 10 else ''}")

    return df[needed]


def load_all_assets() -> pd.DataFrame:
    freq = CFG["freq"]
    data_dir = CFG["data_dir"]
    book_levels = CFG["book_levels"]

    def rename_asset_cols(df_one: pd.DataFrame, asset: str) -> pd.DataFrame:
        rename_map = {
            "midpoint": asset,
            "buys": f"buys_{asset}",
            "sells": f"sells_{asset}",
            "spread": f"spread_{asset}",
        }
        for i in range(book_levels):
            rename_map[f"bids_notional_{i}"] = f"bids_vol_{asset}_{i}"
            rename_map[f"asks_notional_{i}"] = f"asks_vol_{asset}_{i}"
        return df_one.rename(columns=rename_map)

    df_ADA = rename_asset_cols(load_asset("ADA", freq, data_dir, book_levels, part = [0, 80]), "ADA")
    df_BTC = rename_asset_cols(load_asset("BTC", freq, data_dir, book_levels, part = [0, 80]), "BTC")
    df_ETH = rename_asset_cols(load_asset("ETH", freq, data_dir, book_levels, part = [0, 80]), "ETH")

    df = df_ADA.join(df_BTC).join(df_ETH)
    df = df.reset_index()  # timestamp column remains
    return df


df = load_all_assets()
T = len(df)

# log returns
for a in ASSETS:
    df[f"lr_{a}"] = np.log(df[a]).diff().fillna(0.0)

print("Loaded df:", df.shape)
print("Example columns:", df.columns[:25].tolist())


Loaded df: (2693, 106)
Example columns: ['timestamp', 'ADA', 'spread_ADA', 'buys_ADA', 'sells_ADA', 'bids_vol_ADA_0', 'bids_vol_ADA_1', 'bids_vol_ADA_2', 'bids_vol_ADA_3', 'bids_vol_ADA_4', 'bids_vol_ADA_5', 'bids_vol_ADA_6', 'bids_vol_ADA_7', 'bids_vol_ADA_8', 'bids_vol_ADA_9', 'bids_vol_ADA_10', 'bids_vol_ADA_11', 'bids_vol_ADA_12', 'bids_vol_ADA_13', 'bids_vol_ADA_14', 'asks_vol_ADA_0', 'asks_vol_ADA_1', 'asks_vol_ADA_2', 'asks_vol_ADA_3', 'asks_vol_ADA_4']


## 2. multi-window correlations → edge features (T,E,W)

In [3]:
# ЛОГИЧЕСКИЙ БЛОК: multi-window correlations -> corr_array (T,E,W)
# ИСПОЛНЕНИЕ БЛОКА:

candidate_windows = CFG["corr_windows"]
edges = EDGES

n_w = len(candidate_windows)
n_edges = len(edges)
T = len(df)

corr_array = np.zeros((T, n_edges, n_w), dtype=np.float32)

for wi, w in enumerate(candidate_windows):
    r_ADA_BTC = df["lr_ADA"].rolling(w, min_periods=1).corr(df["lr_BTC"])
    r_ADA_ETH = df["lr_ADA"].rolling(w, min_periods=1).corr(df["lr_ETH"])
    r_ETH_BTC = df["lr_ETH"].rolling(w, min_periods=1).corr(df["lr_BTC"])

    corr_array[:, 0, wi] = np.nan_to_num(r_ADA_BTC)
    corr_array[:, 1, wi] = np.nan_to_num(r_ADA_ETH)
    corr_array[:, 2, wi] = np.nan_to_num(r_ETH_BTC)

print("corr_array shape:", corr_array.shape)  # (T,E,W)


corr_array shape: (2693, 3, 5)


## 3. triple-barrier → y_tb + exit_ret → two-stage labels

In [4]:
# ЛОГИЧЕСКИЙ БЛОК: triple-barrier labels -> y_tb + exit_ret + two-stage labels
# ИСПОЛНЕНИЕ БЛОКА:

def triple_barrier_labels_from_lr(
    lr: pd.Series,
    horizon: int,
    vol_window: int,
    pt_mult: float,
    sl_mult: float,
    min_barrier: float,
    max_barrier: float,
):
    """
    Returns:
      y_tb: {0=down, 1=flat/no-trade, 2=up}
      exit_ret: realized log-return to exit (tp/sl/timeout)
      exit_t: exit index
      thr: barrier per t
    No leakage: vol is shift(1).
    """
    lr = lr.astype(float).copy()
    T = len(lr)

    vol = lr.rolling(vol_window, min_periods=max(10, vol_window//10)).std().shift(1)
    thr = (vol * np.sqrt(horizon)).clip(lower=min_barrier, upper=max_barrier)

    y = np.ones(T, dtype=np.int64)
    exit_ret = np.zeros(T, dtype=np.float32)
    exit_t = np.arange(T, dtype=np.int64)

    lr_np = lr.fillna(0.0).to_numpy(dtype=np.float64)
    thr_np = thr.fillna(min_barrier).to_numpy(dtype=np.float64)

    for t in range(T - horizon - 1):
        up = pt_mult * thr_np[t]
        dn = -sl_mult * thr_np[t]

        cum = 0.0
        hit = 1
        et = t + horizon
        er = 0.0

        for dt in range(1, horizon + 1):
            cum += lr_np[t + dt]
            if cum >= up:
                hit = 2
                et = t + dt
                er = cum
                break
            if cum <= dn:
                hit = 0
                et = t + dt
                er = cum
                break

        if hit == 1:
            er = float(np.sum(lr_np[t+1:t+horizon+1]))
            et = t + horizon

        y[t] = hit
        exit_ret[t] = er
        exit_t[t] = et

    return y, exit_ret, exit_t, thr_np

# --- build TB on ETH ---
y_tb, exit_ret, exit_t, thr = triple_barrier_labels_from_lr(
    df["lr_ETH"],
    horizon=1*12, 
    vol_window=7*12,
    pt_mult=1.2,
    sl_mult=1.1,
    min_barrier=0.001,
    max_barrier=0.006,
)

# two-stage labels
y_trade = (y_tb != 1).astype(np.int64)      # 1=trade, 0=no-trade
y_dir   = (y_tb == 2).astype(np.int64)      # 1=up, 0=down (для trade-сэмплов)

print("TB dist [down,flat,up]:", np.bincount(y_tb, minlength=3))
print("Trade ratio:", y_trade.mean())


TB dist [down,flat,up]: [ 655 1311  727]
Trade ratio: 0.5131823245451169


## 4. build node tensor + edge tensor + sample_t

In [5]:
# ЛОГИЧЕСКИЙ БЛОК: build node features (T,N,F) + edge features (T,E,W) + sample_t
# ИСПОЛНЕНИЕ БЛОКА:

EPS = 1e-6

def safe_log1p(x: np.ndarray) -> np.ndarray:
    return np.log1p(np.maximum(x, 0.0))

def build_node_tensor(df: pd.DataFrame):
    """
    Фичи на asset:
      lr, spread,
      log_buys, log_sells, ofi,
      DI_15,
      DI_L0..DI_L4,
      near_ratio_bid, near_ratio_ask,
      di_near, di_far
    """
    feats = []
    feat_names = [
        "lr", "spread",
        "log_buys", "log_sells", "ofi",
        "DI_15",
        "DI_L0", "DI_L1", "DI_L2", "DI_L3", "DI_L4",
        "near_ratio_bid", "near_ratio_ask",
        "di_near", "di_far",
    ]

    book_levels = CFG["book_levels"]
    top_k = CFG["top_levels"]     # 5
    near_k = CFG["near_levels"]   # 5
    far_k = book_levels - near_k
    if far_k <= 0:
        raise ValueError("CFG['near_levels'] must be < CFG['book_levels']")

    for a in ASSETS:
        lr = df[f"lr_{a}"].values.astype(np.float32)
        spread = df[f"spread_{a}"].values.astype(np.float32)

        buys = df[f"buys_{a}"].values.astype(np.float32)
        sells = df[f"sells_{a}"].values.astype(np.float32)

        log_buys = safe_log1p(buys).astype(np.float32)
        log_sells = safe_log1p(sells).astype(np.float32)

        ofi = ((buys - sells) / (buys + sells + EPS)).astype(np.float32)

        # уровни стакана
        bids_lvls = np.stack([df[f"bids_vol_{a}_{i}"].values.astype(np.float32) for i in range(book_levels)], axis=1)  # (T,15)
        asks_lvls = np.stack([df[f"asks_vol_{a}_{i}"].values.astype(np.float32) for i in range(book_levels)], axis=1)  # (T,15)

        bid_sum_15 = bids_lvls.sum(axis=1)
        ask_sum_15 = asks_lvls.sum(axis=1)
        DI_15 = ((bid_sum_15 - ask_sum_15) / (bid_sum_15 + ask_sum_15 + EPS)).astype(np.float32)

        # DI_L0..DI_L4
        di_levels = []
        for i in range(top_k):
            b = bids_lvls[:, i]
            s = asks_lvls[:, i]
            di_levels.append(((b - s) / (b + s + EPS)).astype(np.float32))
        DI_L0_4 = np.stack(di_levels, axis=1)  # (T,5)

        # near vs far
        bid_near = bids_lvls[:, :near_k].sum(axis=1)
        ask_near = asks_lvls[:, :near_k].sum(axis=1)
        bid_far = bids_lvls[:, near_k:].sum(axis=1)
        ask_far = asks_lvls[:, near_k:].sum(axis=1)

        near_ratio_bid = (bid_near / (bid_far + EPS)).astype(np.float32)
        near_ratio_ask = (ask_near / (ask_far + EPS)).astype(np.float32)

        di_near = ((bid_near - ask_near) / (bid_near + ask_near + EPS)).astype(np.float32)
        di_far = ((bid_far - ask_far) / (bid_far + ask_far + EPS)).astype(np.float32)

        Xa = np.column_stack([
            lr, spread,
            log_buys, log_sells, ofi,
            DI_15,
            DI_L0_4[:, 0], DI_L0_4[:, 1], DI_L0_4[:, 2], DI_L0_4[:, 3], DI_L0_4[:, 4],
            near_ratio_bid, near_ratio_ask,
            di_near, di_far
        ]).astype(np.float32)

        feats.append(Xa)

    X = np.stack(feats, axis=1).astype(np.float32)  # (T,N,F)
    return X, feat_names


X_node_raw, node_feat_names = build_node_tensor(df)
edge_feat = np.nan_to_num(corr_array.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)

T = len(df)
L = CFG["lookback"]
H = CFG["tb_horizon"]

# sample_t: чтобы можно было брать окно [t-L+1 ... t] и иметь будущий TB-exit без выхода за данные
t_min = L - 1
t_max = T - H - 2
sample_t = np.arange(t_min, t_max + 1)
n_samples = len(sample_t)

print("X_node_raw:", X_node_raw.shape, "edge_feat:", edge_feat.shape)
print("node_feat_names:", node_feat_names)
print("n_samples:", n_samples, "t range:", sample_t[0], sample_t[-1])


X_node_raw: (2693, 3, 15) edge_feat: (2693, 3, 5)
node_feat_names: ['lr', 'spread', 'log_buys', 'log_sells', 'ofi', 'DI_15', 'DI_L0', 'DI_L1', 'DI_L2', 'DI_L3', 'DI_L4', 'near_ratio_bid', 'near_ratio_ask', 'di_near', 'di_far']
n_samples: 2597 t range: 83 2679


## Train (folds) - Test split

In [6]:
# ЛОГИЧЕСКИЙ БЛОК: final holdout split (90% CV + 10% final test), time-ordered
# ИСПОЛНЕНИЕ БЛОКА:

def make_final_holdout_split(n_samples: int, final_test_frac: float):
    if not (0.0 < final_test_frac < 0.5):
        raise ValueError("final_test_frac should be in (0, 0.5)")

    n_final = max(1, int(round(final_test_frac * n_samples)))
    n_cv = n_samples - n_final
    if n_cv <= 10:
        raise ValueError("Too few samples left for CV after holdout split.")

    idx_cv = np.arange(0, n_cv, dtype=np.int64)
    idx_final = np.arange(n_cv, n_samples, dtype=np.int64)
    return idx_cv, idx_final, n_cv, n_final

idx_cv_all, idx_final_test, n_samples_cv, n_samples_final = make_final_holdout_split(
    n_samples=n_samples,
    final_test_frac=CFG["final_test_frac"],
)

print("Holdout split:")
print("  n_samples total:", n_samples)
print("  n_samples CV   :", n_samples_cv, f"({100*(n_samples_cv/n_samples):.1f}%)")
print("  n_samples FINAL:", n_samples_final, f"({100*(n_samples_final/n_samples):.1f}%)")
print("  CV range   :", idx_cv_all[0], idx_cv_all[-1])
print("  FINAL range:", idx_final_test[0], idx_final_test[-1])


Holdout split:
  n_samples total: 2597
  n_samples CV   : 2337 (90.0%)
  n_samples FINAL: 260 (10.0%)
  CV range   : 0 2336
  FINAL range: 2337 2596



## 5. walk-forward splits (с глобальными окнами)

In [7]:
# ЛОГИЧЕСКИЙ БЛОК: walk-forward splits (expanding train + fixed val/test) on CV-part only
# ИСПОЛНЕНИЕ БЛОКА:

def make_walk_forward_splits(n_samples: int,
                             train_min_frac: float,
                             val_window_frac: float,
                             test_window_frac: float,
                             step_window_frac: float):
    train_min = int(train_min_frac * n_samples)
    val_w  = max(1, int(val_window_frac * n_samples))
    test_w = max(1, int(test_window_frac * n_samples))
    step_w = max(1, int(step_window_frac * n_samples))

    splits = []
    start = train_min
    while True:
        tr_end = start
        va_end = tr_end + val_w
        te_end = va_end + test_w
        if te_end > n_samples:
            break

        idx_train = np.arange(0, tr_end, dtype=np.int64)
        idx_val   = np.arange(tr_end, va_end, dtype=np.int64)
        idx_test  = np.arange(va_end, te_end, dtype=np.int64)

        splits.append((idx_train, idx_val, idx_test))
        start += step_w

    return splits

# IMPORTANT: строим сплиты только на 90% (CV-part)
walk_splits = make_walk_forward_splits(
    n_samples=n_samples_cv,
    train_min_frac=CFG["train_min_frac"],
    val_window_frac=CFG["val_window_frac"],
    test_window_frac=CFG["test_window_frac"],
    step_window_frac=CFG["step_window_frac"],
)

print("n_folds:", len(walk_splits))
for i, (a, b, c) in enumerate(walk_splits):
    print(f" fold {i+1}: train {len(a)} | val {len(b)} | test {len(c)}")

print("\nFINAL HOLDOUT:")
print(" final_test size:", len(idx_final_test))


n_folds: 4
 fold 1: train 1168 | val 233 | test 233
 fold 2: train 1401 | val 233 | test 233
 fold 3: train 1634 | val 233 | test 233
 fold 4: train 1867 | val 233 | test 233

FINAL HOLDOUT:
 final_test size: 260


## 6. Dataset + scaling 

In [8]:
# ЛОГИЧЕСКИЙ БЛОК: Dataset + scaling (shared)
# ИСПОЛНЕНИЕ БЛОКА:

class LobGraphSequenceDataset2Stage(Dataset):
    """
    Возвращает (x_seq, e_seq, y_trade, y_dir, exit_ret)
    y_dir корректен только когда y_trade==1, но мы возвращаем всегда.
    """
    def __init__(self, X_node, E_feat, y_trade, y_dir, exit_ret, sample_t, indices, lookback):
        self.X_node = X_node
        self.E_feat = E_feat
        self.y_trade = y_trade
        self.y_dir = y_dir
        self.exit_ret = exit_ret
        self.sample_t = sample_t
        self.indices = indices
        self.L = lookback

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        sidx = self.indices[i]
        t = self.sample_t[sidx]
        t0 = t - self.L + 1

        x_seq = self.X_node[t0:t+1]     # (L,N,F)
        e_seq = self.E_feat[t0:t+1]     # (L,E,W)

        yt = self.y_trade[t]
        yd = self.y_dir[t]
        er = self.exit_ret[t]

        return (
            torch.from_numpy(x_seq),
            torch.from_numpy(e_seq),
            torch.tensor(yt, dtype=torch.long),
            torch.tensor(yd, dtype=torch.long),
            torch.tensor(er, dtype=torch.float32),
        )

def collate_fn_2stage(batch):
    xs, es, yts, yds, ers = zip(*batch)
    return (
        torch.stack(xs, 0),   # (B,L,N,F)
        torch.stack(es, 0),   # (B,L,E,W)
        torch.stack(yts, 0),  # (B,)
        torch.stack(yds, 0),  # (B,)
        torch.stack(ers, 0),  # (B,)
    )

def fit_scale_nodes_train_only(X_node_raw, sample_t, idx_train, max_abs=10.0):
    """
    Fit scaler on all times up to last train sample time (без leakage).
    """
    last_train_t = sample_t[idx_train[-1]]
    train_time_mask = np.arange(0, last_train_t + 1)

    X_train_time = X_node_raw[train_time_mask]  # (Ttr,N,F)
    Ttr, N, Fdim = X_train_time.shape

    scaler = RobustScaler(with_centering=True, with_scaling=True, quantile_range=(5.0, 95.0))
    scaler.fit(X_train_time.reshape(-1, Fdim))

    X_scaled = scaler.transform(X_node_raw.reshape(-1, Fdim)).reshape(X_node_raw.shape).astype(np.float32)
    X_scaled = np.clip(X_scaled, -max_abs, max_abs).astype(np.float32)
    X_scaled = np.nan_to_num(X_scaled, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return X_scaled, scaler

def subset_trade_indices(indices, sample_t, y_trade):
    """
    indices в sample-space -> отфильтровать те, где y_trade[t]==1
    """
    tt = sample_t[indices]
    mask = (y_trade[tt] == 1)
    return indices[mask]


## 7.Model (один класс, n_classes=2) + EdgeGatedMP

In [9]:
# ЛОГИЧЕСКИЙ БЛОК: GNN + Attention Temporal Encoder classifier (универсальный под 2 класса)
# ИСПОЛНЕНИЕ БЛОКА:

import math
import torch
import torch.nn as nn


class EdgeGatedMP(nn.Module):
    def __init__(self, in_dim, hidden, edge_dim, dropout=0.1):
        super().__init__()
        self.node_proj = nn.Linear(in_dim, hidden)
        self.ln0 = nn.LayerNorm(hidden)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden + edge_dim, 2*hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(2*hidden, hidden + 1)  # msg(hidden) + gate(1)
        )

        self.upd = nn.Sequential(
            nn.Linear(2*hidden, 2*hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(2*hidden, hidden)
        )
        self.ln1 = nn.LayerNorm(hidden)
        self.dropout = nn.Dropout(dropout)

    def forward_once(self, x_t, edge_attr_t, edge_index):
        B, N, _ = x_t.shape
        E = edge_index.shape[0]

        h = self.ln0(self.node_proj(x_t))  # (B,N,H)
        h = torch.nan_to_num(h, nan=0.0, posinf=0.0, neginf=0.0)

        agg = torch.zeros((B, N, h.shape[-1]), device=h.device, dtype=h.dtype)

        for e in range(E):
            src = edge_index[e, 0].item()
            dst = edge_index[e, 1].item()
            h_src = h[:, src, :]
            h_dst = h[:, dst, :]
            ea = edge_attr_t[:, e, :]

            z = torch.cat([h_src, h_dst, ea], dim=-1)
            out = self.edge_mlp(z)
            msg = out[:, :-1]
            gate = torch.sigmoid(out[:, -1:])

            agg[:, dst, :] += msg * gate

        h2 = self.upd(torch.cat([h, agg], dim=-1))
        h2 = self.ln1(h + self.dropout(h2))
        h2 = torch.nan_to_num(h2, nan=0.0, posinf=0.0, neginf=0.0)
        return h2

    def forward(self, x_seq, e_seq, edge_index):
        B, L, N, Fin = x_seq.shape
        h_out = []
        for t in range(L):
            ht = self.forward_once(x_seq[:, t, :, :], e_seq[:, t, :, :], edge_index)
            h_out.append(ht)
        return torch.stack(h_out, dim=1)  # (B,L,N,H)


class SinCosPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0), persistent=False)  # (1,max_len,d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,L,D)
        L = x.size(1)
        if L > self.pe.size(1):
            # на всякий случай: если L больше max_len, просто без posenc
            return x
        return x + self.pe[:, :L, :].to(dtype=x.dtype, device=x.device)


class TemporalAttnPool(nn.Module):
    """
    score_t = v^T tanh(W h_t)
    weights = softmax(score_t)
    pooled = sum_t weights_t * h_t
    """
    def __init__(self, hidden: int, dropout: float = 0.1):
        super().__init__()
        self.W = nn.Linear(hidden, hidden, bias=True)
        self.v = nn.Parameter(torch.empty(hidden))
        nn.init.normal_(self.v, std=0.02)
        self.dropout = nn.Dropout(dropout)

    def forward(self, h_seq: torch.Tensor):
        # h_seq: (B,L,H)
        u = torch.tanh(self.W(h_seq))                  # (B,L,H)
        score = torch.matmul(u, self.v)                # (B,L)
        w = torch.softmax(score, dim=1)                # (B,L)
        pooled = torch.sum(h_seq * w.unsqueeze(-1), dim=1)  # (B,H)
        pooled = self.dropout(pooled)
        return pooled, w


class TemporalTransformerEncoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        n_layers: int,
        attn_dropout: float,
        use_posenc: bool,
        max_len: int = 512,
        pool: str = "last",  # "last" | "attnpool"
        posenc_learnable: bool = False,
    ):
        super().__init__()
        self.use_posenc = bool(use_posenc)
        self.pool = pool

        if self.use_posenc:
            if posenc_learnable:
                self.pos_emb = nn.Embedding(max_len, d_model)
                self.pos_drop = nn.Dropout(attn_dropout)
            else:
                self.pos_enc = SinCosPositionalEncoding(d_model=d_model, max_len=max_len)
                self.pos_drop = nn.Dropout(attn_dropout)

        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=attn_dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=n_layers)

        self.attnpool = None
        if pool == "attnpool":
            self.attnpool = TemporalAttnPool(hidden=d_model, dropout=attn_dropout)

    def forward(self, h_seq: torch.Tensor):
        # h_seq: (B,L,D)
        B, L, D = h_seq.shape
        x = h_seq

        if self.use_posenc:
            if hasattr(self, "pos_emb"):
                pos = torch.arange(L, device=x.device)
                x = x + self.pos_emb(pos).unsqueeze(0).to(dtype=x.dtype)  # (1,L,D)
                x = self.pos_drop(x)
            else:
                x = self.pos_drop(self.pos_enc(x))

        x = self.enc(x)  # (B,L,D)

        if self.pool == "attnpool":
            pooled, w = self.attnpool(x)  # pooled (B,D)
            return pooled, w

        # default: last token
        pooled = x[:, -1, :]  # (B,D)
        return pooled, None


class GNN_Attn_Classifier(nn.Module):
    """
    Drop-in replacement for GNN_LSTM_Classifier:
      - spatial: same EdgeGatedMP stack => h (B,L,N,H)
      - target node => h_tgt (B,L,H)
      - temporal: CFG["temporal_mode"] in {"transformer","attnpool"}
      - head => logits (B,2)
    """
    def __init__(
        self,
        node_in,
        edge_dim,
        hidden,
        gnn_layers,
        lstm_hidden,   # keep name for compatibility; this is head/temporal_out dim now
        lstm_layers,   # unused but kept for drop-in
        dropout=0.1,
        target_node=2,
        n_classes=2,
        cfg=None,
    ):
        super().__init__()
        self.target_node = target_node
        self.cfg = cfg if cfg is not None else globals().get("CFG", {})
        self.temporal_mode = self.cfg.get("temporal_mode", "transformer")

        # --- spatial stack (unchanged behavior)
        self.gnns = nn.ModuleList()
        for i in range(gnn_layers):
            in_dim = node_in if i == 0 else hidden
            self.gnns.append(EdgeGatedMP(in_dim=in_dim, hidden=hidden, edge_dim=edge_dim, dropout=dropout))

        # --- temporal encoder
        self.temporal_proj = nn.Identity() if hidden == lstm_hidden else nn.Linear(hidden, lstm_hidden)

        if self.temporal_mode == "transformer":
            n_heads = int(self.cfg.get("n_heads", 4))
            n_layers = int(self.cfg.get("n_layers", 1))
            attn_dropout = float(self.cfg.get("attn_dropout", dropout))
            use_posenc = bool(self.cfg.get("use_posenc", True))
            pool = self.cfg.get("transformer_pool", "last")
            posenc_learnable = bool(self.cfg.get("posenc_learnable", False))
            max_len = int(self.cfg.get("lookback", 512))

            if hidden % n_heads != 0:
                raise ValueError(f"hidden ({hidden}) must be divisible by n_heads ({n_heads}).")

            self.temporal = TemporalTransformerEncoder(
                d_model=hidden,
                n_heads=n_heads,
                n_layers=n_layers,
                attn_dropout=attn_dropout,
                use_posenc=use_posenc,
                max_len=max_len,
                pool=pool,
                posenc_learnable=posenc_learnable,
            )

        elif self.temporal_mode == "attnpool":
            attn_dropout = float(self.cfg.get("attn_dropout", dropout))
            self.temporal = TemporalAttnPool(hidden=hidden, dropout=attn_dropout)

        else:
            raise ValueError("CFG['temporal_mode'] must be 'transformer' or 'attnpool'")

        # --- head (как у тебя, только вход= ltsm_hidden)
        self.head = nn.Sequential(
            nn.LayerNorm(lstm_hidden),
            nn.Dropout(dropout),
            nn.Linear(lstm_hidden, lstm_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_hidden, n_classes)
        )

        # init (как у тебя)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, e, edge_index):
        h = x
        for gnn in self.gnns:
            h = gnn(h, e, edge_index)  # (B,L,N,H)

        h_tgt = h[:, :, self.target_node, :]  # (B,L,H)

        if self.temporal_mode == "transformer":
            pooled, _ = self.temporal(h_tgt)   # (B,H)
        else:
            pooled, _ = self.temporal(h_tgt)   # (B,H)

        pooled = torch.nan_to_num(pooled, nan=0.0, posinf=0.0, neginf=0.0)
        pooled = self.temporal_proj(pooled)    # (B,lstm_hidden)

        logits = self.head(pooled)             # (B,2)
        return torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)


# --- DROP-IN: чтобы вообще не менять train loops / train_binary_classifier
#GNN_LSTM_Classifier = GNN_Attn_Classifier


# -------------------------
# SMOKE TEST (быстрый прогон)
# -------------------------
with torch.no_grad():
    B = 2
    L = int(CFG.get("lookback", 16))
    N = len(ASSETS)
    node_in = X_node_raw.shape[-1] if "X_node_raw" in globals() else 15
    edge_dim = edge_feat.shape[-1] if "edge_feat" in globals() else 5
    E = EDGE_INDEX.shape[0] if "EDGE_INDEX" in globals() else 3

    x = torch.randn(B, L, N, node_in).to(DEVICE)
    e = torch.randn(B, L, E, edge_dim).to(DEVICE)

    # 1) transformer
    CFG["temporal_mode"] = "transformer"
    m = GNN_Attn_Classifier(
        node_in=node_in, edge_dim=edge_dim,
        hidden=CFG["hidden"], gnn_layers=CFG["gnn_layers"],
        lstm_hidden=CFG["lstm_hidden"], lstm_layers=CFG["lstm_layers"],
        dropout=CFG["dropout"], target_node=TARGET_NODE, n_classes=2,
        cfg=CFG,
    ).to(DEVICE)
    y = m(x, e, EDGE_INDEX.to(DEVICE))
    print("[smoke transformer] logits:", tuple(y.shape))

    # 2) attnpool
    CFG["temporal_mode"] = "attnpool"
    m2 = GNN_Attn_Classifier(
        node_in=node_in, edge_dim=edge_dim,
        hidden=CFG["hidden"], gnn_layers=CFG["gnn_layers"],
        lstm_hidden=CFG["lstm_hidden"], lstm_layers=CFG["lstm_layers"],
        dropout=CFG["dropout"], target_node=TARGET_NODE, n_classes=2,
        cfg=CFG,
    ).to(DEVICE)
    y2 = m2(x, e, EDGE_INDEX.to(DEVICE))
    print("[smoke attnpool] logits:", tuple(y2.shape))

print("Model ready (Attention temporal).")


[smoke transformer] logits: (2, 2)
[smoke attnpool] logits: (2, 2)
Model ready (Attention temporal).




## 8. Training/Eval: Stage A (trade) и Stage B (direction)

In [10]:
# ЛОГИЧЕСКИЙ БЛОК: train/eval helpers for two-stage
# ИСПОЛНЕНИЕ БЛОКА:

@torch.no_grad()
def eval_binary(model, loader, loss_fn, y_key: str = "trade"):
    model.eval()
    ys = []
    probs = []
    ers = []
    total_loss = 0.0
    n = 0

    for x, e, y_trade_b, y_dir_b, er in loader:
        x = x.to(DEVICE).float()
        e = e.to(DEVICE).float()
        y = (y_trade_b if y_key == "trade" else y_dir_b).to(DEVICE).long()

        logits = model(x, e, EDGE_INDEX.to(DEVICE))
        loss = loss_fn(logits, y)

        total_loss += loss.item() * y.size(0)
        n += y.size(0)

        p = torch.softmax(logits, dim=-1).detach().cpu().numpy()
        ys.append(y.detach().cpu().numpy())
        probs.append(p)
        ers.append(er.detach().cpu().numpy())

    ys = np.concatenate(ys) if len(ys) else np.array([], dtype=np.int64)
    probs = np.concatenate(probs) if len(probs) else np.zeros((0, 2), dtype=np.float32)
    ers = np.concatenate(ers) if len(ers) else np.array([], dtype=np.float32)

    if len(ys) == 0:
        return np.nan, np.nan, np.nan, np.nan, None, ys, probs, ers

    y_pred = probs.argmax(axis=1)
    acc = accuracy_score(ys, y_pred)
    f1m = f1_score(ys, y_pred, average="macro")
    auc = roc_auc_score(ys, probs[:, 1]) if len(np.unique(ys)) == 2 else np.nan
    cm = confusion_matrix(ys, y_pred)

    return total_loss / max(n, 1), acc, f1m, auc, cm, ys, probs, ers


@torch.no_grad()
def predict_probs_only(model, loader):
    model.eval()
    probs = []
    ers = []
    for x, e, y_trade_b, y_dir_b, er in loader:
        x = x.to(DEVICE).float()
        e = e.to(DEVICE).float()
        logits = model(x, e, EDGE_INDEX.to(DEVICE))
        p = torch.softmax(logits, dim=-1).detach().cpu().numpy()
        probs.append(p)
        ers.append(er.detach().cpu().numpy())
    probs = np.concatenate(probs) if len(probs) else np.zeros((0, 2), dtype=np.float32)
    ers = np.concatenate(ers) if len(ers) else np.array([], dtype=np.float32)
    return probs, ers


def pnl_proxy_grid_max(prob_trade, prob_dir, exit_ret, thr_trade_grid, thr_dir_grid, cost_bps, min_trades: int = 0):
    """
    Возвращает лучший pnl_mean по grid (per-bar), плюс пороги и статистику.
    min_trades используется как фильтр: комбинации, где сделок меньше, пропускаются.
    Если ни одна комбинация не прошла min_trades — вернём best без фильтра (но это будет fallback-сценарий).
    """
    p_trade = prob_trade[:, 1]
    p_up = prob_dir[:, 1]
    p_dn = 1.0 - p_up
    conf_dir = np.maximum(p_up, p_dn)

    sign = np.where(p_up >= 0.5, 1.0, -1.0).astype(np.float32)
    cost = float(cost_bps) * 1e-4
    N = len(exit_ret)

    best = {
        "pnl_mean": -1e18,
        "pnl_sum": -1e18,
        "thr_trade": None,
        "thr_dir": None,
        "n_trades": 0,
        "trade_rate": 0.0,
        "min_trades_used": int(min_trades),
        "passed_min_trades": False,
    }

    # 1) строгий проход (>=min_trades)
    for thr_t in thr_trade_grid:
        mt = (p_trade >= thr_t)
        for thr_d in thr_dir_grid:
            mask = mt & (conf_dir >= thr_d)
            n_tr = int(mask.sum())
            if n_tr < int(min_trades):
                continue

            pnl = (sign * exit_ret) * mask.astype(np.float32) - cost * mask.astype(np.float32)
            pnl_sum = float(pnl.sum())
            pnl_mean = float(pnl.mean()) if N > 0 else np.nan

            if pnl_mean > best["pnl_mean"]:
                best.update({
                    "pnl_mean": pnl_mean,
                    "pnl_sum": pnl_sum,
                    "thr_trade": float(thr_t),
                    "thr_dir": float(thr_d),
                    "n_trades": n_tr,
                    "trade_rate": float(n_tr / max(1, N)),
                    "passed_min_trades": True,
                })

    # 2) если ничего не прошло min_trades — найдём best без фильтра (для fallback-логов)
    if best["thr_trade"] is None:
        for thr_t in thr_trade_grid:
            mt = (p_trade >= thr_t)
            for thr_d in thr_dir_grid:
                mask = mt & (conf_dir >= thr_d)
                n_tr = int(mask.sum())
                pnl = (sign * exit_ret) * mask.astype(np.float32) - cost * mask.astype(np.float32)
                pnl_sum = float(pnl.sum())
                pnl_mean = float(pnl.mean()) if N > 0 else np.nan

                if pnl_mean > best["pnl_mean"]:
                    best.update({
                        "pnl_mean": pnl_mean,
                        "pnl_sum": pnl_sum,
                        "thr_trade": float(thr_t),
                        "thr_dir": float(thr_d),
                        "n_trades": n_tr,
                        "trade_rate": float(n_tr / max(1, N)),
                        "passed_min_trades": False,
                    })

    return best


def train_binary_classifier(
    X_scaled, edge_feat,
    y_trade_arr, y_dir_arr,
    exit_ret, sample_t,
    idx_train, idx_val, idx_test,
    cfg,
    stage_name: str,
    select_metric: str | None = None,        # "va_auc" | "va_f1m" | "va_pnl_max"
    trade_model_for_pnl=None,                # для stage="dir": фиксированная trade-модель
    idx_val_pnl=None,                        # индексы (sample-space) для pnl-proxy, обычно полный idx_val
):
    """
    Градиентами оптимизируем: CrossEntropyLoss.
    Селектор best checkpoint:
      - trade: обычно va_auc
      - dir:   va_pnl_max, но если best trades < proxy_min_trades => fallback на va_auc
              (реализовано как: best_pnl если были эпохи с trades>=min, иначе best_auc)
    """
    if select_metric is None:
        select_metric = "va_auc"
    if select_metric not in ("va_auc", "va_f1m", "va_pnl_max"):
        raise ValueError("select_metric must be one of: 'va_auc', 'va_f1m', 'va_pnl_max'")

    if select_metric == "va_pnl_max":
        if stage_name != "dir":
            raise ValueError("select_metric='va_pnl_max' supported only for stage_name='dir'")
        if trade_model_for_pnl is None or idx_val_pnl is None:
            raise ValueError("For va_pnl_max you must pass trade_model_for_pnl and idx_val_pnl (full val indices).")

    L = cfg["lookback"]

    tr_ds = LobGraphSequenceDataset2Stage(X_scaled, edge_feat, y_trade_arr, y_dir_arr, exit_ret, sample_t, idx_train, L)
    va_ds = LobGraphSequenceDataset2Stage(X_scaled, edge_feat, y_trade_arr, y_dir_arr, exit_ret, sample_t, idx_val,   L)
    te_ds = LobGraphSequenceDataset2Stage(X_scaled, edge_feat, y_trade_arr, y_dir_arr, exit_ret, sample_t, idx_test,  L)

    tr_loader = DataLoader(tr_ds, batch_size=cfg["batch_size"], shuffle=True,  drop_last=True, collate_fn=collate_fn_2stage)
    va_loader = DataLoader(va_ds, batch_size=cfg["batch_size"], shuffle=False, collate_fn=collate_fn_2stage)
    te_loader = DataLoader(te_ds, batch_size=cfg["batch_size"], shuffle=False, collate_fn=collate_fn_2stage)

    va_pnl_loader = None
    if stage_name == "dir" and (idx_val_pnl is not None):
        va_pnl_ds = LobGraphSequenceDataset2Stage(X_scaled, edge_feat, y_trade_arr, y_dir_arr, exit_ret, sample_t, idx_val_pnl, L)
        va_pnl_loader = DataLoader(va_pnl_ds, batch_size=cfg["batch_size"], shuffle=False, collate_fn=collate_fn_2stage)

    node_in = X_scaled.shape[-1]
    edge_dim = edge_feat.shape[-1]
    model = GNN_Attn_Classifier(
        node_in=node_in, edge_dim=edge_dim,
        hidden=cfg["hidden"], gnn_layers=cfg["gnn_layers"],
        lstm_hidden=cfg["lstm_hidden"], lstm_layers=cfg["lstm_layers"],
        dropout=cfg["dropout"], target_node=TARGET_NODE, n_classes=2,
        cfg=cfg,
    ).to(DEVICE)


    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=4)
    scaler_amp = torch.amp.GradScaler('cuda', enabled=(cfg["use_amp"] and DEVICE.type == "cuda"))

    # --- trade-prob на полном val для PnL proxy (считаем 1 раз)
    prob_trade_val_pnl = None
    if stage_name == "dir" and (trade_model_for_pnl is not None) and (va_pnl_loader is not None):
        prob_trade_val_pnl, _ = predict_probs_only(trade_model_for_pnl, va_pnl_loader)

    thr_trade_grid_proxy = cfg.get("proxy_thr_trade_grid") or cfg.get("thr_trade_grid", [0.5])
    thr_dir_grid_proxy   = cfg.get("proxy_thr_dir_grid")   or cfg.get("thr_dir_grid",   [0.5])
    proxy_min_trades = int(cfg.get("proxy_min_trades", 0))

    # --- best trackers
    best_score = -1e18
    best_state = None
    best_epoch = -1
    best_used = select_metric

    # специальные трекеры для va_pnl_max с fallback
    best_score_auc = -1e18
    best_state_auc = None
    best_epoch_auc = -1

    best_score_pnl = -1e18
    best_state_pnl = None
    best_epoch_pnl = -1

    seen_pnl_ok = False

    patience = 8
    bad = 0

    hist = {
        "tr_loss": [], "va_loss": [],
        "va_f1m": [], "va_auc": [],
        "va_pnl_max": [],
        "va_pnl_thr_trade": [],
        "va_pnl_thr_dir": [],
        "va_pnl_n_trades": [],
        "va_sel": [],
        "va_sel_mode": []
    }

    for ep in range(1, cfg["epochs"] + 1):
        # ---- TRAIN
        model.train()
        tot = 0.0
        n = 0

        for x, e, y_trade_b, y_dir_b, er in tr_loader:
            x = x.to(DEVICE).float()
            e = e.to(DEVICE).float()
            y = (y_trade_b if stage_name == "trade" else y_dir_b).to(DEVICE).long()

            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=(cfg["use_amp"] and DEVICE.type == "cuda")):
                logits = model(x, e, EDGE_INDEX.to(DEVICE))
                loss = loss_fn(logits, y)

            if not torch.isfinite(loss):
                continue

            scaler_amp.scale(loss).backward()
            scaler_amp.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
            scaler_amp.step(opt)
            scaler_amp.update()

            tot += loss.item() * y.size(0)
            n += y.size(0)

        tr_loss = tot / max(n, 1)

        # ---- VAL classification metrics
        va_loss, va_acc, va_f1m, va_auc, va_cm, va_y, va_prob, va_er = eval_binary(
            model, va_loader, loss_fn, y_key=stage_name
        )

        # ---- VAL PnL proxy (dir only)
        va_pnl_best = {"pnl_mean": np.nan, "thr_trade": np.nan, "thr_dir": np.nan, "n_trades": 0, "trade_rate": np.nan,
                       "passed_min_trades": False, "min_trades_used": proxy_min_trades}

        if stage_name == "dir" and (prob_trade_val_pnl is not None) and (va_pnl_loader is not None):
            prob_dir_val_pnl, er_dir_val_pnl = predict_probs_only(model, va_pnl_loader)

            va_pnl_best = pnl_proxy_grid_max(
                prob_trade=prob_trade_val_pnl,
                prob_dir=prob_dir_val_pnl,
                exit_ret=er_dir_val_pnl,
                thr_trade_grid=thr_trade_grid_proxy,
                thr_dir_grid=thr_dir_grid_proxy,
                cost_bps=cfg["cost_bps"],
                min_trades=proxy_min_trades,
            )

        # ---- selection
        sel_val = np.nan
        sel_mode = select_metric

        if select_metric in ("va_auc", "va_f1m"):
            sel_val = (va_auc if select_metric == "va_auc" else va_f1m)
            if not np.isfinite(sel_val):
                sel_val = -1e18

            # единый best
            prev_best = best_score
            if sel_val > best_score:
                best_score = sel_val
                best_epoch = ep
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                bad = 0
            else:
                bad += 1

        else:
            # select_metric == "va_pnl_max" (dir only) with hard fallback
            pnl_mean = float(va_pnl_best["pnl_mean"])
            n_tr = int(va_pnl_best["n_trades"])
            pnl_ok = (np.isfinite(pnl_mean) and (n_tr >= proxy_min_trades))

            # обновим best_auc (fallback) всегда
            if np.isfinite(va_auc) and (va_auc > best_score_auc):
                best_score_auc = float(va_auc)
                best_epoch_auc = ep
                best_state_auc = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

            # обновим best_pnl только если pnl_ok
            if pnl_ok and (pnl_mean > best_score_pnl):
                best_score_pnl = pnl_mean
                best_epoch_pnl = ep
                best_state_pnl = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

            if pnl_ok:
                seen_pnl_ok = True
                sel_val = pnl_mean
                sel_mode = "va_pnl_max"
            else:
                sel_val = float(va_auc) if np.isfinite(va_auc) else -1e18
                sel_mode = f"va_auc_fallback({n_tr}/{proxy_min_trades})"

            # scheduler всегда по текущему sel_val
            if not np.isfinite(sel_val):
                sel_val = -1e18
            sch.step(float(sel_val))

            # early stop: до первой валидной pnl-эпохи -> по AUC, после -> по PnL
            improved = False
            if not seen_pnl_ok:
                # следим за ростом AUC
                improved = (np.isfinite(va_auc) and (float(va_auc) >= best_score_auc))
            else:
                # следим за ростом PnL (только когда pnl_ok)
                improved = pnl_ok and (pnl_mean >= best_score_pnl)

            if improved:
                bad = 0
            else:
                bad += 1

        # если не pnl-метрика — scheduler тут
        if select_metric != "va_pnl_max":
            sch.step(float(sel_val))

        # ---- logging + hist
        hist["tr_loss"].append(tr_loss)
        hist["va_loss"].append(va_loss)
        hist["va_f1m"].append(va_f1m)
        hist["va_auc"].append(va_auc)

        hist["va_pnl_max"].append(float(va_pnl_best["pnl_mean"]) if np.isfinite(va_pnl_best["pnl_mean"]) else np.nan)
        hist["va_pnl_thr_trade"].append(float(va_pnl_best["thr_trade"]) if va_pnl_best["thr_trade"] is not None else np.nan)
        hist["va_pnl_thr_dir"].append(float(va_pnl_best["thr_dir"]) if va_pnl_best["thr_dir"] is not None else np.nan)
        hist["va_pnl_n_trades"].append(int(va_pnl_best["n_trades"]))
        hist["va_sel"].append(float(sel_val) if np.isfinite(sel_val) else np.nan)
        hist["va_sel_mode"].append(sel_mode)

        lr_now = opt.param_groups[0]["lr"]

        # красивый best_str
        if select_metric == "va_pnl_max":
            if best_state_pnl is not None:
                best_str = f"pnl={best_score_pnl:.6f}@ep{best_epoch_pnl:02d}"
            else:
                best_str = f"auc={best_score_auc:.6f}@ep{best_epoch_auc:02d}"
        else:
            best_str = f"{best_score:.6f}@ep{best_epoch:02d}" if best_epoch > 0 else "none"

        if stage_name == "dir":
            print(
                f"[{stage_name}] ep {ep:02d} lr={lr_now:.2e} "
                f"tr_loss={tr_loss:.4f} va_loss={va_loss:.4f} "
                f"f1m={va_f1m:.3f} auc={va_auc:.3f} "
                f"pnl_max={va_pnl_best['pnl_mean']:.6f} "
                f"thr=({va_pnl_best['thr_trade']:.2f},{va_pnl_best['thr_dir']:.2f}) "
                f"trades={va_pnl_best['n_trades']} "
                f"sel({sel_mode})={float(sel_val):.6f} best={best_str}"
            )
        else:
            print(
                f"[{stage_name}] ep {ep:02d} lr={lr_now:.2e} "
                f"tr_loss={tr_loss:.4f} va_loss={va_loss:.4f} "
                f"f1m={va_f1m:.3f} auc={va_auc:.3f} "
                f"sel({select_metric})={float(sel_val):.6f} best={best_str}"
            )

        if bad >= patience:
            break

    # ---- choose final best state
    if select_metric == "va_pnl_max":
        if best_state_pnl is not None:
            model.load_state_dict(best_state_pnl)
            best_score = best_score_pnl
            best_epoch = best_epoch_pnl
            best_used = "va_pnl_max"
        else:
            model.load_state_dict(best_state_auc)
            best_score = best_score_auc
            best_epoch = best_epoch_auc
            best_used = "va_auc_fallback_only"
    else:
        if best_state is not None:
            model.load_state_dict(best_state)
            best_used = select_metric

    # финальные VAL/TEST по best_state
    va_loss, va_acc, va_f1m, va_auc, va_cm, va_y, va_prob, va_er = eval_binary(
        model, va_loader, loss_fn, y_key=stage_name
    )
    te_loss, te_acc, te_f1m, te_auc, te_cm, te_y, te_prob, te_er = eval_binary(
        model, te_loader, loss_fn, y_key=stage_name
    )

    res = {
        "best_val_score": float(best_score),
        "best_epoch": int(best_epoch),
        "select_metric": select_metric,
        "best_used": best_used,

        "val_loss": va_loss,
        "val_acc": va_acc,
        "val_f1m": va_f1m,
        "val_auc": va_auc,
        "val_cm": va_cm,
        "val_y": va_y,
        "val_prob": va_prob,
        "val_er": va_er,

        "test_loss": te_loss,
        "test_acc": te_acc,
        "test_f1m": te_f1m,
        "test_auc": te_auc,
        "test_cm": te_cm,
        "test_y": te_y,
        "test_prob": te_prob,
        "test_er": te_er,

        "hist": hist,
    }
    return model, res



## 9. Two-stage PnL by confidence thresholds

In [11]:
# ЛОГИЧЕСКИЙ БЛОК: PnL по порогам уверенности (two-stage)
# ИСПОЛНЕНИЕ БЛОКА:

def two_stage_pnl_by_threshold(
    prob_trade,          # (N,2) softmax: [:,1]=p_trade
    prob_dir,            # (N,2) softmax: [:,1]=p_up
    exit_ret,            # (N,) realized log-ret to TB exit
    thr_trade: float,
    thr_dir: float,
    cost_bps: float,
):
    p_trade = prob_trade[:,1]
    p_up = prob_dir[:,1]
    p_dn = 1.0 - p_up
    conf_dir = np.maximum(p_up, p_dn)

    trade_mask = (p_trade >= thr_trade) & (conf_dir >= thr_dir)

    action = np.zeros_like(exit_ret, dtype=np.float32)
    action[trade_mask] = np.where(p_up[trade_mask] >= 0.5, 1.0, -1.0)

    cost = (cost_bps * 1e-4) * trade_mask.astype(np.float32)
    pnl = action * exit_ret - cost

    out = {
        "n": len(exit_ret),
        "n_trades": int(trade_mask.sum()),
        "trade_rate": float(trade_mask.mean()),
        "pnl_sum": float(pnl.sum()),
        "pnl_mean": float(pnl.mean()),
        "pnl_sharpe": float((pnl.mean() / (pnl.std() + 1e-12)) * np.sqrt(288)),
    }
    return out

def sweep_thresholds(prob_trade, prob_dir, exit_ret, cfg):
    rows = []
    for thr_t in cfg["thr_trade_grid"]:
        for thr_d in cfg["thr_dir_grid"]:
            m = two_stage_pnl_by_threshold(
                prob_trade=prob_trade,
                prob_dir=prob_dir,
                exit_ret=exit_ret,
                thr_trade=thr_t,
                thr_dir=thr_d,
                cost_bps=cfg["cost_bps"],
            )
            rows.append({"thr_trade":thr_t, "thr_dir":thr_d, **m})
    return pd.DataFrame(rows).sort_values(["pnl_mean","pnl_sum"], ascending=False)

print("Two-stage PnL threshold utils ready.")


Two-stage PnL threshold utils ready.


In [12]:
# ЛОГИЧЕСКИЙ БЛОК: shared helper for probs on arbitrary indices
# ИСПОЛНЕНИЕ БЛОКА:

@torch.no_grad()
def predict_probs_on_indices(model, X_scaled, edge_feat, indices, cfg):
    ds = LobGraphSequenceDataset2Stage(
        X_scaled, edge_feat, y_trade, y_dir, exit_ret, sample_t, indices, cfg["lookback"]
    )
    loader = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=False, collate_fn=collate_fn_2stage)

    model.eval()
    probs = []
    ers = []
    for x, e, yt, yd, er in loader:
        x = x.to(DEVICE).float()
        e = e.to(DEVICE).float()
        logits = model(x, e, EDGE_INDEX.to(DEVICE))
        p = torch.softmax(logits, dim=-1).cpu().numpy()
        probs.append(p)
        ers.append(er.cpu().numpy())

    return np.concatenate(probs), np.concatenate(ers)


## 10. Run folds: scale once → train trade → filter trades → train dir → PnL sweep

In [13]:
# ЛОГИЧЕСКИЙ БЛОК: run walk-forward folds for two-stage training (ONLY on CV-part)
# ИСПОЛНЕНИЕ БЛОКА:

fold_rows = []
models_trade = []
models_dir = []

for fi, (idx_tr, idx_va, idx_te) in enumerate(walk_splits, 1):
    print("\n" + "="*70)
    print(f"FOLD {fi}/{len(walk_splits)} sizes:", len(idx_tr), len(idx_va), len(idx_te))

    # scale once per fold (fit only on train times)
    X_scaled, _ = fit_scale_nodes_train_only(X_node_raw, sample_t, idx_tr, max_abs=CFG["max_abs_feat"])

    # ---- Stage A: trade/no-trade on all samples (по AUC)
    m_trade, r_trade = train_binary_classifier(
        X_scaled, edge_feat,
        y_trade, y_dir,
        exit_ret, sample_t,
        idx_tr, idx_va, idx_te,
        CFG,
        stage_name="trade",
        select_metric="va_auc",
    )
    models_trade.append(m_trade)

    # ---- Stage B: direction ONLY on trade samples (train/val/test индексы фильтруем)
    idx_tr_T = subset_trade_indices(idx_tr, sample_t, y_trade)
    idx_va_T = subset_trade_indices(idx_va, sample_t, y_trade)
    idx_te_T = subset_trade_indices(idx_te, sample_t, y_trade)

    if len(idx_tr_T) < max(200, CFG["batch_size"]*2) or len(idx_te_T) < 50:
        print("[dir] skip: not enough trade samples in this fold.")
        fold_rows.append({
            "fold": fi,
            "trade_test_f1m": r_trade["test_f1m"],
            "dir_test_f1m": np.nan,
            "best_pnl_mean": np.nan,
            "best_thr_trade": np.nan,
            "best_thr_dir": np.nan,
            "n_trades_best": np.nan,
            "trade_rate_best": np.nan,
        })
        continue

    # dir: учим на trade-only, но PnL-proxy считаем на полном idx_va (full val)
    m_dir, r_dir = train_binary_classifier(
        X_scaled, edge_feat,
        y_trade, y_dir,
        exit_ret, sample_t,
        idx_tr_T, idx_va_T, idx_te_T,
        CFG,
        stage_name="dir",
        select_metric="va_pnl_max",
        trade_model_for_pnl=m_trade,
        idx_val_pnl=idx_va,   # <-- полный val для pnl-proxy
    )
    models_dir.append(m_dir)

    # ---- Two-stage PnL evaluation on fold TEST
    prob_trade_te, er_te = predict_probs_on_indices(m_trade, X_scaled, edge_feat, idx_te, CFG)
    prob_dir_te, _       = predict_probs_on_indices(m_dir,   X_scaled, edge_feat, idx_te, CFG)

    sweep = sweep_thresholds(prob_trade_te, prob_dir_te, er_te, CFG)
    best = sweep.iloc[0].to_dict()

    print("PnL on fold-test:",
          "| thr_trade=", best["thr_trade"],
          "| thr_dir=", best["thr_dir"],
          "| pnl_mean=", best["pnl_mean"],
          "| trades=", best["n_trades"])

    fold_rows.append({
        "fold": fi,
        "trade_test_f1m": r_trade["test_f1m"],
        "dir_test_f1m": r_dir["test_f1m"],
        "best_pnl_mean": best["pnl_mean"],
        "best_thr_trade": best["thr_trade"],
        "best_thr_dir": best["thr_dir"],
        "n_trades_best": best["n_trades"],
        "trade_rate_best": best["trade_rate"],
    })

summary = pd.DataFrame(fold_rows)
display(summary)
print("\nMEAN (fold-test внутри CV-part):")
print(summary.mean(numeric_only=True))



FOLD 1/4 sizes: 1168 233 233
[trade] ep 01 lr=2.00e-04 tr_loss=0.7668 va_loss=0.6946 f1m=0.380 auc=0.371 sel(va_auc)=0.370785 best=0.370785@ep01
[trade] ep 02 lr=2.00e-04 tr_loss=0.7084 va_loss=0.6893 f1m=0.380 auc=0.311 sel(va_auc)=0.311111 best=0.370785@ep01
[trade] ep 03 lr=2.00e-04 tr_loss=0.6913 va_loss=0.6864 f1m=0.380 auc=0.340 sel(va_auc)=0.339549 best=0.370785@ep01
[trade] ep 04 lr=2.00e-04 tr_loss=0.6940 va_loss=0.6756 f1m=0.380 auc=0.328 sel(va_auc)=0.327661 best=0.370785@ep01
[trade] ep 05 lr=2.00e-04 tr_loss=0.6685 va_loss=0.6733 f1m=0.380 auc=0.435 sel(va_auc)=0.435431 best=0.435431@ep05
[trade] ep 06 lr=2.00e-04 tr_loss=0.6625 va_loss=0.6795 f1m=0.380 auc=0.369 sel(va_auc)=0.369464 best=0.435431@ep05
[trade] ep 07 lr=2.00e-04 tr_loss=0.6741 va_loss=0.6705 f1m=0.380 auc=0.451 sel(va_auc)=0.451127 best=0.451127@ep07
[trade] ep 08 lr=2.00e-04 tr_loss=0.6688 va_loss=0.6735 f1m=0.380 auc=0.451 sel(va_auc)=0.450738 best=0.451127@ep07
[trade] ep 09 lr=2.00e-04 tr_loss=0.6519 v

Unnamed: 0,fold,trade_test_f1m,dir_test_f1m,best_pnl_mean,best_thr_trade,best_thr_dir,n_trades_best,trade_rate_best
0,1,0.428922,0.198317,0.0,0.5,0.5,0.0,0.0
1,2,0.233553,0.381679,0.0,0.5,0.5,0.0,0.0
2,3,0.120755,0.345277,0.0,0.5,0.5,0.0,0.0
3,4,0.515032,0.396988,0.0,0.5,0.6,0.0,0.0



MEAN (fold-test внутри CV-part):
fold               2.500000
trade_test_f1m     0.324565
dir_test_f1m       0.330565
best_pnl_mean      0.000000
best_thr_trade     0.500000
best_thr_dir       0.525000
n_trades_best      0.000000
trade_rate_best    0.000000
dtype: float64


## 11. Final test

In [14]:
# ЛОГИЧЕСКИЙ БЛОК: Final train on CV(90%) and evaluate once on FINAL(10%)
# ИСПОЛНЕНИЕ БЛОКА:

print("\n" + "="*70)
print("FINAL TRAIN/TEST (CV=90% | FINAL=10%)")

# 1) final train/val split внутри CV-part (по времени)
val_w_final = max(1, int(CFG["val_window_frac"] * n_samples_cv))
train_end = n_samples_cv - val_w_final

idx_train_final = np.arange(0, train_end, dtype=np.int64)
idx_val_final   = np.arange(train_end, n_samples_cv, dtype=np.int64)
idx_test_final  = idx_final_test.astype(np.int64)  # финальный holdout

print("Final split sizes:")
print("  train_final:", len(idx_train_final))
print("  val_final  :", len(idx_val_final))
print("  FINAL test :", len(idx_test_final))

# 2) scaling (fit only on train_final)
X_scaled_final, _ = fit_scale_nodes_train_only(X_node_raw, sample_t, idx_train_final, max_abs=CFG["max_abs_feat"])

# 6) финальная оценка на holdout (БЕЗ подбора порогов на holdout)
prob_trade_hold, er_hold = predict_probs_on_indices(m_trade, X_scaled_final, edge_feat, idx_test_final, CFG)
prob_dir_hold, _         = predict_probs_on_indices(m_dir,   X_scaled_final, edge_feat, idx_test_final, CFG)

final_metrics = two_stage_pnl_by_threshold(
    prob_trade=prob_trade_hold,
    prob_dir=prob_dir_hold,
    exit_ret=er_hold,
    thr_trade=summary['best_thr_trade'][3],
    thr_dir=summary['best_thr_dir'][3],
    cost_bps=CFG["cost_bps"],
)

print("\nFINAL HOLDOUT RESULT (fixed thresholds from val_final):")
print("  pnl_mean :", final_metrics["pnl_mean"])
print("  pnl_sum  :", final_metrics["pnl_sum"])
print("  n_trades :", final_metrics["n_trades"])
print("  trade_rate:", final_metrics["trade_rate"])
print("  sharpe (per-bar proxy):", final_metrics["pnl_sharpe"])

# (опционально) oracle на holdout — НЕ для выбора, только “потолок”
sweep_hold_oracle = sweep_thresholds(prob_trade_hold, prob_dir_hold, er_hold, CFG)
best_hold_oracle = sweep_hold_oracle.iloc[0].to_dict()
print("\n[ORACLE] best possible on holdout by sweeping thresholds (DO NOT USE for selection):")
print("  thr_trade:", best_hold_oracle["thr_trade"], "thr_dir:", best_hold_oracle["thr_dir"])
print("  pnl_mean :", best_hold_oracle["pnl_mean"], "trades:", best_hold_oracle["n_trades"])



FINAL TRAIN/TEST (CV=90% | FINAL=10%)
Final split sizes:
  train_final: 2104
  val_final  : 233
  FINAL test : 260

FINAL HOLDOUT RESULT (fixed thresholds from val_final):
  pnl_mean : 0.00010919975466094911
  pnl_sum  : 0.028391936793923378
  n_trades : 27
  trade_rate: 0.10384615384615385
  sharpe (per-bar proxy): 0.7988579415838531

[ORACLE] best possible on holdout by sweeping thresholds (DO NOT USE for selection):
  thr_trade: 0.55 thr_dir: 0.6
  pnl_mean : 0.00016489169502165169 trades: 24.0


In [15]:


# 3) train TRADE on train_final, select by AUC on val_final
m_trade_final, r_trade_final = train_binary_classifier(
    X_scaled_final, edge_feat,
    y_trade, y_dir,
    exit_ret, sample_t,
    idx_train_final, idx_val_final, idx_test_final,
    CFG,
    stage_name="trade",
    select_metric="va_auc",
)

# 4) train DIR on trade-only samples (train/val/test filtered),
#    but pnl-proxy computed on full val_final; selector hard-fallback already inside
idx_train_final_T = subset_trade_indices(idx_train_final, sample_t, y_trade)
idx_val_final_T   = subset_trade_indices(idx_val_final,   sample_t, y_trade)
idx_test_final_T  = subset_trade_indices(idx_test_final,  sample_t, y_trade)

print("Trade-only sizes for DIR:")
print("  train_final_T:", len(idx_train_final_T))
print("  val_final_T  :", len(idx_val_final_T))
print("  test_final_T :", len(idx_test_final_T))

m_dir_final, r_dir_final = train_binary_classifier(
    X_scaled_final, edge_feat,
    y_trade, y_dir,
    exit_ret, sample_t,
    idx_train_final_T, idx_val_final_T, idx_test_final_T,
    CFG,
    stage_name="dir",
    select_metric="va_pnl_max",
    trade_model_for_pnl=m_trade_final,
    idx_val_pnl=idx_val_final,   # pnl-proxy на полном val_final
)

# 5) выбрать пороги по val_final (grid sweep)
prob_trade_val, er_val = predict_probs_on_indices(m_trade_final, X_scaled_final, edge_feat, idx_val_final, CFG)
prob_dir_val, _        = predict_probs_on_indices(m_dir_final,   X_scaled_final, edge_feat, idx_val_final, CFG)

sweep_val = sweep_thresholds(prob_trade_val, prob_dir_val, er_val, CFG)
best_val = sweep_val.iloc[0].to_dict()
thr_trade_star = float(best_val["thr_trade"])
thr_dir_star   = float(best_val["thr_dir"])

print("\nChosen thresholds on val_final:")
print("  thr_trade*:", thr_trade_star)
print("  thr_dir*  :", thr_dir_star)
print("  val pnl_mean:", float(best_val["pnl_mean"]), "| val trades:", int(best_val["n_trades"]))

# 6) финальная оценка на holdout (БЕЗ подбора порогов на holdout)
prob_trade_hold, er_hold = predict_probs_on_indices(m_trade_final, X_scaled_final, edge_feat, idx_test_final, CFG)
prob_dir_hold, _         = predict_probs_on_indices(m_dir_final,   X_scaled_final, edge_feat, idx_test_final, CFG)

final_metrics = two_stage_pnl_by_threshold(
    prob_trade=prob_trade_hold,
    prob_dir=prob_dir_hold,
    exit_ret=er_hold,
    thr_trade=thr_trade_star,
    thr_dir=thr_dir_star,
    cost_bps=CFG["cost_bps"],
)

print("\nFINAL HOLDOUT RESULT (fixed thresholds from val_final):")
print("  pnl_mean :", final_metrics["pnl_mean"])
print("  pnl_sum  :", final_metrics["pnl_sum"])
print("  n_trades :", final_metrics["n_trades"])
print("  trade_rate:", final_metrics["trade_rate"])
print("  sharpe (per-bar proxy):", final_metrics["pnl_sharpe"])

# (опционально) oracle на holdout — НЕ для выбора, только “потолок”
sweep_hold_oracle = sweep_thresholds(prob_trade_hold, prob_dir_hold, er_hold, CFG)
best_hold_oracle = sweep_hold_oracle.iloc[0].to_dict()
print("\n[ORACLE] best possible on holdout by sweeping thresholds (DO NOT USE for selection):")
print("  thr_trade:", best_hold_oracle["thr_trade"], "thr_dir:", best_hold_oracle["thr_dir"])
print("  pnl_mean :", best_hold_oracle["pnl_mean"], "trades:", best_hold_oracle["n_trades"])


[trade] ep 01 lr=2.00e-04 tr_loss=0.7465 va_loss=0.6877 f1m=0.525 auc=0.485 sel(va_auc)=0.485082 best=0.485082@ep01
[trade] ep 02 lr=2.00e-04 tr_loss=0.6760 va_loss=0.7058 f1m=0.536 auc=0.489 sel(va_auc)=0.488889 best=0.488889@ep02
[trade] ep 03 lr=2.00e-04 tr_loss=0.6487 va_loss=0.7180 f1m=0.477 auc=0.521 sel(va_auc)=0.520979 best=0.520979@ep03
[trade] ep 04 lr=2.00e-04 tr_loss=0.6296 va_loss=0.7273 f1m=0.499 auc=0.545 sel(va_auc)=0.545221 best=0.545221@ep04
[trade] ep 05 lr=2.00e-04 tr_loss=0.6267 va_loss=0.7286 f1m=0.481 auc=0.545 sel(va_auc)=0.544988 best=0.545221@ep04
[trade] ep 06 lr=2.00e-04 tr_loss=0.6248 va_loss=0.7161 f1m=0.474 auc=0.547 sel(va_auc)=0.546931 best=0.546931@ep06
[trade] ep 07 lr=2.00e-04 tr_loss=0.6215 va_loss=0.6990 f1m=0.540 auc=0.540 sel(va_auc)=0.539705 best=0.546931@ep06
[trade] ep 08 lr=2.00e-04 tr_loss=0.6249 va_loss=0.7194 f1m=0.489 auc=0.543 sel(va_auc)=0.542735 best=0.546931@ep06
[trade] ep 09 lr=2.00e-04 tr_loss=0.6198 va_loss=0.7184 f1m=0.481 auc=0.