# Neural Granger Crypto Analysis

This notebook:

- Loads processed hourly features from `data/processed/features_<SYMBOL>.parquet` for BTCUSDT, ETHUSDT, and SOLUSDT.
- Intersects the common time range, builds a multi-asset panel, and constructs lagged design matrices for multiple horizons.
- Runs backtests over predefined train/test windows with seeds and sparsity sweeps.
- Trains or reloads checkpoints for models: Last, VARX-LASSO, compact LSTM, NeuralGVAR (gated), and StaticGVAR.
- Computes test metrics (MSE, MAE, SignHit) per run and aggregates them into `metrics_summary.csv`/`.tex`.
- Builds hypothesis tables from `edge_hypothesis_summary`, plus `h1_summary`, `h4a_summary`, and `event_conditioning_summary` when those are computed.
- Exports paper_assets into `paper_assets`, including:
  - `metrics_summary.csv` and `metrics_summary.tex`
  - `edge_key_edges_table.tex` and `key_edges_sign_table.tex` (key edge tables)
  - `mse_*.png` / `mae_*.png` bar plots by task and horizon
  - `h1_cond_*.png` and `h4a_cond_*.png` conditioning plots (if summaries exist)
  - `edge_signfrac_keyedges.png` (sign-fraction vs horizon for key edges)


# Crypto Neural-GVAR Robust Analysis Notebook

In [1]:
# 0) Setup / Config
import os, math, random, warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.linear_model import MultiTaskLasso
from sklearn.preprocessing import StandardScaler

from pathlib import Path

FAST_MODE = False   # True for quick debug, False for paper runs
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PRINT_EVERY = 1

DATA_DIR = Path("data")

# --- Checkpointing (set True to save trained models so you can reload without re-training)
SAVE_CHECKPOINTS = True
from pathlib import Path
# Checkpoints: saved per-model into separate subfolders for easier tracking.
CHECKPOINT_ROOT = Path('./checkpoints')  
MODEL_CKPT_DIRS = {
    'NeuralGVAR': CHECKPOINT_ROOT / 'NeuralGVAR',
    'StaticGVAR': CHECKPOINT_ROOT / 'StaticGVAR',
    'VARX_LASSO': CHECKPOINT_ROOT / 'VARX_LASSO',
    'LSTM': CHECKPOINT_ROOT / 'LSTM',
    'Last': CHECKPOINT_ROOT / 'Last',
}
# Ensure a stable variable CHECKPOINT_DIR is available for older code paths
CHECKPOINT_DIR = CHECKPOINT_ROOT
if SAVE_CHECKPOINTS:
    for _d in MODEL_CKPT_DIRS.values():
        _d.mkdir(parents=True, exist_ok=True)
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
PROCESSED_DIR = DATA_DIR / "processed"

SYMBOLS = ["BTCUSDT", "ETHUSDT", "SOLUSDT"]
ASSET_IDX = {s:i for i,s in enumerate(SYMBOLS)}


RET_COL  = "ret_1h"
VOL_COL  = "realized_vol"
DIV_COL  = "rsi_div"
FUND_COL = "fundingRate"

LAGS_HOURS = [1, 3, 6, 12]
HORIZONS_H = [1, 4, 12, 24] if not FAST_MODE else [1, 12]

BACKTEST_WINDOWS = [
    ("2021-01-01", "2022-12-31 23:00", "2023-01-01", "2023-12-31 23:00"),
    ("2021-01-01", "2023-12-31 23:00", "2024-01-01", "2024-12-31 23:00"),
    ("2022-01-01", "2024-12-31 23:00", "2025-01-01", "2025-12-25 23:00"),
]
if FAST_MODE:
    BACKTEST_WINDOWS = [("2021-01-01", "2023-12-31 23:00", "2024-01-01", "2024-12-31 23:00")]

SEEDS = [1,2,3,4,5] if not FAST_MODE else [1,2]
LAMBDA_SPARSE_SWEEP = [1e-4, 5e-4, 1e-3] if not FAST_MODE else [5e-4]

EPOCHS = 25 if not FAST_MODE else 10
BATCH_SIZE = 256
LR = 3e-3

# Bootstrap / event settings
BOOT_BLOCK = 48
BOOT_N = 800 if not FAST_MODE else 200

EVENT_Q = 0.99
MATCH_BINS = 10
EVENT_PRE_H = 6
EVENT_POST_H = 6

print("DEVICE:", DEVICE, "| FAST_MODE:", FAST_MODE)
print("Horizons:", HORIZONS_H, "| windows:", len(BACKTEST_WINDOWS), "| seeds:", len(SEEDS), "| lambdas:", len(LAMBDA_SPARSE_SWEEP))

DEVICE: cpu | FAST_MODE: False
Horizons: [1, 4, 12, 24] | windows: 3 | seeds: 5 | lambdas: 3


## 1) Load features and trim to intersection window

In [5]:

def find_feature_file(symbol: str) -> Path:
    candidates = sorted(PROCESSED_DIR.glob(f"features_{symbol}.parquet"))
    if not candidates:
        raise FileNotFoundError(f"No parquet found for {symbol} in {PROCESSED_DIR}")
    candidates = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)
    return candidates[0]

def load_features(symbol: str) -> pd.DataFrame:
    p = find_feature_file(symbol)
    df = pd.read_parquet(p)
    if not isinstance(df.index, pd.DatetimeIndex):
        df.index = pd.to_datetime(df.index, utc=True)
    if df.index.tz is None:
        df.index = df.index.tz_localize("UTC")
    return df.sort_index()

def require_columns(df: pd.DataFrame, cols: List[str]) -> None:
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise KeyError(f"Missing columns {missing}. Example cols: {list(df.columns)[:40]}")

feature_dfs: Dict[str, pd.DataFrame] = {sym: load_features(sym) for sym in SYMBOLS}
for sym, df in feature_dfs.items():
    require_columns(df, [RET_COL, VOL_COL, DIV_COL, FUND_COL])

required_cols_all = [RET_COL, VOL_COL, DIV_COL, FUND_COL]
first_valid, last_valid = {}, {}
for sym, df in feature_dfs.items():
    mask = df[required_cols_all].notna().all(axis=1)
    first_valid[sym] = df.index[mask].min()
    last_valid[sym]  = df.index[mask].max()

intersection_start = max(first_valid.values())
intersection_end   = min(last_valid.values())

print("First valid per symbol:", first_valid)
print("Intersection:", intersection_start, "->", intersection_end)

for sym in feature_dfs:
    feature_dfs[sym] = feature_dfs[sym].loc[intersection_start:intersection_end]

panel = None
for sym, df in feature_dfs.items():
    tmp = df[[RET_COL, VOL_COL, DIV_COL, FUND_COL]].copy()
    tmp.columns = pd.MultiIndex.from_product([[sym], tmp.columns])
    panel = tmp if panel is None else panel.join(tmp, how="outer")

panel = panel.dropna()
panel.shape, panel.index.min(), panel.index.max()


First valid per symbol: {'BTCUSDT': Timestamp('2020-01-02 00:00:00+0000', tz='UTC'), 'ETHUSDT': Timestamp('2020-01-02 00:00:00+0000', tz='UTC'), 'SOLUSDT': Timestamp('2020-09-15 07:00:00+0000', tz='UTC')}
Intersection: 2020-09-15 07:00:00+00:00 -> 2025-12-25 23:00:00+00:00


((46265, 12),
 Timestamp('2020-09-15 07:00:00+0000', tz='UTC'),
 Timestamp('2025-12-25 23:00:00+0000', tz='UTC'))

## 2) Design matrix builder (lags + horizon)

In [6]:

@dataclass
class TaskSpec:
    name: str
    y_cols: List[Tuple[str,str]]
    x_cols: List[Tuple[str,str]]
    horizon_h: int

def make_design_matrix(df: pd.DataFrame, y_cols, x_cols, lags_h, horizon_h):
    y = df[y_cols].copy()
    x = df[x_cols].copy()

    lagged = []
    for k in lags_h:
        shifted = y.shift(k)
        shifted.columns = pd.MultiIndex.from_tuples([(a, f"{b}_lag{k}") for (a,b) in shifted.columns])
        lagged.append(shifted)
    Xy = pd.concat(lagged, axis=1)

    Xx = x.copy()
    Xx.columns = pd.MultiIndex.from_tuples([(a, f"{b}_t") for (a,b) in Xx.columns])

    Y = y.shift(-horizon_h)
    Y.columns = pd.MultiIndex.from_tuples([(a, f"{b}_t+{horizon_h}") for (a,b) in Y.columns])

    full = pd.concat([Xy, Xx, Y], axis=1).dropna()
    return full

def split_time(df: pd.DataFrame, train_start, train_end, test_start, test_end):
    train = df.loc[pd.Timestamp(train_start, tz="UTC"):pd.Timestamp(train_end, tz="UTC")]
    test  = df.loc[pd.Timestamp(test_start, tz="UTC"):pd.Timestamp(test_end, tz="UTC")]
    return train, test

RET_TASKS, VOL_TASKS = [], []
for h in HORIZONS_H:
    RET_TASKS.append(TaskSpec(
        name=f"RET_h{h}",
        y_cols=[(s, RET_COL) for s in SYMBOLS],
        x_cols=[(s, DIV_COL) for s in SYMBOLS] + [(s, FUND_COL) for s in SYMBOLS],
        horizon_h=h
    ))
    VOL_TASKS.append(TaskSpec(
        name=f"VOL_h{h}",
        y_cols=[(s, VOL_COL) for s in SYMBOLS],
        x_cols=[(s, DIV_COL) for s in SYMBOLS] + [(s, FUND_COL) for s in SYMBOLS],
        horizon_h=h
    ))


## 3) Metrics + block bootstrap

In [7]:

def mse(y_true, y_pred): return float(np.mean((y_true - y_pred)**2))
def mae(y_true, y_pred): return float(np.mean(np.abs(y_true - y_pred)))
def sign_hit(y_true, y_pred): return float(np.mean(np.sign(y_true) == np.sign(y_pred)))

def block_bootstrap_ci(series: np.ndarray, stat_fn, block_size: int, n_boot: int, alpha: float = 0.05, seed: int = 0):
    rng = np.random.default_rng(seed)
    n = len(series)
    if n < block_size:
        return np.nan, np.nan
    n_blocks = int(np.ceil(n / block_size))
    stats=[]
    for _ in range(n_boot):
        idx=[]
        for _b in range(n_blocks):
            start = rng.integers(0, n - block_size + 1)
            idx.extend(range(start, start+block_size))
        idx = np.asarray(idx[:n])
        stats.append(stat_fn(series[idx]))
    return float(np.quantile(stats, alpha/2)), float(np.quantile(stats, 1-alpha/2))


## 4) Baselines

In [8]:
class NumpyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X).float()
        self.y = torch.from_numpy(y).float()
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

class CompactLSTM(nn.Module):
    def __init__(self, input_dim, hidden=32, out_dim=3):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden, batch_first=True)
        self.fc = nn.Linear(hidden, out_dim)
    def forward(self, x):
        out,_ = self.lstm(x)
        return self.fc(out[:, -1, :])

def get_target_cols(y_cols, horizon_h):
    return [(sym, f"{feat}_t+{horizon_h}") for (sym, feat) in y_cols]

def run_last(full_test, y_cols, horizon_h):
    k = horizon_h if horizon_h in LAGS_HOURS else 1
    pred_cols = [(sym, f"{feat}_lag{k}") for (sym, feat) in y_cols]
    return full_test[get_target_cols(y_cols, horizon_h)].to_numpy(), full_test[pred_cols].to_numpy()

# helper to build LSTM sequences from the design-matrix style dataframe
def build_seq(full, y_cols, x_cols, lags_h, horizon_h):
    target_cols = get_target_cols(y_cols, horizon_h)
    Y = full[target_cols].to_numpy().astype(np.float32)
    steps = []
    for k in lags_h:
        step_cols = [(sym, f"{feat}_lag{k}") for (sym, feat) in y_cols]
        steps.append(full[step_cols].to_numpy().astype(np.float32))
    Xstate = full[[(sym, f"{feat}_t") for (sym, feat) in x_cols]].to_numpy().astype(np.float32)
    seq = np.stack(steps, axis=1)
    Xstate_rep = np.repeat(Xstate[:, None, :], seq.shape[1], axis=1)
    seq = np.concatenate([seq, Xstate_rep], axis=2)
    return seq, Y

def fit_varx_lasso(full_train, full_test, y_cols, horizon_h, alpha=1e-4, *,
                   return_model=False, checkpoint_path=None, meta=None):
    """VARX baseline (MultiTaskLasso) with optional checkpointing.

    Returns:
      - default: (Yte, Yhat)
      - if return_model=True: (Yte, Yhat, fitted_model, fitted_scaler)
    """
    from sklearn.linear_model import MultiTaskLasso
    from sklearn.preprocessing import StandardScaler
    import joblib, os

    target_cols = get_target_cols(y_cols, horizon_h)
    Xtr = full_train.drop(columns=target_cols).to_numpy()
    Ytr = full_train[target_cols].to_numpy()
    Xte = full_test.drop(columns=target_cols).to_numpy()
    Yte = full_test[target_cols].to_numpy()

    xs = StandardScaler()
    Xtr_s = xs.fit_transform(Xtr)
    Xte_s = xs.transform(Xte)

    model = MultiTaskLasso(alpha=alpha, fit_intercept=True, max_iter=10000)
    model.fit(Xtr_s, Ytr)
    Yhat = model.predict(Xte_s)

    if checkpoint_path is not None:
        try:
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            ckpt = {"model": model, "scaler": xs, "meta": meta or {}}
            joblib.dump(ckpt, checkpoint_path)
            print(f"[checkpoint] saved: {checkpoint_path}")
        except Exception as e:
            print(f"[checkpoint] WARNING: failed to save {checkpoint_path}: {e}")

    if return_model:
        return Yte, Yhat, model, xs
    return Yte, Yhat

def fit_compact_lstm(full_train, full_test, y_cols, x_cols, lags_h, horizon_h,
                     epochs=12, lr=3e-3, seed=0, *,
                     return_model=False, checkpoint_path=None, meta=None, hidden=32):
    """Compact LSTM baseline with optional checkpointing.

    Notes:
      - Expects the same design-matrix construction as the rest of the notebook.
      - Checkpoint stores model weights + minimal config to rebuild the module.
    """
    torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)

    target_cols = get_target_cols(y_cols, horizon_h)
    Xtr_seq, Ytr = build_seq(full_train, y_cols, x_cols, lags_h, horizon_h)
    Xte_seq, Yte = build_seq(full_test,  y_cols, x_cols, lags_h, horizon_h)

    input_dim = Xtr_seq.shape[-1]
    out_dim = Ytr.shape[-1]

    model = CompactLSTM(input_dim=input_dim, hidden=hidden, out_dim=out_dim).to(DEVICE)
    dl = DataLoader(NumpyDataset(Xtr_seq, Ytr), batch_size=256, shuffle=True)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    model.train()
    for ep in range(epochs):
        total=0.0
        for xb,yb in dl:
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad()
            pred = model(xb)
            loss = loss_fn(pred, yb)
            loss.backward()
            opt.step()
            total += loss.item()*len(xb)
        if (ep+1) % PRINT_EVERY == 0:
            print(f"LSTM ep {ep+1}/{epochs} loss={total/len(dl.dataset):.6f}")

    model.eval()
    with torch.no_grad():
        Yhat = model(torch.from_numpy(Xte_seq).float().to(DEVICE)).cpu().numpy()

    if checkpoint_path is not None:
        try:
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            ckpt = {
                "state_dict": model.state_dict(),
                "config": {"input_dim": int(input_dim), "hidden": int(hidden), "out_dim": int(out_dim)},
                "meta": meta or {},
            }
            torch.save(ckpt, checkpoint_path)
            print(f"[checkpoint] saved: {checkpoint_path}")
        except Exception as e:
            print(f"[checkpoint] WARNING: failed to save {checkpoint_path}: {e}")

    if return_model:
        return Yte, Yhat, model
    return Yte, Yhat

## 5) Neural GVAR (gated + static) â€” raw-scale metrics

In [9]:

class NeuralGVAR(nn.Module):
    def __init__(self, d, x_dim, lags, hidden=64, use_gating=True, use_concept=True):
        super().__init__()
        self.d=d; self.lags=lags; self.K=len(lags)
        self.use_gating=use_gating
        self.S = nn.Parameter(torch.randn(self.K,d,d)*0.01)
        self.gate = nn.Sequential(nn.Linear(x_dim, hidden), nn.ReLU(), nn.Linear(hidden, self.K*d*d)) if use_gating else None
        self.Wc = nn.Linear(x_dim, d, bias=False) if use_concept else None
        self.bias = nn.Parameter(torch.zeros(d))

    def forward(self, y_lags, x_t):
        B = y_lags.shape[0]
        if self.use_gating:
            g = self.gate(x_t).view(B,self.K,self.d,self.d)
            A = self.S[None,:,:,:] * torch.sigmoid(g)
        else:
            A = self.S[None,:,:,:].expand(B,-1,-1,-1)
        yhat = torch.zeros(B,self.d, device=y_lags.device)
        for ki in range(self.K):
            yk = y_lags[:,ki,:].unsqueeze(-1)
            yhat = yhat + torch.matmul(A[:,ki,:,:], yk).squeeze(-1)
        if self.Wc is not None:
            yhat = yhat + self.Wc(x_t)
        yhat = yhat + self.bias
        return yhat, A

def group_lasso_edges(S):
    return torch.sqrt(torch.sum(S**2, dim=0) + 1e-12).sum()

def temporal_smoothness(A_seq):
    return torch.mean((A_seq[1:] - A_seq[:-1])**2)

def fit_neural_gvar(full_train, full_test, y_cols, x_cols, horizon_h, use_gating=True, lam_sparse=5e-4, lam_smooth=1e-4, seed=0, checkpoint_path=None):
    torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)

    target_cols = get_target_cols(y_cols, horizon_h)
    x_state_cols = [(sym, f"{feat}_t") for (sym, feat) in x_cols]

    Ytr_lags=[]
    for k in LAGS_HOURS:
        Ytr_lags.append(full_train[[(sym, f"{feat}_lag{k}") for (sym, feat) in y_cols]].to_numpy().astype(np.float32))
    Ytr_lags = np.stack(Ytr_lags, axis=1)
    Ytr = full_train[target_cols].to_numpy().astype(np.float32)
    Xtr = full_train[x_state_cols].to_numpy().astype(np.float32)

    Yte_lags=[]
    for k in LAGS_HOURS:
        Yte_lags.append(full_test[[(sym, f"{feat}_lag{k}") for (sym, feat) in y_cols]].to_numpy().astype(np.float32))
    Yte_lags = np.stack(Yte_lags, axis=1)
    Yte = full_test[target_cols].to_numpy().astype(np.float32)
    Xte = full_test[x_state_cols].to_numpy().astype(np.float32)

    Xmu, Xsd = Xtr.mean(0, keepdims=True), Xtr.std(0, keepdims=True) + 1e-6
    Xtr_s, Xte_s = (Xtr-Xmu)/Xsd, (Xte-Xmu)/Xsd

    Lmu = Ytr_lags.reshape(-1, Ytr_lags.shape[-1]).mean(0, keepdims=True)
    Lsd = Ytr_lags.reshape(-1, Ytr_lags.shape[-1]).std(0, keepdims=True) + 1e-6
    Ytr_lags_s, Yte_lags_s = (Ytr_lags-Lmu)/Lsd, (Yte_lags-Lmu)/Lsd

    y_mu, y_sd = Ytr.mean(0, keepdims=True), Ytr.std(0, keepdims=True) + 1e-6
    Ytr_s, Yte_s = (Ytr-y_mu)/y_sd, (Yte-y_mu)/y_sd

    d = Ytr.shape[1]; x_dim = Xtr.shape[1]
    hidden_dim = 64 if not FAST_MODE else 48
    model = NeuralGVAR(d, x_dim, LAGS_HOURS, hidden=hidden_dim, use_gating=use_gating, use_concept=True).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.MSELoss()

    idx = np.arange(len(Ytr_s))
    dl = DataLoader(idx, batch_size=BATCH_SIZE, shuffle=False)

    model.train()
    for ep in range(EPOCHS):
        total=0.0
        for b_idx in dl:
            b = b_idx.numpy()
            y_lags = torch.from_numpy(Ytr_lags_s[b]).float().to(DEVICE)
            x_t = torch.from_numpy(Xtr_s[b]).float().to(DEVICE)
            y_true = torch.from_numpy(Ytr_s[b]).float().to(DEVICE)

            opt.zero_grad()
            y_pred, A = model(y_lags, x_t)
            loss = loss_fn(y_pred, y_true)
            loss = loss + lam_sparse * group_lasso_edges(model.S)
            if lam_smooth>0 and len(b)>1:
                loss = loss + lam_smooth * temporal_smoothness(A)
            loss.backward()
            opt.step()
            total += loss.item()*len(b)

        if (ep+1) % PRINT_EVERY == 0:
            print(f"NeuralGVAR({'gated' if use_gating else 'static'}) ep {ep+1}/{EPOCHS} loss={total/len(Ytr_s):.6f}")

    model.eval()
    with torch.no_grad():
        yhat_s_list=[]; A_list=[]
        for i in range(0, len(Yte_s), 2048):
            sl = slice(i, min(i+2048, len(Yte_s)))
            y_lags = torch.from_numpy(Yte_lags_s[sl]).float().to(DEVICE)
            x_t = torch.from_numpy(Xte_s[sl]).float().to(DEVICE)
            yhat_s, A = model(y_lags, x_t)
            yhat_s_list.append(yhat_s.cpu().numpy())
            A_list.append(A.cpu().numpy())
        Yhat_s = np.concatenate(yhat_s_list, axis=0)
        A_te = np.concatenate(A_list, axis=0)

    Yhat = Yhat_s * y_sd + y_mu
    metrics = {"MSE": mse(Yte, Yhat), "MAE": mae(Yte, Yhat), "SignHitRate": sign_hit(Yte, Yhat)}

    extras = {
        "A_test": A_te,
        "Wc": model.Wc.weight.detach().cpu().numpy(),
        "test_index": full_test.index.values,
        "x_state_cols": x_state_cols,
        "test_start": full_test.index.min(),
        "test_end": full_test.index.max(),
        "horizon_h": horizon_h
    }

    # ---- Optional checkpoint save (model + scalers + metadata) ----
    if checkpoint_path is not None:
        try:
            ckpt = {
                "model_class": "NeuralGVAR",
                "model_state_dict": model.state_dict(),
                "config": {
                    "d": int(d),
                    "x_dim": int(x_dim),
                    "lags_hours": list(LAGS_HOURS),
                    "hidden": int(hidden_dim),
                    "use_gating": bool(use_gating),
                    "use_concept": True,
                },
                "scalers": {
                    "Xmu": Xmu, "Xsd": Xsd,
                    "Lmu": Lmu, "Lsd": Lsd,
                    "y_mu": y_mu, "y_sd": y_sd,
                },
                "meta": {
                    "y_cols": y_cols,
                    "x_cols": x_cols,
                    "x_state_cols": x_state_cols,
                    "horizon_h": int(horizon_h),
                    "seed": int(seed),
                    "lam_sparse": float(lam_sparse),
                    "lam_smooth": float(lam_smooth),
                    "train_start": str(full_train.index.min()),
                    "train_end": str(full_train.index.max()),
                    "test_start": str(full_test.index.min()),
                    "test_end": str(full_test.index.max()),
                },
            }
            os.makedirs(str(Path(checkpoint_path).parent), exist_ok=True)
            torch.save(ckpt, checkpoint_path)
            print(f"[checkpoint] saved: {checkpoint_path}")
        except Exception as e:
            print(f"[checkpoint] WARNING: failed to save {checkpoint_path}: {e}")

    return metrics, extras


# --- StaticGVAR convenience wrapper (baseline) ---
# StaticGVAR = same architecture as NeuralGVAR but with gating disabled (A_k does not depend on x_t).
def fit_static_gvar(full_train, full_test, y_cols, x_cols, horizon_h,
                    lam_sparse=1e-3, seed=0):
    return fit_neural_gvar(full_train, full_test, y_cols, x_cols, horizon_h,
                           use_gating=False, lam_sparse=lam_sparse, lam_smooth=0.0, seed=seed)


# ---- Load helper: rebuild a (Neural/Static) GVAR from a saved checkpoint ----
def load_gvar_checkpoint(path, map_location=None):
    path = str(path)
    ckpt = torch.load(path, map_location=map_location if map_location is not None else DEVICE)
    cfg = ckpt["config"]
    model = NeuralGVAR(
        d=cfg["d"],
        x_dim=cfg["x_dim"],
        lags=cfg["lags_hours"],
        hidden=cfg["hidden"],
        use_gating=cfg["use_gating"],
        use_concept=cfg.get("use_concept", True),
    ).to(map_location if map_location is not None else DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
    return model, ckpt


## 6) Hypothesis + event utilities

In [10]:

def edge_timeseries(A_test, source_idx, target_idx):
    return A_test[:, :, target_idx, source_idx].sum(axis=1)

def edge_stats_with_ci(A_test, source_sym, target_sym, block_size=48, n_boot=400):
    e = edge_timeseries(A_test, ASSET_IDX[source_sym], ASSET_IDX[target_sym])
    sf = float(np.mean(e > 0))
    med = float(np.median(e))
    sf_lo, sf_hi = block_bootstrap_ci(np.asarray(e), lambda z: np.mean(z > 0), block_size, n_boot=n_boot, seed=0)
    med_lo, med_hi = block_bootstrap_ci(np.asarray(e), lambda z: np.median(z), block_size, n_boot=n_boot, seed=1)
    return sf, sf_lo, sf_hi, med, med_lo, med_hi

def diag_concept_weight(extras, concept_feat_base: str, asset: str) -> float:
    Wc = extras["Wc"]
    cols = extras["x_state_cols"]
    col_index = {c:i for i,c in enumerate(cols)}
    key = (asset, f"{concept_feat_base}_t")
    if key not in col_index:
        return np.nan
    return float(Wc[ASSET_IDX[asset], col_index[key]])

def conditioning_delta_on_test(test_df: pd.DataFrame, horizon_h: int, concept_feat: str, y_feat: str, q: float, mode: str):
    out={}
    for a in SYMBOLS:
        concept = test_df[(a, concept_feat)].astype(float)
        y_next  = test_df[(a, y_feat)].shift(-horizon_h).astype(float)
        df = pd.concat([concept, y_next], axis=1).dropna()
        if len(df) == 0:
            out[a] = np.nan
            continue
        if mode == "abs":
            thr = df.iloc[:,0].abs().quantile(q)
            mask = df.iloc[:,0].abs() >= thr
        else:
            thr = df.iloc[:,0].quantile(q)
            mask = df.iloc[:,0] >= thr
        out[a] = float(df.loc[mask].iloc[:,1].mean() - df.iloc[:,1].mean())
    return out

def define_events(series: pd.Series, q=0.99, mode="abs"):
    if mode=="abs":
        thr = series.abs().quantile(q)
        ev = series.index[series.abs()>=thr]
    else:
        thr = series.quantile(q)
        ev = series.index[series>=thr]
    return pd.DatetimeIndex(ev), float(thr)

def event_pre_post_delta(e_series: pd.Series, events: pd.DatetimeIndex, pre=6, post=6):
    deltas=[]
    for t in events:
        pre_win = e_series.loc[t - pd.Timedelta(hours=pre): t - pd.Timedelta(hours=1)]
        post_win = e_series.loc[t + pd.Timedelta(hours=1): t + pd.Timedelta(hours=post)]
        if len(pre_win)==0 or len(post_win)==0:
            continue
        deltas.append(post_win.abs().mean() - pre_win.abs().mean())
    return np.array(deltas)

def matched_control_events(events, match_stat: pd.Series, exclude, n_bins=10, seed=0):
    rng = np.random.default_rng(seed)
    ms = match_stat.dropna()
    excl=set()
    for t in exclude:
        for h in range(-EVENT_PRE_H, EVENT_POST_H+1):
            excl.add(t + pd.Timedelta(hours=h))
    ms_valid = ms.loc[~ms.index.isin(list(excl))]
    if len(ms_valid) < 10:
        return pd.DatetimeIndex([])
    bins = pd.qcut(ms_valid, q=n_bins, duplicates="drop")
    bin_by_time = pd.Series(bins.astype(str).values, index=ms_valid.index)

    controls=[]
    for t in events:
        if t not in ms.index:
            continue
        nearest = ms_valid.index.get_indexer([t], method="nearest")[0]
        b = bin_by_time.iloc[nearest]
        cand = bin_by_time[bin_by_time==b].index
        if len(cand)==0:
            continue
        controls.append(rng.choice(cand))
    return pd.DatetimeIndex(controls)

def placebo_shift(events, hours=-6):
    return pd.DatetimeIndex([t + pd.Timedelta(hours=hours) for t in events])

def random_events(index: pd.DatetimeIndex, n: int, seed=0):
    rng = np.random.default_rng(seed)
    idx = rng.choice(np.arange(len(index)), size=min(n, len(index)), replace=False)
    return pd.DatetimeIndex(index[idx])


## 7) Run experiments

In [11]:

def load_varx_checkpoint(path):
    import joblib
    ckpt = joblib.load(path)
    return ckpt["model"], ckpt["scaler"], ckpt.get("meta", {})

def load_lstm_checkpoint(path, map_location="cpu"):
    ckpt = torch.load(path, map_location=map_location)
    cfg = ckpt["config"]
    m = CompactLSTM(input_dim=cfg["input_dim"], hidden=cfg["hidden"], out_dim=cfg["out_dim"]).to(DEVICE)
    m.load_state_dict(ckpt["state_dict"])
    m.eval()
    return m, ckpt.get("meta", {})

def _compat_lags(meta_or_cfg, expected_lags):
    lags = None
    if isinstance(meta_or_cfg, dict):
        lags = meta_or_cfg.get("lags_hours") or meta_or_cfg.get("lags")
    if lags is None:
        return True
    return list(lags) == list(expected_lags)

def _compat_horizon(meta_or_cfg, expected_h):
    h = None
    if isinstance(meta_or_cfg, dict):
        h = meta_or_cfg.get("horizon_h") or meta_or_cfg.get("h")
    if h is None:
        return True
    return int(h) == int(expected_h)

def _predict_varx_from_checkpoint(path, full_test, y_cols, horizon_h):
    model, scaler, meta = load_varx_checkpoint(path)
    if not _compat_horizon(meta, horizon_h):
        raise ValueError(f"VARX checkpoint horizon mismatch: {meta.get('horizon_h', meta.get('h'))} != {horizon_h}")
    target_cols = get_target_cols(y_cols, horizon_h)
    Xte = full_test.drop(columns=target_cols).to_numpy()
    Yte = full_test[target_cols].to_numpy()
    Xte_s = scaler.transform(Xte)
    Yhat = model.predict(Xte_s)
    return Yte, Yhat

def _predict_lstm_from_checkpoint(path, full_test, y_cols, x_cols, lags_h, horizon_h):
    model, meta = load_lstm_checkpoint(path, map_location=DEVICE)
    if not _compat_horizon(meta, horizon_h):
        raise ValueError(f"LSTM checkpoint horizon mismatch: {meta.get('horizon_h', meta.get('h'))} != {horizon_h}")
    Xte_seq, Yte = build_seq(full_test, y_cols, x_cols, lags_h, horizon_h)
    with torch.no_grad():
        Yhat = model(torch.from_numpy(Xte_seq).float().to(DEVICE)).cpu().numpy()
    return Yte, Yhat

def _predict_gvar_from_checkpoint(path, full_test, y_cols, x_cols, horizon_h):
    model, ckpt = load_gvar_checkpoint(path, map_location=DEVICE)
    cfg = ckpt.get("config", {})
    meta = ckpt.get("meta", {})
    if not _compat_lags(cfg, LAGS_HOURS):
        raise ValueError(f"GVAR checkpoint lag mismatch: {cfg.get('lags_hours')} != {list(LAGS_HOURS)}")
    if not _compat_horizon(meta, horizon_h):
        raise ValueError(f"GVAR checkpoint horizon mismatch: {meta.get('horizon_h', meta.get('h'))} != {horizon_h}")
    scalers = ckpt["scalers"]

    target_cols = get_target_cols(y_cols, horizon_h)
    x_state_cols = [(sym, f"{feat}_t") for (sym, feat) in x_cols]

    Yte_lags = []
    for k in LAGS_HOURS:
        Yte_lags.append(full_test[[(sym, f"{feat}_lag{k}") for (sym, feat) in y_cols]].to_numpy().astype(np.float32))
    Yte_lags = np.stack(Yte_lags, axis=1)
    Yte = full_test[target_cols].to_numpy().astype(np.float32)
    Xte = full_test[x_state_cols].to_numpy().astype(np.float32)

    Xte_s = (Xte - scalers["Xmu"]) / scalers["Xsd"]
    Yte_lags_s = (Yte_lags - scalers["Lmu"]) / scalers["Lsd"]

    model.eval()
    with torch.no_grad():
        yhat_s_list = []
        A_list = []
        for i in range(0, len(Yte_lags_s), 2048):
            sl = slice(i, min(i + 2048, len(Yte_lags_s)))
            y_lags = torch.from_numpy(Yte_lags_s[sl]).float().to(DEVICE)
            x_t = torch.from_numpy(Xte_s[sl]).float().to(DEVICE)
            yhat_s, A = model(y_lags, x_t)
            yhat_s_list.append(yhat_s.cpu().numpy())
            A_list.append(A.cpu().numpy())
        Yhat_s = np.concatenate(yhat_s_list, axis=0)
        A_te = np.concatenate(A_list, axis=0)

    Yhat = Yhat_s * scalers["y_sd"] + scalers["y_mu"]

    extras = {
        "A_test": A_te,
        "Wc": model.Wc.weight.detach().cpu().numpy() if model.Wc is not None else None,
        "test_index": full_test.index.values,
        "x_state_cols": x_state_cols,
        "test_start": full_test.index.min(),
        "test_end": full_test.index.max(),
        "horizon_h": horizon_h,
    }
    return Yte, Yhat, extras

def run_one(task: TaskSpec, tr_s, tr_e, te_s, te_e, seed, lam_sparse):
    full = make_design_matrix(panel, task.y_cols, task.x_cols, LAGS_HOURS, task.horizon_h)
    tr, te = split_time(full, tr_s, tr_e, te_s, te_e)
    if len(tr) < 2000 or len(te) < 500:
        raise ValueError(f"too little data: train={len(tr)} test={len(te)}")

    # ---- checkpoint paths (optional; per-model folders) ----
    ckpt_paths = {}
    meta = {
        "task": task.name,
        "h": int(task.horizon_h),
        "train_start": str(tr_s),
        "train_end": str(tr_e),
        "test_start": str(te_s),
        "test_end": str(te_e),
        "seed": int(seed),
        "lam_sparse": float(lam_sparse),
        "lags_hours": list(LAGS_HOURS),
    }
    if 'SAVE_CHECKPOINTS' in globals() and SAVE_CHECKPOINTS:
        tr_tag = str(tr_s).replace(':','-')
        te_tag = str(te_s).replace(':','-')
        run_id = f"{task.name}_h{task.horizon_h}_tr{tr_tag}_te{te_tag}_seed{seed}_lam{lam_sparse:.1e}"

        ckpt_paths["Last"] = str(MODEL_CKPT_DIRS["Last"] / f"{run_id}.json")
        ckpt_paths["VARX-LASSO"] = str(MODEL_CKPT_DIRS["VARX_LASSO"] / f"{run_id}.pkl")
        ckpt_paths["LSTM"] = str(MODEL_CKPT_DIRS["LSTM"] / f"{run_id}.pt")
        ckpt_paths["NeuralGVAR"] = str(MODEL_CKPT_DIRS["NeuralGVAR"] / f"{run_id}.pt")
        ckpt_paths["StaticGVAR"] = str(MODEL_CKPT_DIRS["StaticGVAR"] / f"{run_id}.pt")

    # ---- Last baseline (deterministic) ----
    Y_last, Yhat_last = run_last(te, task.y_cols, task.horizon_h)
    m_last = {"MSE": mse(Y_last, Yhat_last), "MAE": mae(Y_last, Yhat_last), "SignHitRate": sign_hit(Y_last, Yhat_last)}
    if "Last" in ckpt_paths:
        try:
            import json, os
            os.makedirs(os.path.dirname(ckpt_paths["Last"]), exist_ok=True)
            with open(ckpt_paths["Last"], "w") as f:
                json.dump({"type": "Last", "meta": meta, "metrics": m_last}, f, indent=2)
            print(f"[checkpoint] saved: {ckpt_paths['Last']}")
        except Exception as e:
            print(f"[checkpoint] WARNING: failed to save {ckpt_paths['Last']}: {e}")

    # ---- VARX-LASSO baseline ----
    varx_ckpt = ckpt_paths.get("VARX-LASSO")
    if varx_ckpt and Path(varx_ckpt).exists():
        print(f"[checkpoint] loading: {varx_ckpt}")
        try:
            Y_varx, Yhat_varx = _predict_varx_from_checkpoint(varx_ckpt, te, task.y_cols, task.horizon_h)
        except ValueError as e:
            print(f"[checkpoint] warning: {e}; retraining VARX-LASSO")
            Y_varx, Yhat_varx = fit_varx_lasso(
                tr, te, task.y_cols, task.horizon_h,
                checkpoint_path=varx_ckpt,
                meta=meta,
            )
    else:
        Y_varx, Yhat_varx = fit_varx_lasso(
            tr, te, task.y_cols, task.horizon_h,
            checkpoint_path=varx_ckpt,
            meta=meta,
        )
    m_varx = {"MSE": mse(Y_varx, Yhat_varx), "MAE": mae(Y_varx, Yhat_varx), "SignHitRate": sign_hit(Y_varx, Yhat_varx)}

    # ---- LSTM baseline ----
    lstm_ckpt = ckpt_paths.get("LSTM")
    if lstm_ckpt and Path(lstm_ckpt).exists():
        print(f"[checkpoint] loading: {lstm_ckpt}")
        try:
            Y_lstm, Yhat_lstm = _predict_lstm_from_checkpoint(lstm_ckpt, te, task.y_cols, task.x_cols, LAGS_HOURS, task.horizon_h)
        except ValueError as e:
            print(f"[checkpoint] warning: {e}; retraining LSTM")
            Y_lstm, Yhat_lstm = fit_compact_lstm(
                tr, te, task.y_cols, task.x_cols, LAGS_HOURS, task.horizon_h,
                epochs=12 if not FAST_MODE else 5, seed=seed,
                hidden=32 if not FAST_MODE else 24,
                checkpoint_path=lstm_ckpt,
                meta=meta,
            )
    else:
        Y_lstm, Yhat_lstm = fit_compact_lstm(
            tr, te, task.y_cols, task.x_cols, LAGS_HOURS, task.horizon_h,
            epochs=12 if not FAST_MODE else 5, seed=seed,
            hidden=32 if not FAST_MODE else 24,
            checkpoint_path=lstm_ckpt,
            meta=meta,
        )
    m_lstm = {"MSE": mse(Y_lstm, Yhat_lstm), "MAE": mae(Y_lstm, Yhat_lstm), "SignHitRate": sign_hit(Y_lstm, Yhat_lstm)}

    # ---- NeuralGVAR + StaticGVAR ----
    ng_ckpt = ckpt_paths.get("NeuralGVAR")
    if ng_ckpt and Path(ng_ckpt).exists():
        print(f"[checkpoint] loading: {ng_ckpt}")
        try:
            Y_ng, Yhat_ng, ex_ng = _predict_gvar_from_checkpoint(ng_ckpt, te, task.y_cols, task.x_cols, task.horizon_h)
            m_ng = {"MSE": mse(Y_ng, Yhat_ng), "MAE": mae(Y_ng, Yhat_ng), "SignHitRate": sign_hit(Y_ng, Yhat_ng)}
        except ValueError as e:
            print(f"[checkpoint] warning: {e}; retraining NeuralGVAR")
            m_ng, ex_ng = fit_neural_gvar(
                tr, te, task.y_cols, task.x_cols, task.horizon_h,
                use_gating=True, lam_sparse=lam_sparse, lam_smooth=1e-4, seed=seed,
                checkpoint_path=ng_ckpt,
            )
    else:
        m_ng, ex_ng = fit_neural_gvar(
            tr, te, task.y_cols, task.x_cols, task.horizon_h,
            use_gating=True, lam_sparse=lam_sparse, lam_smooth=1e-4, seed=seed,
            checkpoint_path=ng_ckpt,
        )

    static_ckpt = ckpt_paths.get("StaticGVAR")
    if static_ckpt and Path(static_ckpt).exists():
        print(f"[checkpoint] loading: {static_ckpt}")
        try:
            Y_static, Yhat_static, ex_static = _predict_gvar_from_checkpoint(static_ckpt, te, task.y_cols, task.x_cols, task.horizon_h)
            m_static = {"MSE": mse(Y_static, Yhat_static), "MAE": mae(Y_static, Yhat_static), "SignHitRate": sign_hit(Y_static, Yhat_static)}
        except ValueError as e:
            print(f"[checkpoint] warning: {e}; retraining StaticGVAR")
            m_static, ex_static = fit_neural_gvar(
                tr, te, task.y_cols, task.x_cols, task.horizon_h,
                use_gating=False, lam_sparse=lam_sparse, lam_smooth=0.0, seed=seed,
                checkpoint_path=static_ckpt,
            )
    else:
        m_static, ex_static = fit_neural_gvar(
            tr, te, task.y_cols, task.x_cols, task.horizon_h,
            use_gating=False, lam_sparse=lam_sparse, lam_smooth=0.0, seed=seed,
            checkpoint_path=static_ckpt,
        )

    return {
        "task": task.name, "h": task.horizon_h,
        "train_start": tr_s, "train_end": tr_e,
        "test_start": te_s, "test_end": te_e,
        "seed": seed, "lam_sparse": lam_sparse,
        "metrics": {"Last": m_last, "VARX-LASSO": m_varx, "LSTM": m_lstm, "NeuralGVAR": m_ng, "StaticGVAR": m_static},
        "extras": {"NeuralGVAR": ex_ng, "StaticGVAR": ex_static},
        "checkpoint_paths": ckpt_paths,
    }

results=[]
for tr_s, tr_e, te_s, te_e in BACKTEST_WINDOWS:
    for task in (RET_TASKS + VOL_TASKS):
        for lam in LAMBDA_SPARSE_SWEEP:
            for seed in SEEDS:
                print("\n====", task.name, (tr_s,tr_e,"->",te_s,te_e), "seed", seed, "lam", lam, "====")
                results.append(run_one(task, tr_s, tr_e, te_s, te_e, seed, lam))

print("Runs completed:", len(results))



==== RET_h1 ('2021-01-01', '2022-12-31 23:00', '->', '2023-01-01', '2023-12-31 23:00') seed 1 lam 0.0001 ====
[checkpoint] saved: checkpoints/Last/RET_h1_h1_tr2021-01-01_te2023-01-01_seed1_lam1.0e-04.json
[checkpoint] saved: checkpoints/VARX_LASSO/RET_h1_h1_tr2021-01-01_te2023-01-01_seed1_lam1.0e-04.pkl
LSTM ep 1/12 loss=0.000891
LSTM ep 2/12 loss=0.000152
LSTM ep 3/12 loss=0.000152
LSTM ep 4/12 loss=0.000152
LSTM ep 5/12 loss=0.000152
LSTM ep 6/12 loss=0.000151
LSTM ep 7/12 loss=0.000152
LSTM ep 8/12 loss=0.000152
LSTM ep 9/12 loss=0.000152
LSTM ep 10/12 loss=0.000152
LSTM ep 11/12 loss=0.000152
LSTM ep 12/12 loss=0.000152
[checkpoint] saved: checkpoints/LSTM/RET_h1_h1_tr2021-01-01_te2023-01-01_seed1_lam1.0e-04.pt
NeuralGVAR(gated) ep 1/25 loss=1.321641
NeuralGVAR(gated) ep 2/25 loss=1.147918
NeuralGVAR(gated) ep 3/25 loss=1.063271
NeuralGVAR(gated) ep 4/25 loss=1.025690
NeuralGVAR(gated) ep 5/25 loss=1.008858
NeuralGVAR(gated) ep 6/25 loss=1.001040
NeuralGVAR(gated) ep 7/25 loss=0.9

## 8) Aggregate into per-horizon tables (metrics, hypotheses, test-window conditioning, event conditioning)

In [12]:
# ---- metrics_summary.csv (aggregated test metrics across runs) ----
import numpy as np
import pandas as pd
import os
from pathlib import Path

# This notebook stores metrics nested per run:
#   results[i]["metrics"] == {model_name: {"MSE":..., "MAE":..., ...}, ...}
if "results" not in globals() or results is None or len(results) == 0:
    raise RuntimeError("No `results` found. Run the experiment loop first to populate `results`.")

rows = []
for r in results:
    task = r.get("task", None)
    h = int(r.get("h", -1))
    seed = r.get("seed", None)
    lam_sparse = r.get("lam_sparse", None)

    win = f'{r.get("train_start","")}-{r.get("train_end","")}->' \
          f'{r.get("test_start","")}-{r.get("test_end","")}'

    metrics = r.get("metrics", {}) or {}
    for model, md in metrics.items():
        if md is None:
            continue
        rows.append({
            "task": task,
            "h": h,
            "model": str(model),
            "MSE": float(md.get("MSE", np.nan)),
            "MAE": float(md.get("MAE", np.nan)),
            "SignHitRate": float(md.get("SignHitRate", np.nan)) if ("SignHitRate" in md) else np.nan,
            "seed": seed,
            "lam_sparse": lam_sparse,
            "window": win,
        })

met_df = pd.DataFrame(rows)
met_df = met_df.dropna(subset=["task", "h", "model", "MSE", "MAE"]).copy()

# Aggregate across runs (mean Â± std, and n)
metrics_summary = (
    met_df.groupby(["task", "h", "model"], as_index=False)
          .agg(
              MSE_mean=("MSE", "mean"),
              MSE_std=("MSE", "std"),
              MAE_mean=("MAE", "mean"),
              MAE_std=("MAE", "std"),
              SignHit_mean=("SignHitRate", "mean"),
              SignHit_std=("SignHitRate", "std"),
              n=("MSE", "count"),
          )
)

# std is NaN when n==1; replace with 0 for plotting
for col in ["MSE_std", "MAE_std", "SignHit_std"]:
    if col in metrics_summary.columns:
        metrics_summary[col] = metrics_summary[col].fillna(0.0)

# Stable ordering in tables/plots
model_order = {"VARX-LASSO": 0, "StaticGVAR": 1, "NeuralGVAR": 2, "LSTM": 3, "Last": 4}
metrics_summary["model_rank"] = metrics_summary["model"].map(model_order).fillna(99).astype(int)
metrics_summary = (
    metrics_summary.sort_values(["task", "h", "model_rank"])
                   .drop(columns=["model_rank"])
                   .reset_index(drop=True)
)

# Output locations
PAPER_ASSETS = globals().get("PAPER_ASSETS", "paper_assets")
Path(PAPER_ASSETS).mkdir(parents=True, exist_ok=True)

out1 = os.path.join(PAPER_ASSETS, "metrics_summary.csv")

metrics_summary.to_csv(out1, index=False)

# Sanity check: do we have all expected models for every (task, h)?
expected = set(model_order.keys())
missing = []
for (t, hh), g in metrics_summary.groupby(["task", "h"]):
    present = set(g["model"].astype(str))
    miss = sorted(expected - present)
    if miss:
        missing.append((t, int(hh), miss))

print(f"Saved metrics_summary to: {out1}")

if missing:
    print("WARNING: missing models for some (task, h):")
    for t, hh, miss in missing[:30]:
        print(f"  - {t} h={hh}: missing {miss}")
else:
    print("All expected models present for each (task, h).")

metrics_summary.head(15)


Saved metrics_summary to: paper_assets/metrics_summary.csv
All expected models present for each (task, h).


Unnamed: 0,task,h,model,MSE_mean,MSE_std,MAE_mean,MAE_std,SignHit_mean,SignHit_std,n
0,RET_h1,1,VARX-LASSO,5.5e-05,4e-06,0.004635,0.000384,0.507089,0.004829,45
1,RET_h1,1,StaticGVAR,5.5e-05,4e-06,0.004642,0.000383,0.501226,0.004698,45
2,RET_h1,1,NeuralGVAR,5.5e-05,4e-06,0.004654,0.000382,0.501542,0.002444,45
3,RET_h1,1,LSTM,5.6e-05,4e-06,0.004732,0.000418,0.499789,0.006422,45
4,RET_h1,1,Last,0.000111,9e-06,0.006848,0.000534,0.481683,0.002747,45
5,RET_h12,12,VARX-LASSO,5.5e-05,4e-06,0.004637,0.000384,0.502132,0.003462,45
6,RET_h12,12,StaticGVAR,5.5e-05,4e-06,0.004647,0.000384,0.497047,0.004586,45
7,RET_h12,12,NeuralGVAR,5.5e-05,4e-06,0.004655,0.000384,0.498808,0.004004,45
8,RET_h12,12,LSTM,5.6e-05,4e-06,0.004734,0.000365,0.50057,0.004604,45
9,RET_h12,12,Last,0.000112,7e-06,0.006879,0.000521,0.495704,0.003007,45


In [13]:
# ---- Sanity checks for metrics_summary.csv ----
expected_models = ["Last", "VARX-LASSO", "StaticGVAR", "NeuralGVAR", "LSTM"]

missing = []
for (task, h), g in metrics_summary.groupby(["task", "h"]):
    present = set(g["model"])
    miss = [m for m in expected_models if m not in present]
    if miss:
        missing.append((task, h, miss))

if missing:
    print("WARNING: Missing models for some (task,h):")
    for task, h, miss in missing:
        print(f"  - {task}, h={h}: missing {miss}")
else:
    print("OK: All expected models present for each (task,h) group.")

print("\nRun counts (n) by (task,h):")
display(metrics_summary.groupby(["task", "h"])["n"].describe())

print("\nExample scale checks (MSE_mean):")
display(metrics_summary.pivot_table(index=["task","h"], columns="model", values="MSE_mean", aggfunc="first").head(10))

OK: All expected models present for each (task,h) group.

Run counts (n) by (task,h):


Unnamed: 0_level_0,Unnamed: 1_level_0,count,mean,std,min,25%,50%,75%,max
task,h,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
RET_h1,1,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0
RET_h12,12,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0
RET_h24,24,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0
RET_h4,4,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0
VOL_h1,1,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0
VOL_h12,12,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0
VOL_h24,24,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0
VOL_h4,4,5.0,45.0,0.0,45.0,45.0,45.0,45.0,45.0



Example scale checks (MSE_mean):


Unnamed: 0_level_0,model,LSTM,Last,NeuralGVAR,StaticGVAR,VARX-LASSO
task,h,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
RET_h1,1,5.6e-05,0.0001107119,5.527532e-05,5.502608e-05,5.494302e-05
RET_h12,12,5.6e-05,0.0001117733,5.515893e-05,5.50843e-05,5.492275e-05
RET_h24,24,5.6e-05,0.000109034,5.522256e-05,5.503406e-05,5.497226e-05
RET_h4,4,5.6e-05,0.0001097901,5.518056e-05,5.499949e-05,5.490185e-05
VOL_h1,1,4e-06,6.955136e-07,7.738413e-07,7.680758e-07,7.266366e-07
VOL_h12,12,8e-06,1.195155e-05,5.320602e-06,5.128999e-06,5.216912e-06
VOL_h24,24,1.2e-05,1.222126e-05,1.019587e-05,9.809325e-06,9.980701e-06
VOL_h4,4,5e-06,1.933092e-06,1.940437e-06,1.901008e-06,1.890049e-06


In [14]:
def run_event_conditioning_for_run(
    r,
    model_key="NeuralGVAR",
    edge_source="BTCUSDT",
    edge_target="ETHUSDT",
    event_mode="jump",
):
    ex = r["extras"][model_key]
    idx = pd.to_datetime(ex["test_index"])
    if getattr(idx, "tz", None) is None:
        idx = idx.tz_localize("UTC")
    idx = pd.DatetimeIndex(idx)

    # edge series from A_test
    e = edge_timeseries(ex["A_test"], ASSET_IDX[edge_source], ASSET_IDX[edge_target])
    e_series = pd.Series(e, index=idx)

    # matching series (volatility of edge_source)
    ms = panel[(edge_source, VOL_COL)].reindex(idx).astype(float)

    if event_mode == "jump":
        s = panel[(edge_source, RET_COL)].reindex(idx).astype(float)
        events, thr = define_events(s, q=EVENT_Q, mode="abs")
    else:
        s = panel[(edge_source, VOL_COL)].reindex(idx).astype(float)
        events, thr = define_events(s, q=EVENT_Q, mode="high")

    d_real = event_pre_post_delta(e_series, events, pre=EVENT_PRE_H, post=EVENT_POST_H)
    controls = matched_control_events(events, ms, exclude=events, n_bins=MATCH_BINS, seed=r["seed"])
    d_ctrl  = event_pre_post_delta(e_series, controls, pre=EVENT_PRE_H, post=EVENT_POST_H)
    d_shift = event_pre_post_delta(e_series, placebo_shift(events, hours=-6), pre=EVENT_PRE_H, post=EVENT_POST_H)
    d_rand  = event_pre_post_delta(e_series, random_events(idx, n=len(events), seed=r["seed"]), pre=EVENT_PRE_H, post=EVENT_POST_H)

    return {
        "model": model_key,
        "task": r["task"],
        "h": r["h"],
        "edge": f"{edge_source}â†’{edge_target}",
        "event_mode": event_mode,
        "events_n": int(len(events)),
        "threshold": thr,
        "real_mean": float(np.mean(d_real)) if len(d_real) else np.nan,
        "ctrl_mean": float(np.mean(d_ctrl)) if len(d_ctrl) else np.nan,
        "shift_mean": float(np.mean(d_shift)) if len(d_shift) else np.nan,
        "rand_mean": float(np.mean(d_rand)) if len(d_rand) else np.nan,
    }


ev_rows = []
for r in results:
    for model_key in ["StaticGVAR", "NeuralGVAR"]:
        if r["task"].startswith("RET"):
            ev_rows.append(run_event_conditioning_for_run(r, model_key, "BTCUSDT", "ETHUSDT", "jump"))
        if r["task"].startswith("VOL"):
            ev_rows.append(run_event_conditioning_for_run(r, model_key, "BTCUSDT", "ETHUSDT", "vol"))
            ev_rows.append(run_event_conditioning_for_run(r, model_key, "BTCUSDT", "SOLUSDT", "vol"))

ev_df = pd.DataFrame(ev_rows)

event_conditioning_summary = (
    ev_df.groupby(["model","task","h","edge","event_mode"])
         .agg(
             events_n_mean=("events_n","mean"),
             real_mean=("real_mean","mean"),
             real_std =("real_mean","std"),
             ctrl_mean=("ctrl_mean","mean"),
             ctrl_std =("ctrl_mean","std"),
             shift_mean=("shift_mean","mean"),
             shift_std =("shift_mean","std"),
             rand_mean=("rand_mean","mean"),
             rand_std =("rand_mean","std"),
             n=("real_mean","count"),
         )
         .reset_index()
         .sort_values(["task","h","edge","event_mode","model"])
)

event_conditioning_summary


Unnamed: 0,model,task,h,edge,event_mode,events_n_mean,real_mean,real_std,ctrl_mean,ctrl_std,shift_mean,shift_std,rand_mean,rand_std,n
0,NeuralGVAR,RET_h1,1,BTCUSDTâ†’ETHUSDT,jump,87.666667,0.000738,0.00302,-0.001436426,0.001986818,0.001453,0.00283,-0.0003922443,0.001722071,45
12,StaticGVAR,RET_h1,1,BTCUSDTâ†’ETHUSDT,jump,87.666667,0.0,0.0,-1.88146e-12,1.262122e-11,0.0,0.0,0.0,0.0,45
1,NeuralGVAR,RET_h12,12,BTCUSDTâ†’ETHUSDT,jump,87.666667,-0.000348,0.003409,4.666585e-05,0.002451936,0.000789,0.003516,0.00048298,0.002840973,45
13,StaticGVAR,RET_h12,12,BTCUSDTâ†’ETHUSDT,jump,87.666667,0.0,0.0,-1.763869e-13,8.738799e-13,0.0,0.0,0.0,9.129133e-12,45
2,NeuralGVAR,RET_h24,24,BTCUSDTâ†’ETHUSDT,jump,87.333333,-0.000121,0.001885,-0.0006102835,0.002379232,0.001486,0.002186,0.0005867133,0.001050947,45
14,StaticGVAR,RET_h24,24,BTCUSDTâ†’ETHUSDT,jump,87.333333,0.0,0.0,4.70365e-13,3.155304e-12,0.0,0.0,0.0,0.0,45
3,NeuralGVAR,RET_h4,4,BTCUSDTâ†’ETHUSDT,jump,87.666667,0.000141,0.001545,-0.0004916068,0.002904238,0.003138,0.003216,0.0009112988,0.002385576,45
15,StaticGVAR,RET_h4,4,BTCUSDTâ†’ETHUSDT,jump,87.666667,0.0,0.0,0.0,4.512696e-12,0.0,0.0,-4.757714e-13,2.230989e-12,45
4,NeuralGVAR,VOL_h1,1,BTCUSDTâ†’ETHUSDT,vol,87.666667,-0.000217,0.000969,6.506202e-05,0.0008575459,-0.000325,0.001743,-0.000104981,0.0006187444,45
16,StaticGVAR,VOL_h1,1,BTCUSDTâ†’ETHUSDT,vol,87.666667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,45


In [15]:

# 8.3 H1 + H4a: test-window only conditioning and Wc diagonal, per horizon
h1_rows=[]
h4a_rows=[]

for r in results:
    task = r["task"]; h = r["h"]
    te_start = pd.Timestamp(r["test_start"], tz="UTC")
    te_end   = pd.Timestamp(r["test_end"], tz="UTC")
    test_panel = panel.loc[te_start:te_end].copy()
    ex = r["extras"]["NeuralGVAR"]

    if task.startswith("RET"):
        cond = conditioning_delta_on_test(test_panel, h, concept_feat=DIV_COL, y_feat=RET_COL, q=0.99, mode="abs")
        for a in SYMBOLS:
            h1_rows.append({
                "task": task, "h": h,
                "hyp": "H1 divergenceâ†’returns (âˆ’)",
                "asset": a,
                "Wc_diag": diag_concept_weight(ex, DIV_COL, a),
                "cond_delta": cond[a],
            })

    if task.startswith("VOL"):
        cond = conditioning_delta_on_test(test_panel, h, concept_feat=FUND_COL, y_feat=VOL_COL, q=0.99, mode="high")
        for a in SYMBOLS:
            h4a_rows.append({
                "task": task, "h": h,
                "hyp": "H4a fundingâ†’volatility (+)",
                "asset": a,
                "Wc_diag": diag_concept_weight(ex, FUND_COL, a),
                "cond_delta": cond[a],
            })

h1_df = pd.DataFrame(h1_rows)
h4a_df = pd.DataFrame(h4a_rows)

h1_summary = (h1_df.groupby(["task","h","hyp","asset"])
              .agg(Wc_mean=("Wc_diag","mean"), Wc_std=("Wc_diag","std"),
                   cond_delta_mean=("cond_delta","mean"), cond_delta_std=("cond_delta","std"),
                   n=("Wc_diag","count"))
              .reset_index().sort_values(["task","asset"]))
h4a_summary = (h4a_df.groupby(["task","h","hyp","asset"])
               .agg(Wc_mean=("Wc_diag","mean"), Wc_std=("Wc_diag","std"),
                    cond_delta_mean=("cond_delta","mean"), cond_delta_std=("cond_delta","std"),
                    n=("Wc_diag","count"))
               .reset_index().sort_values(["task","asset"]))

h1_summary, h4a_summary


(       task   h                        hyp    asset   Wc_mean    Wc_std  \
 0    RET_h1   1  H1 divergenceâ†’returns (âˆ’)  BTCUSDT  0.012037  0.006983   
 1    RET_h1   1  H1 divergenceâ†’returns (âˆ’)  ETHUSDT  0.009365  0.006849   
 2    RET_h1   1  H1 divergenceâ†’returns (âˆ’)  SOLUSDT -0.005072  0.010439   
 3   RET_h12  12  H1 divergenceâ†’returns (âˆ’)  BTCUSDT  0.005904  0.003198   
 4   RET_h12  12  H1 divergenceâ†’returns (âˆ’)  ETHUSDT  0.007817  0.002457   
 5   RET_h12  12  H1 divergenceâ†’returns (âˆ’)  SOLUSDT -0.005777  0.002090   
 6   RET_h24  24  H1 divergenceâ†’returns (âˆ’)  BTCUSDT  0.001647  0.005777   
 7   RET_h24  24  H1 divergenceâ†’returns (âˆ’)  ETHUSDT -0.004433  0.004614   
 8   RET_h24  24  H1 divergenceâ†’returns (âˆ’)  SOLUSDT  0.000430  0.005702   
 9    RET_h4   4  H1 divergenceâ†’returns (âˆ’)  BTCUSDT  0.012150  0.007299   
 10   RET_h4   4  H1 divergenceâ†’returns (âˆ’)  ETHUSDT -0.004508  0.006711   
 11   RET_h4   4  H1 divergenceâ†’returns (â

## 9) Export CSV + LaTeX + PNG figures

In [21]:
# 9) Export paper_assets (tables + figures)
# Writes plots/tables into ./paper_assets and prints where they were saved.

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

assets_dir = Path("paper_assets")
assets_dir.mkdir(parents=True, exist_ok=True)

def _task_prefix(task: str) -> str:
    t = str(task).upper()
    if "RET" in t:
        return "RET"
    if "VOL" in t:
        return "VOL"
    return t.replace(" ", "_")

def _ensure_task_and_h(df: pd.DataFrame) -> pd.DataFrame:
    """Ensure columns: task (RET/VOL) and h (int). Accepts legacy task strings like RET_h1."""
    df = df.copy()
    if "task" not in df.columns:
        raise KeyError("metrics_summary must contain a 'task' column (or be created in this notebook).")
    if "h" not in df.columns:
        # legacy format: task like 'RET_h1'
        h = df["task"].astype(str).str.extract(r"h(\d+)")[0]
        if h.isna().any():
            raise KeyError("Could not parse horizon `h` from task names. Expected strings like 'RET_h1'.")
        df["h"] = h.astype(int)
        df["task"] = df["task"].astype(str).str.replace(r"_h\d+$", "", regex=True)
    else:
        df["h"] = df["h"].astype(int)
    return df

def _safe_to_latex(df: pd.DataFrame, path: Path) -> None:
    df.to_latex(path, index=False, float_format="%.6g")

# -------------------- Save summary CSV + LaTeX --------------------
if "metrics_summary" in globals():
    metrics_summary.to_csv(assets_dir / "metrics_summary.csv", index=False)
    _safe_to_latex(metrics_summary, assets_dir / "metrics_summary.tex")

for name in ["edge_hypothesis_summary", "h1_summary", "h4a_summary", "event_conditioning_summary"]:
    if name in globals():
        globals()[name].to_csv(assets_dir / f"{name}.csv", index=False)
        _safe_to_latex(globals()[name], assets_dir / f"{name}.tex")

# -------------------- Plots: baselines from metrics_summary --------------------
def plot_metric_bars(metrics_df: pd.DataFrame, task: str, h: int, metric: str) -> Path | None:
    df = metrics_df[(metrics_df["task"] == task) & (metrics_df["h"] == h)].copy()
    if df.empty:
        return None

    order = ["Last", "VARX-LASSO", "StaticGVAR", "NeuralGVAR", "LSTM"]
    if "model" in df.columns:
        df["model"] = pd.Categorical(df["model"], categories=order, ordered=True)
        df = df.sort_values("model")

    mean_col = f"{metric}_mean"
    std_col  = f"{metric}_std"
    if mean_col not in df.columns:
        return None

    means = df[mean_col].to_numpy()
    stds  = df[std_col].to_numpy() if std_col in df.columns else np.zeros_like(means)

    plt.figure(figsize=(7.0, 3.2))
    plt.bar(np.arange(len(df)), means, yerr=stds, capsize=3)
    plt.xticks(np.arange(len(df)), df["model"].astype(str).tolist(), rotation=20)
    plt.ylabel(f"Test {metric}")
    plt.title(f"{_task_prefix(task)} â€” horizon h={h}")
    plt.tight_layout()

    out = assets_dir / f"{metric.lower()}_{_task_prefix(task)}_h{h}.png"
    plt.savefig(out, dpi=200)
    plt.close()
    return out

# If the user didn't run the aggregation cell, fall back to reading an existing CSV
if "metrics_summary" in globals():
    _metrics_df = metrics_summary
else:
    _csv = Path("metrics_summary.csv")
    _metrics_df = pd.read_csv(_csv) if _csv.exists() else None

saved = []

if _metrics_df is not None and not _metrics_df.empty:
    _metrics_df = _ensure_task_and_h(_metrics_df)

    for task in sorted(_metrics_df["task"].unique()):
        for h in sorted(_metrics_df[_metrics_df["task"].eq(task)]["h"].unique()):
            h = int(h)
            for metric in ["MSE", "MAE"]:
                out = plot_metric_bars(_metrics_df, task, h, metric)
                if out is not None:
                    saved.append(out)

            # VOL often has QLIKE (if present)
            if "QLIKE_mean" in _metrics_df.columns:
                out = plot_metric_bars(_metrics_df, task, h, "QLIKE")
                if out is not None:
                    saved.append(out)

# -------------------- Plots: concept outcome conditioning (H1/H4a) --------------------
def plot_conditioning(summary_df: pd.DataFrame, task: str, h: int,
                      value_col: str, ylabel: str, title: str, fname: Path) -> Path | None:
    df = summary_df[(summary_df["task"] == task) & (summary_df["h"] == h)].copy()
    if df.empty or value_col not in df.columns:
        return None

    label_col = None
    for cand in ["target", "symbol", "var", "series", "node"]:
        if cand in df.columns:
            label_col = cand
            break
    if label_col is None:
        label_col = df.columns[0]

    labels = df[label_col].astype(str).tolist()
    vals = df[value_col].to_numpy()

    plt.figure(figsize=(6.2, 3.2))
    plt.bar(np.arange(len(vals)), vals)
    plt.axhline(0.0, linestyle="--")
    plt.xticks(np.arange(len(vals)), labels, rotation=0)
    plt.ylabel(ylabel)
    plt.title(f"{title} â€” {_task_prefix(task)} â€” h={h}")
    plt.tight_layout()
    plt.savefig(fname, dpi=200)
    plt.close()
    return fname

if "h1_summary" in globals():
    h1_summary = _ensure_task_and_h(h1_summary)
    for task in sorted(h1_summary["task"].unique()):
        for h in sorted(h1_summary[h1_summary["task"].eq(task)]["h"].unique()):
            h = int(h)
            out = assets_dir / f"h1_cond_{_task_prefix(task)}_h{h}.png"
            p = plot_conditioning(
                h1_summary, task, h,
                value_col="cond_delta_mean",
                ylabel=r"$\mathbb{E}[r_{t+h}\mid |div|\ extreme]-\mathbb{E}[r_{t+h}]$",
                title="H1 conditioning",
                fname=out
            )
            if p is not None:
                saved.append(p)

if "h4a_summary" in globals():
    h4a_summary = _ensure_task_and_h(h4a_summary)
    for task in sorted(h4a_summary["task"].unique()):
        for h in sorted(h4a_summary[h4a_summary["task"].eq(task)]["h"].unique()):
            h = int(h)
            out = assets_dir / f"h4a_cond_{_task_prefix(task)}_h{h}.png"
            p = plot_conditioning(
                h4a_summary, task, h,
                value_col="cond_delta_mean",
                ylabel=r"$\mathbb{E}[\sigma_{t+h}\mid funding\ extreme]-\mathbb{E}[\sigma_{t+h}]$",
                title="H4a conditioning",
                fname=out
            )
            if p is not None:
                saved.append(p)

print("Saved paper_assets to:", assets_dir.resolve())
if saved:
    print("Saved plots:")
    for p in sorted(set(saved)):
        print("  -", p.name)
else:
    print("No plots were saved. (Did you run the evaluation/aggregation cells above?)")


Saved paper assets to: /Users/Shana/Desktop/neural-granger-crypto/paper_assets
Saved plots:
  - h1_cond_RET_h1.png
  - h1_cond_RET_h12.png
  - h1_cond_RET_h24.png
  - h1_cond_RET_h4.png
  - h4a_cond_VOL_h1.png
  - h4a_cond_VOL_h12.png
  - h4a_cond_VOL_h24.png
  - h4a_cond_VOL_h4.png
  - mae_RET_h1.png
  - mae_RET_h12.png
  - mae_RET_h24.png
  - mae_RET_h4.png
  - mae_VOL_h1.png
  - mae_VOL_h12.png
  - mae_VOL_h24.png
  - mae_VOL_h4.png
  - mse_RET_h1.png
  - mse_RET_h12.png
  - mse_RET_h24.png
  - mse_RET_h4.png
  - mse_VOL_h1.png
  - mse_VOL_h12.png
  - mse_VOL_h24.png
  - mse_VOL_h4.png


In [23]:
import os
import pandas as pd

ASSETDIR = "paper_assets"  # must match your LaTeX \assetdir
os.makedirs(ASSETDIR, exist_ok=True)

# ---- Select the exact hypotheses used in your main text table ----
want = [
    "H2 BTCretâ†’ETHUSDTret (+)",
    "H2 BTCretâ†’SOLUSDTret (+)",
    "H4b BTCvolâ†’ETHUSDTvol (+)",
    "H4b BTCvolâ†’SOLUSDTvol (+)",
]

label_map = {
    "H2 BTCretâ†’ETHUSDTret (+)": r"BTC ret $\to$ ETH ret",
    "H2 BTCretâ†’SOLUSDTret (+)": r"BTC ret $\to$ SOL ret",
    "H4b BTCvolâ†’ETHUSDTvol (+)": r"BTC vol $\to$ ETH vol",
    "H4b BTCvolâ†’SOLUSDTvol (+)": r"BTC vol $\to$ SOL vol",
}

# If not already in memory, load the exported summary CSV
if "edge_hypothesis_summary" not in globals():
    csv_path = os.path.join(ASSETDIR, "edge_hypothesis_summary.csv")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(
            f"edge_hypothesis_summary not defined and CSV not found at {csv_path}. "
            "Run the edge aggregation cell first or export the CSV."
        )
    edge_hypothesis_summary = pd.read_csv(csv_path)

key = edge_hypothesis_summary[edge_hypothesis_summary["hyp"].isin(want)].copy()
key["Edge"] = key["hyp"].map(label_map)

# Cell formatting: sign fraction (median)
key["cell"] = key.apply(lambda r: f'{r["sign_frac_mean"]:.3f} ({r["median_mean"]:+.3f})', axis=1)

tab = (
    key.pivot(index="Edge", columns="h", values="cell")
       .reindex(columns=[1, 4, 12, 24])
)

tab.columns = [r"$h=1$", r"$h=4$", r"$h=12$", r"$h=24$"]

latex_tabular = tab.to_latex(
    index=True,
    escape=False,
    column_format="lcccc",
    bold_rows=False,
)

out_path = os.path.join(ASSETDIR, "edge_key_edges_table.tex")
with open(out_path, "w") as f:
    f.write(latex_tabular)

print("Wrote:", out_path)


Wrote: paper_assets/edge_key_edges_table.tex


In [24]:

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- config ---
assetdir = Path("paper_assets")   # match your \assetdir
assetdir.mkdir(parents=True, exist_ok=True)

# ---------------------------------------------------------------------
# Load edge_hypothesis_summary
# Option A: already computed in notebook:
#   edge_hypothesis_summary = ...
#
# Option B: load from CSV you export (recommended for robustness):
# edge_hypothesis_summary = pd.read_csv(assetdir / "edge_hypothesis_summary.csv")

df = edge_hypothesis_summary.copy()

# Expected columns (from your aggregation):
# task, h, hyp, sign_frac_mean, sign_ci_lo_mean, sign_ci_hi_mean, median_mean, med_ci_lo_mean, med_ci_hi_mean, n

# ---------------------------------------------------------------------
# Keep only the "key edges" you discuss in the main text
key_hyps = [
    "H2 BTCretâ†’ETHUSDTret (+)",
    "H2 BTCretâ†’SOLUSDTret (+)",
    "H4b BTCvolâ†’ETHUSDTvol (+)",
    "H4b BTCvolâ†’SOLUSDTvol (+)",
]

key = df[df["hyp"].isin(key_hyps)].copy()

# Map to nicer labels for the paper
label_map = {
    "H2 BTCretâ†’ETHUSDTret (+)": r"BTC ret $\rightarrow$ ETH ret",
    "H2 BTCretâ†’SOLUSDTret (+)": r"BTC ret $\rightarrow$ SOL ret",
    "H4b BTCvolâ†’ETHUSDTvol (+)": r"BTC vol $\rightarrow$ ETH vol",
    "H4b BTCvolâ†’SOLUSDTvol (+)": r"BTC vol $\rightarrow$ SOL vol",
}
key["Edge"] = key["hyp"].map(label_map)

# Ensure horizons sorted numerically
key["h"] = key["h"].astype(int)
key = key.sort_values(["Edge", "h"])

# ---------------------------------------------------------------------
# Build a wide table: each cell "signfrac (median)" or with CI if available
has_ci = {"sign_ci_lo_mean", "sign_ci_hi_mean"}.issubset(set(key.columns))

def fmt_cell(row):
    sf = row["sign_frac_mean"]
    med = row["median_mean"]
    if has_ci:
        lo = row["sign_ci_lo_mean"]
        hi = row["sign_ci_hi_mean"]
        return f"{sf:.3f} [{lo:.3f},{hi:.3f}] ({med:+.3f})"
    return f"{sf:.3f} ({med:+.3f})"

key["cell"] = key.apply(fmt_cell, axis=1)

wide = key.pivot(index="Edge", columns="h", values="cell")
wide = wide[[1,4,12,24]]  # enforce horizon order if present

# ---------------------------------------------------------------------
# Export a LaTeX table that you can \input{} to replace Table~\ref{tab:key_edges}
# Keep it simple & robust in 2-column format.
latex_lines = []
latex_lines.append(r"\begin{table}[t]")
latex_lines.append(r"\centering")
latex_lines.append(r"\small")
latex_lines.append(r"\setlength{\tabcolsep}{3.5pt}")
latex_lines.append(r"\caption{Key spillover edges by horizon. Each cell reports sign fraction (block-bootstrap 95\% CI if available) and median in parentheses, computed from the lag-aggregated adjacency $\bar A^{(h)}(t)=\sum_{k\in\{1,3,6,12\}}A_k^{(h)}(t)$ and aggregated over $n=45$ runs per horizon.}")
latex_lines.append(r"\label{tab:key_edges}")
latex_lines.append(r"\begin{tabular}{lcccc}")
latex_lines.append(r"\toprule")
latex_lines.append(r"\textbf{Edge} & $h=1$ & $h=4$ & $h=12$ & $h=24$ \\")
latex_lines.append(r"\midrule")

for edge in wide.index:
    row = wide.loc[edge]
    latex_lines.append(
        f"{edge} & {row[1]} & {row[4]} & {row[12]} & {row[24]} \\\\"
    )

latex_lines.append(r"\bottomrule")
latex_lines.append(r"\end{tabular}")
latex_lines.append(r"\end{table}")

out_table = assetdir / "key_edges_sign_table.tex"
out_table.write_text("\n".join(latex_lines))
print("Wrote:", out_table)

# ---------------------------------------------------------------------
# FIGURE: Sign fraction vs horizon for the key edges (more visually appealing)
# We'll plot mean sign fraction with CI bands if present.
fig, ax = plt.subplots(figsize=(7.5, 3.2))  # good for figure* spanning both cols

horizons = np.array([1, 4, 12, 24])

for hyp in key_hyps:
    sub = key[key["hyp"] == hyp].sort_values("h")
    y = sub["sign_frac_mean"].to_numpy()
    ax.plot(horizons, y, marker="o", linewidth=2, label=label_map[hyp])

    if has_ci:
        lo = sub["sign_ci_lo_mean"].to_numpy()
        hi = sub["sign_ci_hi_mean"].to_numpy()
        ax.fill_between(horizons, lo, hi, alpha=0.15)

ax.axhline(0.5, linestyle="--", linewidth=1)
ax.set_xticks(horizons)
ax.set_xlabel("Forecast horizon $h$ (hours)")
ax.set_ylabel("Sign fraction  (share of $t$ with edge $>0$)")
ax.set_ylim(0.0, 1.0)
ax.legend(loc="best", fontsize=8, frameon=True)
ax.set_title("Key spillover edges: sign stability across horizons (mean over $n=45$ runs/horizon)")

fig.tight_layout()
out_fig = assetdir / "edge_signfrac_keyedges.png"
fig.savefig(out_fig, dpi=200)
plt.close(fig)
print("Wrote:", out_fig)

Wrote: paper_assets/key_edges_sign_table.tex
Wrote: paper_assets/edge_signfrac_keyedges.png
