In [None]:


# -------------------- Standard library --------------------
import os
import sys
import math
import re
import time
import warnings
from datetime import datetime, timedelta
from calendar import monthrange

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# -------------------- PyTorch & PyG --------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
    # core data container for PyTorch Geometric
    from torch_geometric.data import Data
    # pick what you actually use; add more as needed
    from torch_geometric.nn import GCNConv  # e.g., GCN; swap/add TAGConv, SAGEConv, etc.
except Exception as e:
    raise ImportError(
        "torch_geometric is not available. Install the versions pinned in requirements.txt "
        "or remove GNN code paths."
    ) from e

# -------------------- Time-series / econometrics --------------------
# VAR is used for FEVD labels, etc.
from statsmodels.tsa.api import VAR

# -------------------- Data sources (market / news) --------------------
# yfinance for prices
import yfinance as yf

# news/sentiment dependencies
try:
    import requests
    from bs4 import BeautifulSoup
    import feedparser
    from dateutil import parser as dtp
    from urllib.parse import urlencode, quote_plus

    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    _TRANSFORMERS_AVAILABLE = True
except Exception:
    _TRANSFORMERS_AVAILABLE = False

# -------------------- Global display & warnings --------------------
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 200)


# 1. Data Preprocessing

This section builds the full weekly panel used throughout the dissertation. It:

- Downloads and filters daily adjusted close for the candidate ticker set, then keeps only instruments with nearly complete coverage (strict missingness and start/end tolerances).

- Aggregates to weekly (Fri) and computes features: returns, rolling vol (4w, 8w), beta(26w) vs equal-weight market, size, turnover (weekly volume ÷ 26w avg), and VIX replicated per node.

- Constructs rolling correlation graphs (104w window, |ρ| ≥ 0.6) and exports degree and eigenvector centralities as weekly features.

- Scrapes/loads headlines, scores with FinBERT, and produces weekly sentiment mean/vol/change features (shifted by +1 week to avoid leakage)

In [2]:
# ============================ Section 1: Data preprocessing ============================
from pathlib import Path
import logging
import numpy as np
import pandas as pd
import yfinance as yf
import networkx as nx
from datetime import timedelta
from typing import Tuple, Dict, List

# --- config ---
CFG = {
    "seed": 42,
    "data_dir": "data",              # root for artifacts
    "start": "2017-01-01",
    "end":   "2025-01-01",
    "weekly_freq": "W-FRI",
    "cor_win_weeks": 104,            # for graph window
    "corr_threshold": 0.60,          # |rho| >= THETA
    "beta_win": 26,                  # weeks
    "turnover_win": 26,              # weeks
    "vol_wins": [4, 8],
    "tickers": [
        "JPM","BAC","WFC","C","GS","MS","PNC","USB","BK","STT",
        "TFC","MTB","NTRS","FITB","HBAN","CMA","ZION","RF","KEY","WAL",
        "ALLY","COF","AXP","CFR",
        # collapsed/absorbed to be filtered by coverage
        "SIVB","SBNY","CS","PACW"
    ],
}

# -------------------- helpers --------------------
def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def save_parquet(df: pd.DataFrame, path: Path, log: logging.Logger) -> None:
    df.to_parquet(path)
    log.info(f"Saved {path} | shape={df.shape} | null%≈{df.isna().mean().mean():.2%}")

# -------------------- prices --------------------
def download_prices(tickers: List[str], start: str, end: str, auto_adjust: bool = True) -> pd.DataFrame:
    """
    Download daily prices via yfinance and return a wide Close matrix (index=date, columns=tickers).
    """
    data = yf.download(
        tickers, start=start, end=end,
        auto_adjust=auto_adjust, progress=False, group_by="ticker", threads=True
    )
    if isinstance(tickers, str): tickers = [tickers]

    if len(tickers) == 1:
        close = data["Close"].to_frame(tickers[0])
    else:
        cols = {}
        for t in tickers:
            try:
                cols[t] = data[t]["Close"]
            except Exception:
                cols[t] = pd.Series(dtype=float)
        close = pd.DataFrame(cols)
    return close.sort_index()

def filter_full_coverage(
    px: pd.DataFrame,
    start: str, end: str,
    max_nan_ratio: float = 0.01,
    tol_start_days: int = 10,
    tol_end_days: int = 10
) -> Tuple[List[str], Dict[str, str]]:
    """
    Keep tickers whose first/last valid dates are near the requested window and with low missingness.
    Returns (kept_tickers, dropped_reason).
    """
    start = pd.to_datetime(start)
    end = pd.to_datetime(end)
    kept, dropped = [], {}
    for t in px.columns:
        s = px[t].dropna()
        if s.empty:
            dropped[t] = "no data"; continue
        first_ok = s.index.min() <= (start + timedelta(days=tol_start_days))
        last_ok  = s.index.max() >= (end   - timedelta(days=tol_end_days))
        nan_ratio = float(px[t].isna().mean())
        if not first_ok:
            dropped[t] = f"starts too late (first={s.index.min().date()})"; continue
        if not last_ok:
            dropped[t] = f"ends early (last={s.index.max().date()})"; continue
        if nan_ratio > max_nan_ratio:
            dropped[t] = f"too many NaNs ({nan_ratio:.2%})"; continue
        kept.append(t)
    return kept, dropped

# -------------------- weekly features --------------------
def compute_weekly_core_features(
    px_daily: pd.DataFrame,
    weekly_freq: str,
    beta_win: int,
    vol_wins: List[int],
    turnover_win: int,
    log: logging.Logger
) -> Dict[str, pd.DataFrame]:
    """
    Build weekly features: returns, vols, beta vs EW market, size (log price), turnover, VIX replicated.
    """
    tickers = list(px_daily.columns)

    # Weekly close & returns
    px_w = px_daily.resample(weekly_freq).last().ffill().dropna(how="all")
    ret_w = px_w.pct_change().dropna()

    # Rolling vol
    feats = {"ret": ret_w}
    for w in vol_wins:
        feats[f"vol{w}"] = ret_w.rolling(w, min_periods=w).std()

    # Beta(26w) vs EW market
    mkt = ret_w.mean(axis=1)
    def rolling_beta(series, market, win=beta_win):
        cov = series.rolling(win, min_periods=win).cov(market)
        var = market.rolling(win, min_periods=win).var()
        return cov / var
    feats["beta26"] = ret_w.apply(lambda col: rolling_beta(col, mkt, beta_win))

    # Size proxy (log price)
    feats["size"] = np.log(px_w.replace(0, np.nan))

    # Turnover from volume
    log.info("Downloading volume for turnover…")
    raw = yf.download(tickers, start=px_daily.index.min().date(), end=px_daily.index.max().date(),
                      auto_adjust=False, progress=False)
    # normalize to wide "Volume" frame
    if isinstance(raw.columns, pd.MultiIndex):
        if "Volume" in raw.columns.levels[0]:
            vol_raw = raw["Volume"]
        else:
            # ticker-first layout
            vols = {t: raw[(t, "Volume")] if (t, "Volume") in raw.columns else pd.Series(dtype=float) for t in tickers}
            vol_raw = pd.DataFrame(vols)
    else:
        if "Volume" in raw.columns:
            colname = tickers[0] if len(tickers) == 1 else "Volume"
            v = raw["Volume"]
            vol_raw = v.to_frame(colname) if len(tickers) == 1 else v
        else:
            vol_raw = pd.DataFrame({t: pd.Series(dtype=float) for t in tickers})

    vol_raw = vol_raw.sort_index().reindex(px_daily.index).ffill(limit=3)
    vol_w = vol_raw.resample(weekly_freq).sum().reindex(px_w.index)
    turnover = (vol_w / vol_w.rolling(turnover_win, min_periods=turnover_win).mean()).replace([np.inf, -np.inf], np.nan)

    feats["turnover"] = turnover

    # VIX replicated per node
    log.info("Downloading VIX…")
    vix_close = yf.download("^VIX", start=px_daily.index.min().date(), end=px_daily.index.max().date(),
                            auto_adjust=False, progress=False)["Close"]
    vix_w = vix_close.resample(weekly_freq).last().reindex(px_w.index).ffill()
    feats["vix"] = pd.DataFrame(
        np.repeat(vix_w.values.reshape(-1,1), len(tickers), axis=1),
        index=vix_w.index, columns=tickers
    )

    return feats

# -------------------- graph centralities --------------------
def compute_centralities(
    ret_w: pd.DataFrame,
    window_weeks: int,
    theta: float,
    log: logging.Logger
) -> Dict[str, pd.DataFrame]:
    """
    For each target week t, compute degree and eigenvector centrality from the correlation
    graph built on returns over [t-window, t).
    """
    dates = ret_w.index
    tickers = list(ret_w.columns)
    deg_rows, ec_rows = [], []

    log.info(f"Computing centralities over {len(dates)} weeks (W={window_weeks}, θ={theta})…")
    for d in dates:
        t_idx = ret_w.index.get_loc(d)
        if t_idx < window_weeks:
            # not enough history yet
            zero = pd.Series({t: 0.0 for t in tickers}, name=d)
            deg_rows.append(zero.copy()); ec_rows.append(zero.copy()); continue

        win = ret_w.iloc[t_idx-window_weeks:t_idx]
        C = win.corr().abs().fillna(0.0)

        G = nx.Graph()
        G.add_nodes_from(tickers)
        # thresholded weighted edges
        for i in range(len(tickers)):
            for j in range(i+1, len(tickers)):
                w = float(C.iat[i, j])
                if w >= theta:
                    G.add_edge(tickers[i], tickers[j], weight=w)

        deg_s = pd.Series(nx.degree_centrality(G), name=d).reindex(tickers).fillna(0.0)
        try:
            ec_s = pd.Series(nx.eigenvector_centrality_numpy(G, weight="weight"), name=d).reindex(tickers).fillna(0.0)
        except Exception:
            ec_s = pd.Series({t: 0.0 for t in tickers}, name=d)

        deg_rows.append(deg_s); ec_rows.append(ec_s)

    DEG = pd.DataFrame(deg_rows).sort_index()
    EVC = pd.DataFrame(ec_rows).sort_index()
    return {"degree": DEG, "eigencent": EVC}

# -------------------- optional: weekly FinBERT sentiment --------------------
def build_weekly_sentiment(
    news_df: pd.DataFrame,
    weekly_freq: str
) -> Dict[str, pd.DataFrame]:
    """
    Expect columns: ['Ticker','Date','Headline'] at minimum.
    Returns wide weekly panels: sent_mean, sent_vol, sent_change.
    """
    try:
        from transformers import AutoTokenizer, AutoModelForSequenceClassification
        import torch
    except Exception as e:
        raise RuntimeError("Transformers not installed. Install optional deps to compute sentiment.") from e

    tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
    model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device); model.eval()

    def finbert_scores(texts, batch_size=32, max_len=512):
        scores = []
        for i in range(0, len(texts), batch_size):
            batch = [t if isinstance(t, str) and t.strip() else "" for t in texts[i:i+batch_size]]
            enc = tokenizer(batch, return_tensors="pt", truncation=True, padding=True, max_length=max_len)
            enc = {k: v.to(device) for k, v in enc.items()}
            with torch.no_grad():
                out = model(**enc)
                probs = torch.nn.functional.softmax(out.logits, dim=-1)  # [neg, neu, pos]
            batch_scores = (probs[:,2] - probs[:,0]).detach().cpu().numpy().tolist()
            for j, b in enumerate(batch):
                if b == "": batch_scores[j] = np.nan
            scores.extend(batch_scores)
        return np.array(scores, dtype=float)

    df = news_df.copy()
    df["Date"] = pd.to_datetime(df["Date"], errors="coerce")
    df = df.dropna(subset=["Ticker","Date","Headline"])
    df["sentiment"] = finbert_scores(df["Headline"].tolist())
    df["WeekEnd"] = df["Date"].dt.to_period(weekly_freq).dt.to_timestamp(weekly_freq)

    weekly = (df.groupby(["Ticker","WeekEnd"])["sentiment"]
                .agg(sent_mean="mean", sent_vol="std").reset_index())
    weekly["sent_change"] = (weekly.sort_values(["Ticker","WeekEnd"])
                                   .groupby("Ticker")["sent_mean"].diff())

    def wide(col):
        w = weekly.pivot(index="WeekEnd", columns="Ticker", values=col).sort_index()
        w.index.name = "Date"; return w.astype("float64")

    return {
        "sent_mean": wide("sent_mean").fillna(0.0),
        "sent_vol":  wide("sent_vol").fillna(0.0),
        "sent_change": wide("sent_change").fillna(0.0),
    }

# -------------------- main runner for preprocessing --------------------
def run_preprocessing(cfg: dict, log: logging.Logger) -> None:
    base = Path(cfg["data_dir"]); ensure_dir(base)
    feats_dir = base / "feat"; ensure_dir(feats_dir)

    # Prices
    log.info("Downloading adjusted close prices…")
    px = download_prices(cfg["tickers"], cfg["start"], cfg["end"], auto_adjust=True).ffill(limit=3)
    kept, dropped = filter_full_coverage(px, cfg["start"], cfg["end"], max_nan_ratio=0.01,
                                         tol_start_days=10, tol_end_days=10)
    log.info(f"KEPT ({len(kept)}): {kept}")
    for k,v in dropped.items(): log.info(f"DROPPED {k}: {v}")
    px_good = px[kept]
    save_parquet(px_good, base / "prices_daily.parquet", log)

    # Weekly features
    feats = compute_weekly_core_features(
        px_good, cfg["weekly_freq"], cfg["beta_win"], cfg["vol_wins"], cfg["turnover_win"], log
    )
    # persist
    for name, df in feats.items():
        # consistent names with your original files
        out_name = {
            "ret":"feat_ret", "beta26":"feat_beta26", "size":"feat_size",
            "turnover":"feat_turnover", "vix":"feat_vix",
            f"vol{cfg['vol_wins'][0]}": f"feat_vol{cfg['vol_wins'][0]}",
            f"vol{cfg['vol_wins'][1]}": f"feat_vol{cfg['vol_wins'][1]}",
        }.get(name, f"feat_{name}")
        save_parquet(df.reindex(columns=kept).astype(float), feats_dir / f"{out_name}.parquet", log)

    # Centralities (from returns)
    ret_w = feats["ret"].sort_index()
    cents = compute_centralities(ret_w, cfg["cor_win_weeks"], cfg["corr_threshold"], log)
    save_parquet(cents["degree"],    feats_dir / "feat_degree.parquet", log)
    save_parquet(cents["eigencent"], feats_dir / "feat_eigencent.parquet", log)

    log.info("Preprocessing complete.")




# 2. Feature Engineering

This step merges all saved weekly feature panels into a single (Week, Ticker) long table, handles warm-up weeks (zero-only signals), trims to the first fully informative week, and standardises features:

Local, node-varying features (e.g., returns, degree, eigenvector, turnover, vols, beta, size) are z-scored within each week to preserve cross-sectional structure without peeking across time.

Global, market-wide features (e.g., VIX) are standardised using training-period statistics only to avoid leakage.

In [None]:
# ============================ Section 2: Feature Engineering ============================
from pathlib import Path
import os, json, numpy as np, pandas as pd
from functools import reduce

# ---- config  ----
FEAT_DIR = Path("data/feat")         # where feat_*.parquet live
OUT_DIR  = Path("data")              # output folder
RET_FILE = "feat_ret.parquet"        # required returns panel
WEEKLY_FREQ = "W-FRI"

# To avoid leakage in global z-scales; set last week of training split (e.g. "2023-12-29")
TRAIN_END: str | None = None  # e.g., "2023-12-29"

OUT_DIR.mkdir(parents=True, exist_ok=True)

# -------------------- helpers --------------------
def to_wfri_index(df: pd.DataFrame) -> pd.DataFrame:
    """Coerce index to week-ending Friday timestamps and sort."""
    idx = pd.to_datetime(df.index)
    df = df.copy()
    df.index = idx.to_period(WEEKLY_FREQ).end_time.normalize()
    return df.sort_index()

def melt_feature(df_wide: pd.DataFrame, name: str) -> pd.DataFrame:
    """Wide (weeks × tickers) -> long (Week, Ticker, name)."""
    d = df_wide.copy()
    d["Week"] = d.index
    return d.melt(id_vars="Week", var_name="Ticker", value_name=name)

def zscore_per_week(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
    """Within-week z-score for node-level columns (cross-sectional standardization)."""
    g = df.groupby("Week", sort=False)
    for c in cols:
        mu = g[c].transform("mean")
        sd = g[c].transform("std").replace(0, np.nan)
        df[c + "_z"] = (df[c] - mu) / sd
    return df

def mark_all_zero_weeks_as_nan(df_long: pd.DataFrame, col: str) -> None:
    """Treat weeks where a feature is identically zero across all tickers as missing (warm-up)."""
    if col not in df_long.columns:
        return
    zero_weeks = (
        df_long.groupby("Week")[col]
               .apply(lambda s: np.nanmax(np.abs(s.fillna(0.0))) == 0.0)
    )
    zw = zero_weeks[zero_weeks].index
    if len(zw):
        df_long.loc[df_long["Week"].isin(zw), col] = np.nan

def first_valid_week(df_long: pd.DataFrame, cols: list[str]) -> pd.Timestamp | None:
    """Latest among the first weeks where each key column becomes available."""
    firsts = []
    for c in cols:
        if c in df_long.columns:
            w = df_long.dropna(subset=[c])["Week"].min()
            if pd.notna(w): firsts.append(pd.to_datetime(w))
    return max(firsts) if firsts else None

# -------------------- 1) Load weekly panels --------------------
ret_path = FEAT_DIR / RET_FILE
if not ret_path.exists():
    raise FileNotFoundError(f"Missing returns parquet: {ret_path}")

wk_returns = pd.read_parquet(ret_path)
wk_returns = wk_returns.set_index("Week") if "Week" in wk_returns.columns else wk_returns
wk_returns = to_wfri_index(wk_returns)

# Load all other feat_*.parquet, skipping returns and any labels
feat_dict: dict[str, pd.DataFrame] = {}
for p in FEAT_DIR.glob("*.parquet"):
    stem = p.stem.lower()
    if p.name == RET_FILE or stem.startswith("labels") or "labels" in stem:
        continue
    df = pd.read_parquet(p)
    df = df.set_index("Week") if "Week" in df.columns else df
    df = to_wfri_index(df)
    feat_dict[p.stem] = df

print(f"[INFO] wk_returns shape: {wk_returns.shape} | loaded other panels: {len(feat_dict)}")

# -------------------- 2) Melt & merge into long --------------------
long_parts = [melt_feature(wk_returns, "feat_ret")]
for name, df in feat_dict.items():
    long_parts.append(melt_feature(df, name))

feat_long = reduce(lambda a, b: a.merge(b, on=["Week", "Ticker"], how="outer"), long_parts)
feat_long = feat_long.sort_values(["Week", "Ticker"]).reset_index(drop=True)

# Drop columns that are entirely NaN
value_cols = [c for c in feat_long.columns if c not in ["Week","Ticker"]]
all_null = [c for c in value_cols if feat_long[c].isna().all()]
if all_null:
    feat_long = feat_long.drop(columns=all_null)
    value_cols = [c for c in feat_long.columns if c not in ["Week","Ticker"]]
    print("[INFO] Dropped all-NaN columns:", all_null)

# Force numeric (non-numeric → NaN)
for c in value_cols:
    feat_long[c] = pd.to_numeric(feat_long[c], errors="coerce")

# -------------------- 3) Handle warm-up all-zero weeks --------------------
warmup_prefixes = ["feat_vol4", "feat_vol8", "feat_turnover", "feat_degree", "feat_eigencent", "feat_beta26"]
warmup_cols = [c for c in value_cols if any(c.startswith(px) for px in warmup_prefixes)]
for c in warmup_cols:
    mark_all_zero_weeks_as_nan(feat_long, c)

# -------------------- 4) Trim to first fully informative week --------------------
# choose key cols to require (if missing, they’re ignored)
candidate_keys = ["feat_ret","feat_size","feat_vol4","feat_vol8","feat_degree","feat_eigencent","feat_turnover","feat_beta26"]
key_cols = [c for c in candidate_keys if c in feat_long.columns]

fv = first_valid_week(feat_long, key_cols)
if fv is not None:
    feat_long = feat_long[feat_long["Week"] >= fv].copy()
    wk_returns = wk_returns[wk_returns.index >= fv].copy()
print("[INFO] First valid week used:", fv)

# -------------------- 5) Detect global features & z-score locals --------------------
# Globals = constant across tickers within a week (e.g., feat_vix)
grp = feat_long.groupby("Week", sort=False)
global_feats = []
for c in value_cols:
    nun = grp[c].nunique(dropna=True)
    if (nun.fillna(1) <= 1).all():
        global_feats.append(c)
local_feats = [c for c in value_cols if c not in global_feats]

print("[INFO] Global features:", global_feats)
print("[INFO] Local features to per-week z-score: count=", len(local_feats))

# Remove any stale *_z from prior runs
feat_long = feat_long[[c for c in feat_long.columns if not c.endswith("_z")]].copy()

# Cross-sectional (within-week) z for locals - no temporal leakage
feat_long = zscore_per_week(feat_long, local_feats)

# Global z for globals - use TRAIN_END to avoid leakage; else whole-sample (fallback)
if TRAIN_END is not None:
    train_cut = pd.to_datetime(TRAIN_END)
    mask_train = pd.to_datetime(feat_long["Week"]) <= train_cut
else:
    mask_train = slice(None)  # all rows

for c in global_feats:
    mu = feat_long.loc[mask_train, c].mean()
    sd = feat_long.loc[mask_train, c].std()
    sd = sd if (sd and not np.isnan(sd) and sd != 0) else 1.0
    feat_long[c + "_global_z"] = (feat_long[c] - mu) / sd

# Modeling columns = all *_z (locals) + *_global_z (globals)
X_cols = [c for c in feat_long.columns if c.endswith("_z")]
feat_long[X_cols] = feat_long[X_cols].fillna(0.0)

# -------------------- 6) Save artifacts --------------------
full_path  = OUT_DIR / "node_features_long.parquet"
model_path = OUT_DIR / "node_features_model.parquet"
feat_long.to_parquet(full_path, index=False)
feat_long[["Week","Ticker"] + X_cols].to_parquet(model_path, index=False)
with open(OUT_DIR / "X_cols.json","w") as f:
    json.dump(X_cols, f, indent=2)

# -------------------- 7) Print summary --------------------
wk_min = pd.to_datetime(feat_long["Week"]).min().date() if len(feat_long) else None
wk_max = pd.to_datetime(feat_long["Week"]).max().date() if len(feat_long) else None
print("\n=== FEATURE PIPELINE SUMMARY ===")
print("Weeks:", wk_min, "→", wk_max,
      "| unique weeks:", feat_long["Week"].nunique(),
      "| tickers:", feat_long["Ticker"].nunique())
print("X_cols (#):", len(X_cols))
print("Saved:")
print(" -", full_path)
print(" -", model_path)
print(" -", OUT_DIR / "X_cols.json")


# 3. Label Assignment

We construct semi-supervised labels by fitting a rolling VAR(l) on weekly returns over a lookback window and computing FEVD at horizon H. For each in-window ticker, its “from-others” share is the FEVD row sum excluding its own diagonal term. We then shift by +1 week to form targets
$
y_{i, t+1}
$
	​
used in prediction (forecasting next-week spillover while FEVD itself is computed at horizon H=12 weeks).

Safeguards:

- strict coverage filter within each window,

- adaptive pruning (max N, collinearity drop, feasibility bound),

- deterministic “jitter” to avoid singular residual covariance in edge cases.

In [None]:
# ============================ Section 3: Label assignment (VAR–FEVD) ============================
from __future__ import annotations
from pathlib import Path
import logging, math
import numpy as np
import pandas as pd
from statsmodels.tsa.api import VAR
from tqdm.auto import tqdm

# ---- config  ----
CFG_LBL = {
    "out_dir": "data",
    "win": 40,             # rolling window length (weeks)
    "var_lags": 1,
    "fevd_h": 12,          # FEVD horizon H (in weeks)
    "coverage": 0.60,      # per-column non-null coverage threshold inside the window
    "min_t": 30,           # minimum time points after dropping NaNs
    "max_n": 25,           # max variables in the VAR after pruning
    "collinear": 0.90,     # drop one of any pair with |rho| >= this threshold
    "jitter": 1e-8,        # small deterministic noise to improve stability
    "deterministic_jitter": True,  # make jitter deterministic per window
}

def _det_jitter(shape: tuple[int, int], seed_key: int, scale: float) -> np.ndarray:
    """
    Deterministic tiny noise to avoid singular covariance. Depends on window end-week.
    """
    if scale <= 0:
        return np.zeros(shape)
    rng = np.random.default_rng(seed_key)
    return rng.normal(loc=0.0, scale=scale, size=shape)

def _prune_window(
    W: pd.DataFrame,
    lags: int,
    max_n: int,
    collinear: float,
    jitter: float,
    det_seed: int | None = None
) -> pd.DataFrame:
    """
    Clean/prune a windowed return matrix for stable VAR–FEVD:
      1) drop zero-variance cols
      2) cap dimensionality to top-std columns
      3) iteratively drop near-collinear columns (|rho| >= collinear)
      4) demean
      5) ensure feasibility: T >= (lags+1)*N + 5  (heuristic)
      6) add tiny deterministic jitter (optional)
    """
    # 1) drop zero-variance
    std = W.std()
    W = W.loc[:, std[std > 0].index]
    if W.shape[1] < 2:
        return W

    # 2) cap N
    if W.shape[1] > max_n:
        top = std.sort_values(ascending=False).index[:max_n]
        W = W.loc[:, top]

    # 3) drop near-collinear
    while W.shape[1] > 1:
        C = W.corr().abs().fillna(0.0)
        np.fill_diagonal(C.values, 0.0)
        i, j = np.unravel_index(np.argmax(C.values), C.shape)
        if C.values[i, j] < collinear:
            break
        cols = W.columns.tolist()
        # drop the lower-std of the pair
        drop_col = cols[i] if W[cols[i]].std() < W[cols[j]].std() else cols[j]
        W = W.drop(columns=[drop_col])

    # 4) demean
    W = W - W.mean()

    # 5) feasibility cap (T too small for given N)
    T, N = W.shape
    if T < (lags + 1) * N + 5:
        keep = W.std().sort_values(ascending=False).index
        maxN = max(2, (T - 5) // (lags + 1))
        W = W.loc[:, keep[:maxN]]
        T, N = W.shape

    # 6) deterministic jitter
    if jitter > 0.0:
        seed = (det_seed if det_seed is not None else 0)
        W = W + _det_jitter(W.shape, seed_key=seed, scale=jitter)

    return W

def _fit_var_fevd_stable(
    W: pd.DataFrame,
    fevd_h: int,
    lags: int
) -> tuple[np.ndarray, list[str]]:
    """
    Fit VAR(lags) and return (FEVD matrix at horizon H, variable names).
    Row-normalized so each row sums to 1.
    """
    res = VAR(W).fit(maxlags=lags, ic=None, trend='c')
    names = list(res.names) if hasattr(res, "names") else list(res.model.endog_names)
    # Statsmodels FEVD: .decomp has shape [h, k, k]; select horizon fevd_h and take last (python index h-1)
    fevd = res.fevd(fevd_h)
    M = np.asarray(fevd.decomp[fevd_h - 1])  # ensure “at H”, not “up to H”
    k = min(len(names), M.shape[0], M.shape[1])
    names, M = names[:k], M[:k, :k]

    row_sums = M.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    M = M / row_sums
    return M, names

def make_labels_from_others_hardened(
    wk_returns: pd.DataFrame,
    win: int,
    fevd_h: int,
    lags: int,
    coverage: float,
    min_t: int,
    max_n: int,
    collinear: float,
    jitter: float,
    deterministic_jitter: bool,
    log: logging.Logger | None = None
) -> pd.DataFrame:
    """
    Rolling FEVD labels:
      - Build window [t-win : t) of weekly returns.
      - Keep columns with coverage >= coverage; drop rows with any NaNs.
      - Prune window for stability.
      - Compute FEVD at horizon H; “from_others” = row sum minus diagonal.
      - Shift by +1 week to form y_next (semi-supervised labels).
    Returns long DataFrame [Week, Ticker, from_others, y_next].
    """
    weeks = wk_returns.index.tolist()
    stats = {"too_few_cols":0, "too_few_rows":0, "feasibility_fail":0, "var_fail":0, "ok":0}
    rows, mismatches = [], 0

    rng_seed_base = 17_003  # base for deterministic jitter seeds

    for t in tqdm(range(win, len(weeks)), desc="FEVD (from-others, hardened)"):
        W_end = weeks[t]
        window = wk_returns.iloc[t-win:t]

        # per-column coverage within the window
        cov = window.notna().mean()
        keep_cols = cov[cov >= coverage].index.tolist()
        if len(keep_cols) < 3:
            stats["too_few_cols"] += 1
            continue

        W = window[keep_cols].dropna(how="any")
        if W.shape[0] < min_t:
            stats["too_few_rows"] += 1
            continue

        det_seed = (hash(str(W_end)) ^ rng_seed_base) & 0xFFFFFFFF if deterministic_jitter else None
        Wp = _prune_window(W, lags=lags, max_n=max_n, collinear=collinear, jitter=jitter, det_seed=det_seed)

        T, N = Wp.shape
        if N < 2 or T < (lags + 1) * N + 5:
            stats["feasibility_fail"] += 1
            continue

        try:
            M, names = _fit_var_fevd_stable(Wp, fevd_h=fevd_h, lags=lags)
            FROM = M.sum(axis=1) - np.diag(M)  # off-diagonal row sums
            k = min(len(names), len(FROM))
            if k < 2:
                stats["var_fail"] += 1
                continue
            if len(names) != len(FROM):
                mismatches += 1
            rows.append(pd.DataFrame({"Week": W_end, "Ticker": names[:k], "from_others": FROM[:k]}))
            stats["ok"] += 1
        except Exception:
            stats["var_fail"] += 1
            continue

    labels = (pd.concat(rows, ignore_index=True)
              if rows else pd.DataFrame(columns=["Week","Ticker","from_others"]))
    labels = labels.sort_values(["Ticker","Week"])
    # Shift by +1 week to create the prediction target
    labels["y_next"] = labels.groupby("Ticker")["from_others"].shift(-1)
    labels = labels.dropna(subset=["y_next"]).reset_index(drop=True)

    # reporting
    if log:
        log.info(f"FEVD summary: {stats} | mismatches trimmed: {mismatches}")
        if not labels.empty:
            log.info(f"Label span: {labels['Week'].min()} → {labels['Week'].max()} "
                     f"| rows: {len(labels)} | tickers: {labels['Ticker'].nunique()}")
    else:
        print("---- FEVD summary ----")
        print(stats, "| mismatches trimmed:", mismatches)
        if not labels.empty:
            print("Label span:", labels['Week'].min(), "→", labels['Week'].max(),
                  "| rows:", len(labels), "| tickers:", labels['Ticker'].nunique())
    return labels

# -------------------- runner --------------------
if __name__ == "__main__":
    from utils import get_logger, seed_everything, save_run_metadata  # from earlier
    seed_everything(42)
    log = get_logger("labels")

    out_dir = Path(CFG_LBL["out_dir"]); out_dir.mkdir(parents=True, exist_ok=True)

    # Load weekly returns created in Section 1/2
    ret_path = Path("data/feat/feat_ret.parquet")
    assert ret_path.exists(), "Missing data/feat/feat_ret.parquet. Run preprocessing first."
    wk_returns = pd.read_parquet(ret_path).sort_index()

    # Compute labels
    labels = make_labels_from_others_hardened(
        wk_returns=wk_returns,
        win=CFG_LBL["win"],
        fevd_h=CFG_LBL["fevd_h"],
        lags=CFG_LBL["var_lags"],
        coverage=CFG_LBL["coverage"],
        min_t=CFG_LBL["min_t"],
        max_n=CFG_LBL["max_n"],
        collinear=CFG_LBL["collinear"],
        jitter=CFG_LBL["jitter"],
        deterministic_jitter=CFG_LBL["deterministic_jitter"],
        log=log
    )

    # Save
    labels_path = out_dir / "labels_from_next.parquet"
    labels.to_parquet(labels_path, index=False)
    log.info(f"Saved labels → {labels_path}")

    # Quick check: coverage by year
    if not labels.empty:
        labels["Year"] = pd.to_datetime(labels["Week"]).dt.year
        log.info("\nLabels per year:\n" + str(labels.groupby("Year").size()))


# 4. Graph construction (semi-supervised weekly graphs)

For each target week *t*, we build a node-attributed graph:

- **Nodes:** tickers with feature vectors $X_{i,t}$  
  (standardized features from Section 2).

- **Edges:** thresholded correlations $|\rho_{ij}| \geq \theta$  
  computed over a rolling history $[t-W,\,t)$ using weekly returns  
  (z-scored if available, else raw).

- **Labels & mask:** targets $y_{i,t}$ (next-week “from-others” labels)  
  are present only for a subset of nodes; we store a boolean **train mask** per graph  
  for semi-supervised training.

We save one **Data** object per week as `graph_YYYY-MM-DD.pt` plus an index (`graphs_index.csv`) with counts.  
Finally, we create time-ordered splits (70/15/15) and an optional quick visual of a chosen week.


In [None]:
# ============================ Section 4: Graph Construction ============================
from pathlib import Path
import json, math
import numpy as np
import pandas as pd
import torch
from itertools import combinations
from torch_geometric.data import Data


import matplotlib.pyplot as plt
import networkx as nx

# ---- config  ----
CFG_G = {
    "data_dir": "data",
    "graphs_dir": "graphs",
    "corr_win": 52,          # history window (weeks) for edges
    "edge_thr": 0.20,        # |corr| threshold
    "min_edges": 2,          # require at least this many undirected edges
    "min_labels": 1,         # require >= this many labeled nodes in week
    "ret_col_pref": ["feat_ret_z", "feat_ret"],  # try z-scored first
}

DATA_DIR      = Path(CFG_G["data_dir"])
OUT_GRAPH_DIR = Path(CFG_G["graphs_dir"])
OUT_GRAPH_DIR.mkdir(parents=True, exist_ok=True)

FEATURES_PATH   = DATA_DIR / "node_features_model.parquet"   # Week, Ticker, *_z / *_global_z only
XCOLS_PATH      = DATA_DIR / "X_cols.json"
LABELS_PATH     = DATA_DIR / "labels_from_next.parquet"
GRAPH_INDEX_CSV = OUT_GRAPH_DIR / "graphs_index.csv"
SPLIT_JSON      = OUT_GRAPH_DIR / "split_weeks.json"

# -------------------- load artifacts --------------------
feat = pd.read_parquet(FEATURES_PATH)      # (Week, Ticker, X_cols)
labels = pd.read_parquet(LABELS_PATH)      # (Week, Ticker, y_next)
with open(XCOLS_PATH, "r") as f:
    X_cols = json.load(f)

assert {"Week","Ticker"}.issubset(feat.columns),   "Features must have Week, Ticker"
assert {"Week","Ticker","y_next"}.issubset(labels.columns), "Labels must have Week, Ticker, y_next"

# choose return column for correlations
RET_COL = None
for c in CFG_G["ret_col_pref"]:
    if c in feat.columns:
        RET_COL = c; break
if RET_COL is None:
    raise ValueError("Neither feat_ret_z nor feat_ret present for correlation edges.")

# Keep only modeling columns + Week, Ticker (but we’ll still access RET_COL for edges if it’s not in X_cols)
use_cols = ["Week","Ticker"] + sorted(set(X_cols) | {RET_COL})
feat = feat[use_cols].copy()

# normalize time
feat["Week"]   = pd.to_datetime(feat["Week"])
labels["Week"] = pd.to_datetime(labels["Week"])

# weeks with features AND labels (we want at least some labels per week)
weeks_all   = sorted(feat["Week"].unique())
weeks_label = set(labels["Week"].unique())
weeks = [w for w in weeks_all if w in weeks_label]

# -------------------- helpers --------------------
def build_corr_matrix(df_feat_hist: pd.DataFrame, tickers: list[str], ret_col: str) -> np.ndarray | None:
    """
    Compute correlation over the history window for given tickers using ret_col.
    Returns an NxN numpy array (diag=0), or None if history too short.
    """
    df_ret_hist = df_feat_hist[["Week","Ticker", ret_col]].copy()
    wide = df_ret_hist.pivot(index="Week", columns="Ticker", values=ret_col)
    wide = wide.reindex(columns=tickers)
    wide = wide.dropna(how="all", axis=0)
    if wide.shape[0] < 3:   # too short
        return None
    corr = wide.corr().to_numpy()
    corr = np.nan_to_num(corr, nan=0.0)
    np.fill_diagonal(corr, 0.0)
    return corr

graph_records = []

# -------------------- build graphs --------------------
for t in range(CFG_G["corr_win"], len(weeks)):
    week = weeks[t]
    hist_weeks = weeks[t-CFG_G["corr_win"]:t]

    # features for week t
    df_feat_w = feat.loc[feat["Week"] == week].copy()
    if df_feat_w.empty:
        continue

    # deterministic ticker order
    tickers = sorted(df_feat_w["Ticker"].unique().tolist())
    df_feat_w = df_feat_w.set_index("Ticker").reindex(tickers).reset_index()

    # assemble X
    # Ensure all X_cols exist (if some globals weren’t present, they were dropped earlier)
    missing_x = [c for c in X_cols if c not in df_feat_w.columns]
    if missing_x:
        # if any missing, create zeros (safe because *_z / *_global_z already standardized)
        for c in missing_x:
            df_feat_w[c] = 0.0

    X = df_feat_w[X_cols].to_numpy(dtype=np.float32)
    X = np.nan_to_num(X, nan=0.0)

    # labels + mask
    df_lab_w = labels.loc[labels["Week"] == week, ["Ticker","y_next"]].set_index("Ticker")
    y_series = df_lab_w.reindex(tickers)["y_next"]
    mask = ~y_series.isna().to_numpy()
    if mask.sum() < CFG_G["min_labels"]:
        continue
    y_vec = y_series.fillna(0.0).to_numpy(dtype=np.float32)

    # correlation edges from history
    df_hist = feat.loc[feat["Week"].isin(hist_weeks)].copy()
    if df_hist.empty:
        continue
    corr = build_corr_matrix(df_hist, tickers, ret_col=RET_COL)
    if corr is None:
        continue

    # threshold to edge list (undirected, store both directions)
    N = len(tickers)
    edge_idx, edge_wt = [], []
    thr = CFG_G["edge_thr"]
    for i, j in combinations(range(N), 2):
        cij = float(corr[i, j])
        if abs(cij) >= thr:
            edge_idx.append([i, j]);  edge_wt.append(cij)
            edge_idx.append([j, i]);  edge_wt.append(cij)

    if (len(edge_idx) // 2) < CFG_G["min_edges"]:
        continue

    edge_index = torch.tensor(edge_idx, dtype=torch.long).t().contiguous()
    edge_weight = torch.tensor(edge_wt, dtype=torch.float32)
    x = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y_vec, dtype=torch.float32)
    train_mask = torch.tensor(mask, dtype=torch.bool)

    g = Data(x=x, edge_index=edge_index, edge_weight=edge_weight, y=y)
    g.train_mask = train_mask
    g.tickers = tickers           # list[str]
    g.week = str(pd.to_datetime(week).date())  # "YYYY-MM-DD"

    fname = f"graph_{g.week}.pt"
    fpath = OUT_GRAPH_DIR / fname
    torch.save(g, fpath)

    graph_records.append({
        "Week": pd.to_datetime(week),
        "file": fname,
        "num_nodes": N,
        "num_edges": int(edge_index.shape[1]),
        "num_labeled": int(train_mask.sum())
    })

# -------------------- index & splits --------------------
idx = pd.DataFrame(graph_records).sort_values("Week").reset_index(drop=True)
idx.to_csv(GRAPH_INDEX_CSV, index=False)

print(f"Saved {len(idx)} graphs to {OUT_GRAPH_DIR}")
if len(idx):
    print(idx.tail(5))
else:
    print("No graphs built — check corr_win, edge_thr, or label availability.")

# time-ordered 70/15/15 split
if len(idx) > 0:
    W = len(idx)
    i_tr_end = math.floor(0.70 * W)         # exclusive
    i_va_end = i_tr_end + math.floor(0.15 * W)

    train_weeks = idx.loc[:i_tr_end-1, "Week"].dt.date.astype(str).tolist()
    val_weeks   = idx.loc[i_tr_end:i_va_end-1, "Week"].dt.date.astype(str).tolist()
    test_weeks  = idx.loc[i_va_end:, "Week"].dt.date.astype(str).tolist()

    with open(SPLIT_JSON, "w") as f:
        json.dump({"train_weeks": train_weeks, "val_weeks": val_weeks, "test_weeks": test_weeks}, f, indent=2)

    print("Saved:", SPLIT_JSON)
    print(f"Counts → train {len(train_weeks)}, val {len(val_weeks)}, test {len(test_weeks)}")

# -------------------- Visual --------------------
def _normalize_weeks(df: pd.DataFrame) -> pd.DataFrame:
    """Add normalized week representations for querying (datetime, date str, raw str)."""
    d = df.copy()
    d["_week_dt"] = pd.to_datetime(d["Week"], errors="coerce")
    d["_week_date"] = d["_week_dt"].dt.date.astype(str)
    d["_week_str"] = d["Week"].astype(str)
    return d

def plot_graph_pretty(week, graph_index: pd.DataFrame, base_dir: Path, topk_per_node: int | None = None,
                      node_color="skyblue", seed: int = 42):
    gi = _normalize_weeks(graph_index.copy())
    # normalize query
    w_dt = pd.to_datetime(week, errors="coerce")
    w_date = (w_dt.date().isoformat() if pd.notna(w_dt) else str(week))
    row = gi.loc[gi["_week_date"] == w_date]
    if row.empty:
        row = gi.loc[gi["_week_str"] == str(week)]
    if row.empty and pd.notna(w_dt):
        row = gi.loc[gi["_week_dt"] == w_dt]
    if row.empty:
        print(f"Week {week} not found. Try: {gi['_week_date'].head().tolist()} …")
        return None

    fpath = base_dir / row.iloc[0]["file"]
    g = torch.load(fpath, weights_only=False, map_location="cpu")

    edges_np = g.edge_index.cpu().numpy().T
    w_np     = g.edge_weight.cpu().numpy()
    tickers  = list(g.tickers)

    G = nx.DiGraph()
    for i, tk in enumerate(tickers):
        G.add_node(i, ticker=tk, label=float(g.y[i].item()))

    # optional top-k pruning
    edges_iter = list(zip(edges_np, w_np))
    if topk_per_node:
        out_dict = {}
        for (src, dst), w in edges_iter:
            out_dict.setdefault(int(src), []).append((int(dst), float(w)))
        edges_iter = []
        for src, lst in out_dict.items():
            lst_sorted = sorted(lst, key=lambda x: abs(x[1]), reverse=True)[:topk_per_node]
            edges_iter.extend([((src, dst), w) for dst, w in lst_sorted])

    for (s, d), w in edges_iter:
        G.add_edge(int(s), int(d), weight=float(w))

    labels_dict = nx.get_node_attributes(G, "label")
    max_abs = max(1e-9, max(abs(v) for v in labels_dict.values()))
    node_sizes = [300 + 2000 * (abs(labels_dict[i]) / max_abs) for i in G.nodes]

    eweights = np.array([G[u][v]["weight"] for u, v in G.edges()], dtype=float)
    if len(eweights) == 0:
        print("Graph has no edges to plot.")
        return g
    norm_vals = (eweights + 1.0) / 2.0
    edge_colors = plt.cm.bwr(norm_vals)

    fig, ax = plt.subplots(figsize=(10, 8))
    pos = nx.spring_layout(G, seed=seed, k=1.7, iterations=100)

    nx.draw_networkx_nodes(G, pos, ax=ax, node_size=node_sizes, node_color=node_color,
                           alpha=0.9, edgecolors="black", linewidths=0.8)
    nx.draw_networkx_edges(G, pos, ax=ax, edge_color=edge_colors, alpha=0.6)
    nx.draw_networkx_labels(G, pos, ax=ax, labels={i: G.nodes[i]["ticker"] for i in G.nodes}, font_size=8)

    sm = plt.cm.ScalarMappable(cmap="bwr", norm=plt.Normalize(vmin=-1, vmax=1))
    sm.set_array([])
    fig.colorbar(sm, ax=ax, label="Correlation")

    ax.set_title(f"Graph for Week {row.iloc[0]['_week_date']}\n(Node size ∝ |label|, Edge color ∝ corr)")
    ax.axis("off")
    plt.show()
    return g

# Quick visual
gi = pd.read_csv(GRAPH_INDEX_CSV)
g = plot_graph_pretty("2021-01-08", gi, base_dir=OUT_GRAPH_DIR, topk_per_node=5)
if g is not None:
  print("x:", g.x.shape, "| y:", g.y.shape, "| edges:", g.edge_index.shape)


# 5. Baseline models (econometric benchmarks)

We benchmark the GNN against two classical baselines, using the same time splits produced from the graph index (70/15/15, time-ordered):

- **VAR–FEVD:** Rolling VAR($l$) on weekly returns with pruning and coverage guards, producing the “from-others” share at horizon $H$.  
  We evaluate next-week predictions $y_{t+1}$.

- **Scalar BEKK(1,1):** Python implementation (grid-searched $a, b$) on a capped subset of tickers, rolling the conditional covariance $H_t$  
  forward and mapping to a spillover share.

We report test-set RMSE/MAE/$R^2$, panel-weekly RMSE curves, per-ticker RMSE distributions, and Diebold–Mariano tests (overall and per ticker).  
These results are referenced in the thesis baseline section.


In [None]:
# ============================ Section 5: Baseline Models ============================
from __future__ import annotations
from pathlib import Path
import os, json, warnings, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Tuple
from statsmodels.tsa.api import VAR

# ---------- Paths  ----------
BASE_DIR   = Path(".")
FEAT_DIR   = BASE_DIR / "data" / "feat"          # feat_ret.parquet
DATA_DIR   = BASE_DIR / "data"                   # labels_from_next.parquet
GRAPH_DIR  = BASE_DIR / "graphs"                 # split_weeks.json
OUT_DIR    = BASE_DIR / "baselines"              # outputs here
FIG_DIR    = BASE_DIR / "figures"
TAB_DIR    = BASE_DIR / "tables"
for p in [OUT_DIR, FIG_DIR, TAB_DIR]:
    p.mkdir(parents=True, exist_ok=True)

RET_FILE   = FEAT_DIR / "feat_ret.parquet"
LABEL_FILE = DATA_DIR / "labels_from_next.parquet"
SPLIT_JSON = GRAPH_DIR / "split_weeks.json"

# ---------- Config ----------
ROLL_WINDOW_WEEKS = 40
VAR_LAGS          = 1
FEVD_H            = 12
COVERAGE          = 0.70
MIN_T             = 30
MAX_N             = 25
COLLINEAR         = 0.90
JITTER            = 1e-8

# Scalar BEKK (Python)
BEKK_MAX_N        = 10
BEKK_MIN_T        = 25
COVERAGE_BEKK     = 0.95
DROP_COLLINEAR_BK = 0.995
RIDGE_H           = 1e-8
JITTER_BEKK       = 0.0

# Rolling refits on test (uses info up to t to predict t+1)
ALLOW_UPDATE_ON_TEST = True

# Fast smoke test toggle
FAST_SMOKE_TEST = False
if FAST_SMOKE_TEST:
    BEKK_MAX_N = 6
    BEKK_MIN_T = 20
    TEST_WEEKS_LIMIT = 4

SEED = 42
np.random.seed(SEED)

# ---------- Helpers ----------
def to_wfri_index(df: pd.DataFrame) -> pd.DataFrame:
    idx = pd.to_datetime(df.index)
    df = df.copy()
    df.index = idx.to_period("W-FRI").end_time.dt.normalize()
    return df.sort_index()

def weeks_str_to_wfri(weeks: List[str]) -> pd.DatetimeIndex:
    if not weeks:
        return pd.DatetimeIndex([])
    w = pd.to_datetime(pd.Series(weeks))
    w = w.dt.to_period("W-FRI").dt.end_time.dt.normalize()
    return pd.DatetimeIndex(sorted(w.unique()))

def _prune_window(W: pd.DataFrame,
                  lags: int = VAR_LAGS,
                  max_n: int = MAX_N,
                  collinear: float = COLLINEAR) -> pd.DataFrame:
    """Replicates pruning used in label construction."""
    std = W.std()
    keep_idx = std[std > 0].index
    W = W.loc[:, keep_idx]
    if W.shape[1] < 2:
        return W
    if W.shape[1] > max_n:
        top = std.loc[keep_idx].sort_values(ascending=False).index[:max_n]
        W = W.loc[:, top]
    while True and W.shape[1] > 1:
        C = W.corr().abs()
        np.fill_diagonal(C.values, 0.0)
        i, j = np.unravel_index(np.nanargmax(C.values), C.shape)
        if C.values[i, j] < collinear:
            break
        cols = W.columns.tolist()
        drop_col = cols[i] if W[cols[i]].std() < W[cols[j]].std() else cols[j]
        W = W.drop(columns=[drop_col])
    W = W - W.mean()
    T, N = W.shape
    if T < (lags + 1) * N + 5:
        keep = W.std().sort_values(ascending=False).index
        maxN = max(2, (T - 5) // (lags + 1))
        W = W[keep[:maxN]]
    if JITTER > 0:
        W = W + np.random.normal(scale=JITTER, size=W.shape)
    return W

def fit_var_fevd_matrix(W: pd.DataFrame, fevd_h: int = FEVD_H, lags: int = VAR_LAGS):
    """Return (row-normalized FEVD at horizon H, variable names)."""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        res = VAR(W).fit(maxlags=lags, ic=None, trend='c')
        names = list(res.names) if hasattr(res, "names") else list(res.model.endog_names)
        fevd = res.fevd(fevd_h)
        M = np.asarray(fevd.decomp[fevd_h - 1])  # 'at H'
    k = min(len(names), M.shape[0], M.shape[1])
    names, M = names[:k], M[:k, :k]
    row_sums = M.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    M = M / row_sums
    return M, names

def from_others_from_fevd(M: np.ndarray) -> np.ndarray:
    return M.sum(axis=1) - np.diag(M)

def from_others_from_cov(H: np.ndarray) -> np.ndarray:
    """Spillover share from covariance (BEKK path)."""
    H2 = H ** 2
    row_sum = H2.sum(axis=1)
    own = np.diag(H2)
    spill = np.maximum(row_sum - own, 0.0)
    with np.errstate(invalid='ignore', divide='ignore'):
        share = spill / (row_sum + 1e-12)
    return share

def rmse(a, b):
    a, b = np.asarray(a, float), np.asarray(b, float)
    return float(np.sqrt(np.nanmean((a - b) ** 2)))

def mae(a, b):
    a, b = np.asarray(a, float), np.asarray(b, float)
    return float(np.nanmean(np.abs(a - b)))

def r2(a, b):
    a, b = np.asarray(a, float), np.asarray(b, float)
    ybar = np.nanmean(a)
    ss_res = np.nansum((a - b) ** 2)
    ss_tot = np.nansum((a - ybar) ** 2)
    return float(1 - ss_res / (ss_tot + 1e-12))

from math import erf, sqrt
def _ncdf(x: float) -> float:
    return 0.5 * (1.0 + erf(x / sqrt(2.0)))

def dm_test(loss_a, loss_b, max_lag: int | None = None):
    """Diebold–Mariano on two loss series, Newey–West long-run variance."""
    a = np.asarray(loss_a, dtype="float64").reshape(-1)
    b = np.asarray(loss_b, dtype="float64").reshape(-1)
    d = a - b
    mask = np.isfinite(d)
    d = d[mask]
    T = d.size
    if T < 10:
        return np.nan, np.nan
    d = d - d.mean()
    if max_lag is None:
        max_lag = max(0, int(T ** (1/3)))
    gamma0 = float(np.dot(d, d) / T)
    var = gamma0
    for k in range(1, max_lag + 1):
        w = 1.0 - k / (max_lag + 1.0)
        cov = float(np.dot(d[:-k], d[k:]) / T)
        var += 2.0 * w * cov
    var_mean = var / T
    if not np.isfinite(var_mean) or var_mean <= 0:
        return np.nan, np.nan
    dm = float(d.mean() / np.sqrt(var_mean))
    p  = float(2 * (1 - _ncdf(abs(dm))))
    return dm, p

def make_spd(H: np.ndarray, eps: float = RIDGE_H) -> np.ndarray:
    H = 0.5 * (H + H.T)
    try:
        wmin = float(np.nanmin(np.linalg.eigvalsh(H)))
    except Exception:
        wmin = float("nan")
    if (not np.isfinite(wmin)) or (wmin < eps):
        add = eps if not np.isfinite(wmin) else (eps - wmin)
        H = H + add * np.eye(H.shape[0])
    return H

# ---------- Load data ----------
if not RET_FILE.exists():
    raise FileNotFoundError(f"Missing returns parquet: {RET_FILE}")
wk_returns = pd.read_parquet(RET_FILE)
wk_returns = wk_returns.set_index("Week") if "Week" in wk_returns.columns else wk_returns
wk_returns = to_wfri_index(wk_returns)

if not LABEL_FILE.exists():
    raise FileNotFoundError(f"Missing labels parquet: {LABEL_FILE}")
labels = pd.read_parquet(LABEL_FILE)
labels["Week"] = pd.to_datetime(labels["Week"]).dt.to_period("W-FRI").dt.end_time.dt.normalize()

if not SPLIT_JSON.exists():
    raise FileNotFoundError(f"Missing split file: {SPLIT_JSON}")
with open(SPLIT_JSON, "r") as f:
    split = json.load(f)

train_weeks = weeks_str_to_wfri(split.get("train_weeks", []))
val_weeks   = weeks_str_to_wfri(split.get("val_weeks", []))
test_weeks  = weeks_str_to_wfri(split.get("test_weeks", []))

# intersect with returns index
train_weeks = train_weeks.intersection(wk_returns.index)
val_weeks   = val_weeks.intersection(wk_returns.index)
test_weeks  = test_weeks.intersection(wk_returns.index)

pretest_weeks = train_weeks.union(val_weeks)
if len(pretest_weeks) < 10 or len(test_weeks) < 5:
    raise ValueError(f"Split too small: pretest={len(pretest_weeks)}, test={len(test_weeks)}")

if FAST_SMOKE_TEST:
    test_weeks = pd.DatetimeIndex(sorted(test_weeks))[:TEST_WEEKS_LIMIT]

print(f"[SPLIT] train={len(train_weeks)}, val={len(val_weeks)}, test={len(test_weeks)}")

# ---------- VAR–FEVD forecasts ----------
def var_fevd_next_forecasts_long(ret_df: pd.DataFrame,
                                 pretest_idx: pd.DatetimeIndex,
                                 test_idx: pd.DatetimeIndex,
                                 allow_update: bool) -> pd.DataFrame:
    rows = []
    last_pre = pretest_idx.max()
    for t in test_idx:
        window_end = t if allow_update else min(t, last_pre)
        if window_end not in ret_df.index:
            continue
        end_loc = ret_df.index.get_loc(window_end)
        start_loc = max(0, end_loc - ROLL_WINDOW_WEEKS + 1) if ROLL_WINDOW_WEEKS else 0
        window = ret_df.iloc[start_loc:end_loc + 1]

        cov = window.notna().mean()
        keep_cols = cov[cov >= COVERAGE].index.tolist()
        W = window[keep_cols].dropna(how="any")
        if len(keep_cols) < 3 or W.shape[0] < MIN_T:
            continue

        Wp = _prune_window(W, lags=VAR_LAGS, max_n=MAX_N, collinear=COLLINEAR)
        T, N = Wp.shape
        if N < 2 or T < (VAR_LAGS + 1) * N + 5:
            continue

        try:
            M, names = fit_var_fevd_matrix(Wp, fevd_h=FEVD_H, lags=VAR_LAGS)
            FROM = from_others_from_fevd(M)
            k = min(len(names), len(FROM))
            if k >= 2:
                rows.append(pd.DataFrame({"Week": t, "Ticker": names[:k], "yhat_VAR": FROM[:k]}))
        except Exception:
            continue
    return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=["Week","Ticker","yhat_VAR"])

print("[Run] VAR–FEVD forecasts…")
ret_fit = wk_returns.loc[pretest_weeks.union(test_weeks)]
yhat_var_long = var_fevd_next_forecasts_long(ret_fit, pretest_weeks, test_weeks, allow_update=ALLOW_UPDATE_ON_TEST)

# ---------- Scalar BEKK(1,1) ----------
def s_bekk_fit_and_forecast_long(ret_df: pd.DataFrame,
                                 pretest_idx: pd.DatetimeIndex,
                                 test_idx: pd.DatetimeIndex) -> pd.DataFrame:
    pre_end = pretest_idx.max()
    if pre_end not in ret_df.index:
        pos = ret_df.index.searchsorted(pre_end, side="right") - 1
        if pos < 0:
            return pd.DataFrame(columns=["Week","Ticker","yhat_BEKK"])
        pre_end = ret_df.index[pos]

    end_loc = ret_df.index.get_loc(pre_end)
    start_loc = max(0, end_loc - ROLL_WINDOW_WEEKS + 1) if ROLL_WINDOW_WEEKS else 0
    window = ret_df.iloc[start_loc:end_loc + 1]

    cov = window.notna().mean()
    keep_cov = cov[cov >= COVERAGE_BEKK].index.tolist()
    if len(keep_cov) < 2:
        keep_cov = cov[cov >= 0.90].index.tolist()

    std_all = window[keep_cov].std().sort_values(ascending=False)
    use_cols = std_all.index[:min(len(std_all), BEKK_MAX_N)]
    sub = window[use_cols].dropna(how="any")
    if sub.shape[1] < 2 or sub.shape[0] < BEKK_MIN_T:
        return pd.DataFrame(columns=["Week","Ticker","yhat_BEKK"])

    if sub.shape[1] >= 3:
        Cabs = sub.corr().abs()
        np.fill_diagonal(Cabs.values, 0.0)
        while np.nanmax(Cabs.values) >= DROP_COLLINEAR_BK and sub.shape[1] >= 2:
            i, j = np.unravel_index(np.nanargmax(Cabs.values), Cabs.shape)
            cols = sub.columns.tolist()
            drop_col = cols[i] if sub[cols[i]].std() < sub[cols[j]].std() else cols[j]
            sub = sub.drop(columns=[drop_col])
            if sub.shape[1] < 2:
                return pd.DataFrame(columns=["Week","Ticker","yhat_BEKK"])
            Cabs = sub.corr().abs(); np.fill_diagonal(Cabs.values, 0.0)

    if JITTER_BEKK > 0:
        sub = sub.astype("float64") + np.random.normal(scale=JITTER_BEKK, size=sub.shape)

    X = sub.values.astype("float64")
    mu = X.mean(axis=0)
    E  = X - mu
    S  = np.cov(E.T)

    a_grid = np.array([0.02, 0.04, 0.06, 0.08, 0.10, 0.12])
    b_grid = np.array([0.80, 0.88, 0.92, 0.94, 0.96, 0.975])
    pairs  = [(a,b) for a in a_grid for b in b_grid if (a*a + b*b) < 0.995]

    def qml_negloglik(a, b, E, S):
        H = S.copy()
        const = (1.0 - a*a - b*b)
        if const <= 1e-8:
            return np.inf
        constS = const * S
        nll = 0.0
        for t in range(E.shape[0]):
            et = E[t].reshape(-1,1)
            H = constS + (a*a) * (et @ et.T) + (b*b) * H
            H = make_spd(H, RIDGE_H)
            try:
                L = np.linalg.cholesky(H)
            except np.linalg.LinAlgError:
                return np.inf
            ll_det = 2.0 * np.sum(np.log(np.diag(L)))
            y = np.linalg.solve(L, et)
            quad = float((y.T @ y))
            nll += (ll_det + quad)
        return nll

    best = None; best_val = np.inf
    for (a,b) in pairs:
        val = qml_negloglik(a, b, E, S)
        if val < best_val:
            best_val = val; best = (a,b)
    if best is None:
        return pd.DataFrame(columns=["Week","Ticker","yhat_BEKK"])
    a_hat, b_hat = best

    constS = (1.0 - a_hat*a_hat - b_hat*b_hat) * S
    H = S.copy()
    for t in range(E.shape[0]):
        et = E[t].reshape(-1,1)
        H = constS + (a_hat*a_hat) * (et @ et.T) + (b_hat*b_hat) * H
        H = make_spd(H, RIDGE_H)
    H_T = H.copy()

    rows = []
    tickers = list(sub.columns)
    for t in sorted([w for w in test_idx if w in ret_df.index]):
        r = ret_df.loc[t, tickers]
        if r.isna().any():
            continue
        et = (r.values.astype("float64") - mu).reshape(-1,1)
        H_next = constS + (a_hat*a_hat) * (et @ et.T) + (b_hat*b_hat) * H_T
        H_next = make_spd(H_next, RIDGE_H)
        shares = from_others_from_cov(H_next)
        rows.append(pd.DataFrame({"Week": t, "Ticker": tickers, "yhat_BEKK": shares}))
        H_T = H_next

    return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=["Week","Ticker","yhat_BEKK"])

print("[Run] Scalar BEKK(1,1) forecasts…")
yhat_bekk_long = s_bekk_fit_and_forecast_long(wk_returns, pretest_weeks, test_weeks)

# Align VAR universe to S-BEKK
if not yhat_bekk_long.empty:
    s_tickers = sorted(yhat_bekk_long["Ticker"].unique().tolist())
    wk_returns_var = wk_returns.loc[:, s_tickers].copy()
    yhat_var_long = var_fevd_next_forecasts_long(
        wk_returns_var.loc[pretest_weeks.union(test_weeks)],
        pretest_weeks, test_weeks,
        allow_update=ALLOW_UPDATE_ON_TEST
    )

# ---------- Evaluation ----------
Y = labels[["Week","Ticker","y_next"]].copy()
yhat = Y[Y["Week"].isin(test_weeks)].copy()
yhat = yhat.merge(yhat_var_long,  on=["Week","Ticker"], how="left")
yhat = yhat.merge(yhat_bekk_long, on=["Week","Ticker"], how="left")
yhat = yhat.dropna(subset=["yhat_VAR","yhat_BEKK"], how="all").reset_index(drop=True)

print(f"[Eval] rows={len(yhat)} | weeks={yhat['Week'].nunique()} | tickers={yhat['Ticker'].nunique()}")

# Overall metrics
rows = []
for name, col in [("VAR-FEVD","yhat_VAR"), ("S-BEKK(1,1)","yhat_BEKK")]:
    if col not in yhat or yhat[col].isna().all():
        continue
    rows.append({
        "model": name,
        "rmse": rmse(yhat["y_next"], yhat[col]),
        "mae":  mae(yhat["y_next"], yhat[col]),
        "r2":   r2(yhat["y_next"], yhat[col]),
        "n_obs": int(yhat[col].notna().sum())
    })
overall = pd.DataFrame(rows).sort_values("rmse") if rows else pd.DataFrame(columns=["model","rmse","mae","r2","n_obs"])
overall_path = OUT_DIR / "metrics_overall.csv"
overall.to_csv(overall_path, index=False)

# Per-ticker metrics
per_ticker = []
for name, col in [("VAR-FEVD","yhat_VAR"), ("S-BEKK(1,1)","yhat_BEKK")]:
    if col not in yhat or yhat[col].isna().all():
        continue
    for tk, g in yhat[["Ticker","y_next",col]].dropna(subset=[col]).groupby("Ticker"):
        per_ticker.append({
            "model": name, "ticker": tk,
            "rmse": rmse(g["y_next"], g[col]),
            "mae":  mae(g["y_next"], g[col]),
            "r2":   r2(g["y_next"], g[col]),
            "n_obs": int(len(g))
        })
by_ticker = pd.DataFrame(per_ticker)
by_ticker_path = OUT_DIR / "metrics_by_ticker.csv"
by_ticker.to_csv(by_ticker_path, index=False)

# Save long predictions + a unified long format for plotting
preds_path = OUT_DIR / "predictions_long.parquet"
yhat.to_parquet(preds_path, index=False)

preds_unified = []
for name, col in [("VAR-FEVD","yhat_VAR"), ("S-BEKK(1,1)","yhat_BEKK")]:
    if col in yhat:
        tmp = yhat[["Week","Ticker","y_next",col]].dropna(subset=[col]).copy()
        tmp = tmp.rename(columns={col:"yhat"})
        tmp["model"] = name
        preds_unified.append(tmp)
preds_unified = pd.concat(preds_unified, ignore_index=True) if preds_unified else pd.DataFrame(columns=["Week","Ticker","y_next","yhat","model"])
preds_unified_path = OUT_DIR / "preds_unified_long.parquet"
preds_unified.to_parquet(preds_unified_path, index=False)

# DM test (panel mean squared error per week)
panel = preds_unified.pivot_table(index=["Week","model"], values=["y_next","yhat"], aggfunc="mean").reset_index()
panel_rmse = (panel.groupby("Week")
                 .apply(lambda g: pd.Series({
                     "se_VAR":  np.mean((g.loc[g["model"]=="VAR-FEVD", "y_next"].values - g.loc[g["model"]=="VAR-FEVD","yhat"].values)**2) if "VAR-FEVD" in g["model"].values else np.nan,
                     "se_BEKK": np.mean((g.loc[g["model"]=="S-BEKK(1,1)","y_next"].values - g.loc[g["model"]=="S-BEKK(1,1)","yhat"].values)**2) if "S-BEKK(1,1)" in g["model"].values else np.nan
                 }))
                 .dropna()
                 .sort_index())
dm_stat, p_val = dm_test(panel_rmse["se_VAR"].values, panel_rmse["se_BEKK"].values, max_lag=None)
dm_tbl = pd.DataFrame([{
    "model_A": "VAR-FEVD", "model_B": "S-BEKK(1,1)",
    "loss": "MSE (panel mean per week)",
    "DM_stat": dm_stat, "p_value": p_val, "n_weeks": int(len(panel_rmse))
}])
dm_path = OUT_DIR / "dm_test.csv"
dm_tbl.to_csv(dm_path, index=False)

# ---------- Plots & tables ----------
# Table: overall metrics (CSV + LaTeX)
overall_sorted = overall.sort_values("rmse").rename(columns=str.upper)
overall_sorted.to_csv(TAB_DIR / "tab_baseline_overall.csv", index=False)
with open(TAB_DIR / "tab_baseline_overall.tex","w") as f:
    f.write(overall_sorted.to_latex(index=False, float_format="%.4f",
            caption="Baseline performance on the test set (lower is better).",
            label="tab:baseline_overall"))

# Weekly panel RMSE figure
def panel_losses(df):
    g = df.groupby(["Week","model"], as_index=False).apply(
        lambda x: pd.Series({
            "RMSE": np.sqrt(np.mean((x["y_next"] - x["yhat"])**2)),
            "MAE":  np.mean(np.abs(x["y_next"] - x["yhat"]))
        })
    ).reset_index(drop=True)
    return g.sort_values("Week")

panel_curves = panel_losses(preds_unified)
plt.figure(figsize=(10,5))
for m, g in panel_curves.groupby("model"):
    plt.plot(pd.to_datetime(g["Week"]), g["RMSE"], label=m, linewidth=2)
plt.title("Weekly Panel RMSE by Baseline Model")
plt.xlabel("Week"); plt.ylabel("RMSE")
plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
fig_weekly_rmse = FIG_DIR / "fig_baseline_weekly_panel_rmse.png"
plt.savefig(fig_weekly_rmse, dpi=180); plt.show()

# Per-ticker RMSE boxplot
order = sorted(by_ticker["model"].unique())
data = [by_ticker.loc[by_ticker["model"]==m, "rmse"].values for m in order]
plt.figure(figsize=(7,5))
plt.boxplot(data, labels=order, showfliers=False)
plt.title("Per-Ticker RMSE Distribution (Baselines)")
plt.ylabel("RMSE"); plt.grid(True, axis="y", alpha=0.3); plt.tight_layout()
fig_box_rmse = FIG_DIR / "fig_baseline_box_rmse_by_ticker.png"
plt.savefig(fig_box_rmse, dpi=180); plt.show()

# Actual vs Predicted scatter
plt.figure(figsize=(6,6))
for m in sorted(preds_unified["model"].unique()):
    g = preds_unified[preds_unified["model"]==m]
    plt.scatter(g["y_next"], g["yhat"], s=10, alpha=0.45, label=m)
lims = [min(preds_unified["y_next"].min(), preds_unified["yhat"].min()),
        max(preds_unified["y_next"].max(), preds_unified["yhat"].max())]
plt.plot(lims, lims, linewidth=2)
plt.xlim(lims); plt.ylim(lims)
plt.xlabel("Actual spillover (y_next)")
plt.ylabel("Predicted spillover (ŷ)")
plt.title("Actual vs Predicted — Baselines")
plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
fig_scatter = FIG_DIR / "fig_baseline_scatter_actual_vs_pred.png"
plt.savefig(fig_scatter, dpi=180); plt.show()

# Per-ticker DM tests
def newey_west_var(d, max_lag=None):
    d = np.asarray(d, dtype=float); T = len(d)
    if max_lag is None: max_lag = int(np.floor(4 * (T/100.0)**(2/9)))
    gamma0 = np.dot(d, d) / T; s = gamma0
    for h in range(1, max_lag+1):
        w = 1.0 - h/(max_lag+1)
        gamma = np.dot(d[h:], d[:-h]) / T
        s += 2*w*gamma
    return s

def dm_test_series(e1, e2, max_lag=None):
    d = (e1**2 - e2**2)
    d_bar = d.mean()
    var_nw = newey_west_var(d - d.mean(), max_lag=max_lag)
    stat = d_bar / sqrt(var_nw / len(d)) if var_nw > 0 else np.nan
    def norm_cdf(z): return 0.5*(1+erf(z/np.sqrt(2)))
    p = 2*(1 - norm_cdf(abs(stat))) if np.isfinite(stat) else np.nan
    return stat, p, len(d)

pl = preds_unified.pivot_table(index=["Week","Ticker"], columns="model", values=["y_next","yhat"])
pl.columns = [f"{a}_{b}" for a,b in pl.columns]
needed = {"y_next_VAR-FEVD","yhat_VAR-FEVD","y_next_S-BEKK(1,1)","yhat_S-BEKK(1,1)"}
pl = pl.dropna(subset=list(needed), how="any").reset_index()

rows_dm = []
for tk, g in pl.groupby("Ticker"):
    y  = g["y_next_VAR-FEVD"].values
    e1 = (y - g["yhat_VAR-FEVD"].values)
    e2 = (y - g["yhat_S-BEKK(1,1)"].values)
    if len(e1) >= 12:
        stat, p, n = dm_test_series(e1, e2, max_lag=None)
        rows_dm.append({"ticker": tk, "DM_stat_VAR_vs_BEKK": stat, "p_value": p, "n_weeks": n})
dm_by_ticker = pd.DataFrame(rows_dm).sort_values("p_value")
dm_by_ticker.to_csv(TAB_DIR / "tab_dm_per_ticker.csv", index=False)

# ---------- Summary ----------
print("\n=== BASELINES SUMMARY ===")
print("Saved:")
print(" - Overall metrics ->", overall_path)
print(" - By-ticker metrics ->", by_ticker_path)
print(" - Predictions (long) ->", preds_path)
print(" - Unified preds (long) ->", preds_unified_path)
print(" - DM test (overall panel) ->", dm_path)
print(" - Tables:", TAB_DIR / "tab_baseline_overall.csv", TAB_DIR / "tab_baseline_overall.tex", TAB_DIR / "tab_dm_per_ticker.csv")
print(" - Figures:", fig_weekly_rmse, fig_box_rmse, fig_scatter)
print(f"[Config] UPDATE_ON_TEST={ALLOW_UPDATE_ON_TEST} | WIN={ROLL_WINDOW_WEEKS} | LAGS={VAR_LAGS} | FEVD_H={FEVD_H}")
print(f"[S-BEKK] MAX_N={BEKK_MAX_N} | MIN_T={BEKK_MIN_T} | COVERAGE_BEKK={COVERAGE_BEKK}")


# 6. MLP baseline (feature-only, no graph)

A simple feed-forward regressor that uses the standardized node features  
$X_{i,t}$ (from Section 2) to predict next-week spillover $y_{i,t}$  
(labels from Section 3).

We train on the same time splits as the graph models (from `graphs/split_weeks.json`).  
The loss and metrics are masked to labeled nodes only (semi-supervised setting),  
but we also emit predictions for all test nodes to compare coverage.


In [None]:
# ============================ Section 6: MLP Baseline ============================
from __future__ import annotations
from pathlib import Path
import os, json, math
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import r2_score

# ---- Repro helpers  ----
def seed_everything(seed: int = 42):
    import random, os
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

# -------------------- PATHS (repo-relative) --------------------
BASE_DIR     = Path(".")
DATA_DIR     = BASE_DIR / "data"
GRAPHS_DIR   = BASE_DIR / "graphs"
RESULTS_DIR  = BASE_DIR / "results"          # unified results folder
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

FEATURES_PATH = DATA_DIR / "node_features_model.parquet"   # (Week, Ticker, *_z features)
XCOLS_PATH    = DATA_DIR / "X_cols.json"                   # list of feature columns to use
LABELS_PATH   = DATA_DIR / "labels_from_next.parquet"      # (Week, Ticker, y_next)
SPLIT_JSON    = GRAPHS_DIR / "split_weeks.json"            # weeks for train/val/test
SUMMARY_CSV   = RESULTS_DIR / "summary.csv"

MODEL_NAME    = "MLP"

# -------------------- LOAD FEATURES & LABELS --------------------
feat = pd.read_parquet(FEATURES_PATH).copy()
feat["Week"] = pd.to_datetime(feat["Week"]).dt.to_period("W-FRI").dt.end_time.dt.normalize()

with open(XCOLS_PATH, "r") as f:
    X_cols = json.load(f)

labels = pd.read_parquet(LABELS_PATH).copy()
labels["Week"] = pd.to_datetime(labels["Week"]).dt.to_period("W-FRI").dt.end_time.dt.normalize()
labels = labels.rename(columns={"y_next": "y_true"})  # align with evaluation naming

# Keep only needed columns from features
feat = feat[["Week","Ticker"] + X_cols].copy()

# -------------------- ATTACH SPLITS (from split_weeks.json if present) --------------------
def _weeks_str_to_wfri(weeks):
    if not weeks: return set()
    w = pd.to_datetime(pd.Series(weeks))
    w = set(w.dt.to_period("W-FRI").dt.end_time.dt.normalize().tolist())
    return w

if SPLIT_JSON.exists():
    with open(SPLIT_JSON, "r") as f:
        sp = json.load(f)
    train_weeks = _weeks_str_to_wfri(sp.get("train_weeks", []))
    val_weeks   = _weeks_str_to_wfri(sp.get("val_weeks", []))
    test_weeks  = _weeks_str_to_wfri(sp.get("test_weeks", []))
else:
    # fallback: chronological 70/15/15 over weeks present in features
    wk_sorted = sorted(feat["Week"].unique())
    W = len(wk_sorted)
    i_tr = math.floor(0.70 * W)
    i_va = i_tr + math.floor(0.15 * W)
    train_weeks = set(wk_sorted[:i_tr])
    val_weeks   = set(wk_sorted[i_tr:i_va])
    test_weeks  = set(wk_sorted[i_va:])

# -------------------- MERGE & PREP DATAFRAME --------------------
df = feat.merge(labels[["Week","Ticker","y_true"]], on=["Week","Ticker"], how="left")
df["split"] = np.where(df["Week"].isin(train_weeks), "train",
               np.where(df["Week"].isin(val_weeks),   "val",
               np.where(df["Week"].isin(test_weeks),  "test", "other")))
df = df[df["split"].isin(["train","val","test"])].copy()
df = df.sort_values(["Week","Ticker"]).reset_index(drop=True)

# numeric cast + impute zeros (already z-standardized upstream)
for c in X_cols:
    df[c] = pd.to_numeric(df[c], errors="coerce")
df[X_cols] = df[X_cols].fillna(0.0)

# -------------------- BUILD TENSORS (masked by labels for scoring) --------------------
def make_tensor_split(dfall: pd.DataFrame, split: str):
    d = dfall[dfall["split"] == split].copy()
    X = d[X_cols].values.astype(np.float32)
    y = d["y_true"].values.astype(np.float32)
    mask = ~np.isnan(y)
    y = np.nan_to_num(y, nan=0.0).astype(np.float32)  # placeholders ignored by mask
    return d, torch.tensor(X), torch.tensor(y), torch.tensor(mask)

df_tr, X_tr, y_tr, m_tr = make_tensor_split(df, "train")
df_va, X_va, y_va, m_va = make_tensor_split(df, "val")
df_te, X_te, y_te, m_te = make_tensor_split(df, "test")

print("Rows — train:", len(df_tr), "| val:", len(df_va), "| test:", len(df_te))
print("Labeled % — train:", float(m_tr.float().mean()), "| val:", float(m_va.float().mean()), "| test:", float(m_te.float().mean()))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -------------------- DEFINE MLP --------------------
class MLPReg(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 256, dropout: float = 0.10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

# -------------------- TRAINING --------------------
HIDDEN    = 256
DROPOUT   = 0.10
LR        = 1e-3
WD        = 1e-4
EPOCHS    = 120
PATIENCE  = 20
CLIP_NORM = 2.0

model = MLPReg(in_dim=len(X_cols), hidden=HIDDEN, dropout=DROPOUT).to(device)
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)

X_tr_t, y_tr_t, m_tr_t = X_tr.to(device), y_tr.to(device), m_tr.to(device)
X_va_t, y_va_t, m_va_t = X_va.to(device), y_va.to(device), m_va.to(device)
X_te_t, y_te_t, m_te_t = X_te.to(device), y_te.to(device), m_te.to(device)

def masked_rmse(y_pred, y_true, mask):
    if mask.sum().item() == 0: return np.nan
    p = y_pred[mask].detach().cpu().numpy()
    t = y_true[mask].detach().cpu().numpy()
    return float(np.sqrt(((p - t) ** 2).mean()))

def masked_mae(y_pred, y_true, mask):
    if mask.sum().item() == 0: return np.nan
    p = y_pred[mask].detach().cpu().numpy()
    t = y_true[mask].detach().cpu().numpy()
    return float(np.abs(p - t).mean())

best = {"val_rmse": np.inf, "state": None, "epoch": -1}
logs = []

for ep in range(1, EPOCHS + 1):
    # train
    model.train()
    opt.zero_grad()
    pred_tr = model(X_tr_t)
    loss = F.mse_loss(pred_tr[m_tr_t], y_tr_t[m_tr_t]) if m_tr_t.sum() > 0 else torch.tensor(0.0, device=device)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
    opt.step()

    # validate
    model.eval()
    with torch.no_grad():
        pred_va = model(X_va_t)
    tr_rmse = masked_rmse(pred_tr, y_tr_t, m_tr_t)
    va_rmse = masked_rmse(pred_va, y_va_t, m_va_t)

    logs.append({"epoch": ep, "train_rmse": tr_rmse, "val_rmse": va_rmse})
    print(f"[{MODEL_NAME}] epoch {ep:03d} | train RMSE={tr_rmse:.4f} | val RMSE={va_rmse:.4f}")

    cur = va_rmse
    if not np.isnan(cur) and cur < best["val_rmse"] - 1e-6:
        best.update({"val_rmse": cur, "state": {k: v.detach().cpu() for k, v in model.state_dict().items()}, "epoch": ep})
    elif ep - best["epoch"] >= PATIENCE:
        print(f"[{MODEL_NAME}] Early stop at epoch {ep}. Best val RMSE={best['val_rmse']:.4f} (epoch {best['epoch']})")
        break

# restore best
if best["state"] is not None:
    model.load_state_dict({k: v.to(device) for k, v in best["state"].items()})

# save checkpoint & logs
torch.save(model.state_dict(), RESULTS_DIR / f"{MODEL_NAME}_best.pt")
pd.DataFrame(logs).to_csv(RESULTS_DIR / f"{MODEL_NAME}_logs.csv", index=False)

# -------------------- TEST METRICS (on labeled only) --------------------
model.eval()
with torch.no_grad():
    pred_te = model(X_te_t)

test_rmse = masked_rmse(pred_te, y_te_t, m_te_t)
test_mae  = masked_mae(pred_te, y_te_t, m_te_t)
if m_te_t.sum().item() > 0:
    p = pred_te[m_te_t].detach().cpu().numpy()
    t = y_te_t[m_te_t].detach().cpu().numpy()
    test_r2 = float(r2_score(t, p))
else:
    test_r2 = np.nan

print(f"[{MODEL_NAME}] TEST: RMSE={test_rmse:.4f} | MAE={test_mae:.4f} | R2={test_r2:.4f}")

# -------------------- SAVE PREDICTIONS --------------------
# (A) labeled-only predictions for TEST — same schema others use
te_labeled = df_te[m_te.numpy()].copy()
te_labeled["y_pred"] = pred_te[m_te_t].detach().cpu().numpy()
out_a = RESULTS_DIR / f"{MODEL_NAME}_test_preds.parquet"
te_labeled[["Week","Ticker","y_true","y_pred"]].to_parquet(out_a, index=False)

# (B) full TEST predictions (all nodes)
te_full = df_te.copy()
te_full["y_pred"] = pred_te.detach().cpu().numpy()
te_full["has_label"] = ~te_full["y_true"].isna()
out_b = RESULTS_DIR / f"{MODEL_NAME}_test_preds_full.parquet"
te_full[["Week","Ticker","y_true","y_pred","has_label"]].to_parquet(out_b, index=False)

print("Saved:")
print(" -", RESULTS_DIR / f"{MODEL_NAME}_best.pt")
print(" -", RESULTS_DIR / f"{MODEL_NAME}_logs.csv")
print(" -", out_a)
print(" -", out_b)

# -------------------- SUMMARY --------------------
row = pd.DataFrame([{
    "Model": MODEL_NAME,
    "Val_RMSE": round(float(best["val_rmse"]), 6) if np.isfinite(best["val_rmse"]) else np.nan,
    "Test_RMSE": round(float(test_rmse), 6) if np.isfinite(test_rmse) else np.nan,
    "Test_MAE":  round(float(test_mae),  6) if np.isfinite(test_mae)  else np.nan,
    "Test_R2":   round(float(test_r2),   6) if np.isfinite(test_r2)   else np.nan
}])

if SUMMARY_CSV.exists():
    sm = pd.read_csv(SUMMARY_CSV)
    sm = sm[sm["Model"] != MODEL_NAME]
    sm = pd.concat([sm, row], ignore_index=True)
else:
    sm = row
sm = sm.sort_values("Test_RMSE").reset_index(drop=True)
sm.to_csv(SUMMARY_CSV, index=False)
print("Updated summary →", SUMMARY_CSV)
print(sm)


# 7. Graph Neural Network Architectures

We now train a family of GNN models on the semi-supervised weekly graphs.  
Only nodes with available labels contribute to the loss, but all nodes receive predictions.  
The following architectures are included:

- **WSAGE**: Weighted GraphSAGE with edge-weight normalization.  
- **GAT**: Graph Attention Network (multi-head).  
- **TAG**: Topology Adaptive GCN (K=3).  
- **Chebyshev GCN**: Chebyshev polynomial spectral GCN (K=3).  
- **ECC**: Edge-Conditioned Convolution (edge weights passed through an MLP).  
- **Temporal-GCN**: GCN applied to sequences of past graphs with a GRU to capture short-term dynamics.  

Each model is trained with Adam (lr = 1e-3, wd = 1e-4), hidden dimension = 96, dropout = 0.2,  
and early stopping (patience = 20 epochs). Best validation checkpoints are saved,  
and predictions are exported in the same schema as the baselines.


In [None]:

import os
import json
import math
import argparse
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from tqdm.auto import tqdm

from torch_geometric.data import Data
from torch_geometric.nn import (
    GATConv, TAGConv, ChebConv, NNConv
)
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter_add


# -------------------- CLI --------------------
def get_args():
    p = argparse.ArgumentParser(description="Train & evaluate GNNs on weekly graphs.")
    p.add_argument("--graphs_dir", type=str, default="./graphs", help="Folder with saved weekly Data objects")
    p.add_argument("--graph_index_csv", type=str, default="./graphs/graphs_index.csv", help="Index CSV with Week,file")
    p.add_argument("--split_json", type=str, default="./graphs/split_weeks.json", help="Optional explicit split JSON")
    p.add_argument("--results_dir", type=str, default="./results_gnn", help="Where to save checkpoints/logs/preds")

    # training hparams
    p.add_argument("--hidden", type=int, default=96)
    p.add_argument("--dropout", type=float, default=0.20)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=120)
    p.add_argument("--patience", type=int, default=20)

    # temporal
    p.add_argument("--t_hist", type=int, default=4, help="Temporal history length (for Temporal GCN)")

    # misc
    p.add_argument("--debug", action="store_true", help="Verbose per-graph logs for first few epochs")
    p.add_argument("--debug_epochs", type=int, default=2)
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()


# -------------------- SEEDING --------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Torch backends for determinism (may reduce throughput)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# -------------------- MODELS (only the active ones) --------------------
class WeightedSAGEConv(MessagePassing):
    """Neighbor aggregation with edge weights (normalized per destination)."""
    def __init__(self, in_channels, out_channels, aggr='add', bias=True):
        super().__init__(aggr=aggr)
        self.lin_neigh = nn.Linear(in_channels, out_channels, bias=False)
        self.lin_self  = nn.Linear(in_channels, out_channels, bias=bias)

    def forward(self, x, edge_index, edge_weight=None, use_abs_weight=True,
                add_self_loops=True, self_loop_weight=1.0):
        N = x.size(0)
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1), device=x.device, dtype=x.dtype)
        if use_abs_weight:
            edge_weight = edge_weight.abs()

        if add_self_loops:
            loop_index = torch.arange(N, device=x.device)
            loop_ei = torch.stack([loop_index, loop_index], dim=0)
            edge_index = torch.cat([edge_index, loop_ei], dim=1)
            loop_w = torch.full((N,), float(self_loop_weight), device=x.device, dtype=x.dtype)
            edge_weight = torch.cat([edge_weight, loop_w], dim=0)

        src, dst = edge_index[0], edge_index[1]
        w_sum = scatter_add(edge_weight, dst, dim=0, dim_size=N).clamp_min(1e-12)
        norm_w = edge_weight / w_sum[dst]

        x_neigh = self.lin_neigh(x)
        out = self.propagate(edge_index, x=x_neigh, weight=norm_w)
        out = out + self.lin_self(x)
        return out

    def message(self, x_j, weight):
        return weight.view(-1, 1) * x_j


class WeightedSAGEReg(nn.Module):
    def __init__(self, in_dim, hidden=64, dropout=0.2):
        super().__init__()
        self.conv1 = WeightedSAGEConv(in_dim, hidden)
        self.conv2 = WeightedSAGEConv(hidden, hidden)
        self.lin   = nn.Linear(hidden, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x).squeeze(-1)


class GATReg(nn.Module):
    def __init__(self, in_dim, hidden=64, heads=4, dropout=0.2):
        super().__init__()
        hidden = int(hidden); heads = int(heads)
        out_per_head = max(1, hidden // heads)
        self.gat1 = GATConv(in_dim, out_per_head, heads=heads, dropout=dropout)
        self.gat2 = GATConv(out_per_head*heads, out_per_head, heads=heads, dropout=dropout)
        self.lin  = nn.Linear(out_per_head*heads, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        x = self.gat1(x, edge_index)
        x = F.elu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat2(x, edge_index)
        x = F.elu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x).squeeze(-1)


class TAGReg(nn.Module):
    def __init__(self, in_dim, hidden=64, K=3, dropout=0.2):
        super().__init__()
        self.conv1 = TAGConv(in_dim, hidden, K=K)
        self.conv2 = TAGConv(hidden, hidden, K=K)
        self.lin   = nn.Linear(hidden, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        ew = edge_weight.abs() if edge_weight is not None else None
        x = self.conv1(x, edge_index, ew)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, ew)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x).squeeze(-1)


class ChebReg(nn.Module):
    def __init__(self, in_dim, hidden=64, K=3, dropout=0.2):
        super().__init__()
        self.conv1 = ChebConv(in_dim, hidden, K=K)
        self.conv2 = ChebConv(hidden, hidden, K=K)
        self.lin   = nn.Linear(hidden, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        ew = edge_weight.abs() if edge_weight is not None else None
        x = self.conv1(x, edge_index, ew)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, ew)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x).squeeze(-1)


class ECCReg(nn.Module):
    def __init__(self, in_dim, hidden=64, dropout=0.2):
        super().__init__()
        self.edge_net1 = nn.Sequential(nn.Linear(1, hidden), nn.ReLU(), nn.Linear(hidden, in_dim * hidden))
        self.edge_net2 = nn.Sequential(nn.Linear(1, hidden), nn.ReLU(), nn.Linear(hidden, hidden * hidden))
        self.conv1 = NNConv(in_dim, hidden, self.edge_net1, aggr='mean')
        self.conv2 = NNConv(hidden, hidden, self.edge_net2, aggr='mean')
        self.lin   = nn.Linear(hidden, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        edge_attr = edge_weight.abs().view(-1, 1) if edge_weight is not None else None
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x).squeeze(-1)


class TemporalGCN(nn.Module):
    """GCN over each week, then GRU across time; predicts on final week in the sequence."""
    def __init__(self, in_dim, hidden=64, dropout=0.2):
        super().__init__()
        # Use TAGConv or GCNConv—here we reuse TAG-like idea via TAGConv with K=1 to avoid extra import
        self.gcn = TAGConv(in_dim, hidden, K=1)
        self.gru = nn.GRU(hidden, hidden, batch_first=True)
        self.lin = nn.Linear(hidden, 1)
        self.dropout = dropout

    def forward(self, seq):
        H_seq = []
        for g in seq:
            ew = g.edge_weight.abs() if g.edge_weight is not None else None
            h  = F.relu(self.gcn(g.x, g.edge_index, ew))
            h  = F.dropout(h, p=self.dropout, training=self.training)
            H_seq.append(h.unsqueeze(1))  # [N, 1, H]
        H = torch.cat(H_seq, dim=1)       # [N, T, H]
        out, _ = self.gru(H)              # [N, T, H]
        return self.lin(out[:, -1, :]).squeeze(-1)  # last step


# -------------------- METRICS --------------------
def masked_metrics(pred, y, mask):
    if mask.sum().item() == 0:
        return {"rmse": np.nan, "mae": np.nan, "r2": np.nan}
    p = pred[mask].detach().cpu().numpy()
    t = y[mask].detach().cpu().numpy()
    mse  = float(((p - t) ** 2).mean())
    rmse = float(np.sqrt(mse))
    mae  = float(np.abs(p - t).mean())
    var_t = float(np.var(t))
    r2 = float(1.0 - (np.sum((p - t) ** 2) / (np.sum((t - t.mean()) ** 2) + 1e-12))) if var_t > 1e-12 else np.nan
    return {"rmse": rmse, "mae": mae, "r2": r2}


# -------------------- DATA HELPERS --------------------
def attach_split_masks_if_missing(g: Data, week_ts: pd.Timestamp, idx_df: pd.DataFrame):
    """Attach boolean week-level split flags to each node; fall back to 70/15/15 time split."""
    n = g.x.size(0)
    def mask(allw):
        return torch.tensor([week_ts in allw] * n, dtype=torch.bool)

    if hasattr(g, "is_train_week") and hasattr(g, "is_val_week") and hasattr(g, "is_test_week"):
        return g

    # determine split from global idx if needed
    pos = int(idx_df.index[idx_df["Week"] == week_ts][0])
    W   = len(idx_df)
    i_train_end = math.floor(0.70 * W)
    i_val_end   = i_train_end + math.floor(0.15 * W)
    g.is_train_week = torch.tensor([pos < i_train_end] * n, dtype=torch.bool)
    g.is_val_week   = torch.tensor([(pos >= i_train_end) and (pos < i_val_end)] * n, dtype=torch.bool)
    g.is_test_week  = torch.tensor([pos >= i_val_end] * n, dtype=torch.bool)
    return g


def align_graph_to_tickers(g: Data, target_tickers: list):
    """Realign a graph's nodes to a target ticker order; drop missing; remap edges."""
    idx_map = {tk: i for i, tk in enumerate(g.tickers)}
    N_new = len(target_tickers)
    x_new = torch.zeros((N_new, g.x.size(1)), dtype=g.x.dtype)
    y_new = torch.zeros((N_new,), dtype=g.y.dtype)
    label_mask_new = torch.zeros((N_new,), dtype=torch.bool)

    for j, tk in enumerate(target_tickers):
        i = idx_map.get(tk, None)
        if i is not None:
            x_new[j] = g.x[i]
            y_new[j] = g.y[i]
            label_mask_new[j] = g.label_mask[i] if hasattr(g, "label_mask") else False

    old_to_new = {idx_map[tk]: j for j, tk in enumerate(target_tickers) if tk in idx_map}

    if g.edge_index.numel() > 0:
        ei = g.edge_index.numpy()
        src_old, dst_old = ei[0], ei[1]
        mask_pairs = np.array([(s in old_to_new) and (d in old_to_new) for s, d in zip(src_old, dst_old)])
        src_new = [old_to_new[s] for s, d in zip(src_old, dst_old) if (s in old_to_new and d in old_to_new)]
        dst_new = [old_to_new[d] for s, d in zip(src_old, dst_old) if (s in old_to_new and d in old_to_new)]
        edge_index_new = torch.tensor([src_new, dst_new], dtype=torch.long)
        if hasattr(g, "edge_weight") and g.edge_weight is not None:
            edge_weight_new = g.edge_weight[torch.tensor(mask_pairs, dtype=torch.bool)]
        else:
            edge_weight_new = None
    else:
        edge_index_new = torch.empty((2, 0), dtype=torch.long)
        edge_weight_new = None

    g2 = Data(x=x_new, edge_index=edge_index_new, edge_weight=edge_weight_new, y=y_new)
    g2.label_mask    = label_mask_new
    g2.is_train_week = torch.tensor([g.is_train_week[0]] * N_new, dtype=torch.bool)
    g2.is_val_week   = torch.tensor([g.is_val_week[0]] * N_new, dtype=torch.bool)
    g2.is_test_week  = torch.tensor([g.is_test_week[0]] * N_new, dtype=torch.bool)
    g2.tickers = target_tickers
    g2.week    = g.week
    return g2


def collect_sequences(graphs_all, split_tag, t_hist=4):
    """Build rolling sequences of length t_hist ending at each target week t (for Temporal GCN)."""
    seqs = []
    for t in range(len(graphs_all)):
        week_t, g_t = graphs_all[t]
        if split_tag == "train" and not g_t.is_train_week[0]: continue
        if split_tag == "val"   and not g_t.is_val_week[0]:   continue
        if split_tag == "test"  and not g_t.is_test_week[0]:  continue
        if t - (t_hist - 1) < 0: continue

        tgt_tickers = g_t.tickers
        seq = []
        for k in range(t_hist):
            _, g_prev = graphs_all[t - (t_hist - 1) + k]
            seq.append(align_graph_to_tickers(g_prev, tgt_tickers))
        seqs.append((week_t, seq))
    return seqs


# -------------------- EPOCH LOOPS --------------------
def run_epoch_static(model, graphs_list, device, optimizer=None, epoch_idx=0,
                     model_name="", debug=False, debug_epochs=2):
    train = optimizer is not None
    model.train() if train else model.eval()
    total_loss, total_count = 0.0, 0
    agg = {"rmse": [], "mae": [], "r2": []}

    for week_ts, g in graphs_list:
        x  = g.x.to(device)
        ei = g.edge_index.to(device)
        ew = g.edge_weight.to(device) if hasattr(g, "edge_weight") and g.edge_weight is not None else None
        y  = g.y.to(device).float()

        lb = getattr(g, "label_mask", None)
        if lb is None and hasattr(g, "train_mask"):
            lb = g.train_mask
        if lb is None:
            lb = torch.zeros_like(g.y, dtype=torch.bool)

        if train:
            mask = (lb & g.is_train_week).to(device)
        else:
            split_mask = g.is_val_week if g.is_val_week[0] else g.is_test_week
            mask = (lb & split_mask).to(device)

        pred = model(x, ei, edge_weight=ew)

        used = int(mask.sum().item()); loss_val = np.nan
        if used > 0:
            loss = F.mse_loss(pred[mask], y[mask])
            loss_val = float(loss.detach().cpu().item())
            if train:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                optimizer.step()
            total_loss += loss_val * used
            total_count += used
            m = masked_metrics(pred, y, mask)
            agg["rmse"].append(m["rmse"]); agg["mae"].append(m["mae"]); agg["r2"].append(m["r2"])

        if debug and epoch_idx < debug_epochs:
            e_cnt = int(ei.size(1))
            ew_min = float(ew.min().item()) if ew is not None and ew.numel() else float("nan")
            ew_max = float(ew.max().item()) if ew is not None and ew.numel() else float("nan")
            print(f"[{model_name}][ep{epoch_idx:02d}] {str(week_ts)[:10]} N={x.size(0)} E={e_cnt} "
                  f"used={used} loss={loss_val:.6f} ew=[{ew_min:.3f},{ew_max:.3f}]")

    mean_loss = (total_loss / total_count) if total_count > 0 else np.nan
    mean_metrics = {k: (np.nanmean(v) if len(v) else np.nan) for k, v in agg.items()}
    return mean_loss, mean_metrics


def run_epoch_temporal(model, seqs_list, device, optimizer=None, epoch_idx=0,
                       model_name="", debug=False, debug_epochs=2):
    train = optimizer is not None
    model.train() if train else model.eval()
    total_loss, total_count = 0.0, 0
    agg = {"rmse": [], "mae": [], "r2": []}

    for week_t, seq in seqs_list:
        # move each graph in the sequence to device
        seq_dev = []
        for g in seq:
            g_dev = Data(
                x=g.x.to(device),
                edge_index=g.edge_index.to(device),
                edge_weight=(g.edge_weight.to(device) if g.edge_weight is not None else None),
                y=g.y.to(device).float()
            )
            g_dev.label_mask    = g.label_mask.to(device)
            g_dev.is_train_week = g.is_train_week.to(device)
            g_dev.is_val_week   = g.is_val_week.to(device)
            g_dev.is_test_week  = g.is_test_week.to(device)
            g_dev.tickers = g.tickers
            g_dev.week    = g.week
            seq_dev.append(g_dev)

        gT = seq_dev[-1]
        y  = gT.y
        lb = gT.label_mask

        if train:
            mask = (lb & gT.is_train_week)
        else:
            split_mask = gT.is_val_week if gT.is_val_week[0] else gT.is_test_week
            mask = (lb & split_mask)

        pred = model(seq_dev)

        used = int(mask.sum().item()); loss_val = np.nan
        if used > 0:
            loss = F.mse_loss(pred[mask], y[mask])
            loss_val = float(loss.detach().cpu().item())
            if train:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                optimizer.step()
            total_loss += loss_val * used
            total_count += used
            m = masked_metrics(pred, y, mask)
            agg["rmse"].append(m["rmse"]); agg["mae"].append(m["mae"]); agg["r2"].append(m["r2"])

        if debug and epoch_idx < debug_epochs:
            print(f"[{model_name}][ep{epoch_idx:02d}] {str(week_t)[:10]} used={used} loss={loss_val:.6f} (temporal)")

    mean_loss = (total_loss / total_count) if total_count > 0 else np.nan
    mean_metrics = {k: (np.nanmean(v) if len(v) else np.nan) for k, v in agg.items()}
    return mean_loss, mean_metrics


# -------------------- PREDICTION EXPORT --------------------
def predict_static_on_split(model, graphs_list, device, split="test", save_all_nodes=False):
    """
    If save_all_nodes=False: return only rows evaluated (mask==True).
    If save_all_nodes=True:  return ALL nodes; y_true is NaN where mask==False and has_label marks availability.
    """
    rows = []
    model.eval()
    for week_ts, g in graphs_list:
        # split filter
        if split == "test" and not g.is_test_week[0]:   continue
        if split == "val"  and not g.is_val_week[0]:    continue
        if split == "train" and not g.is_train_week[0]: continue

        x  = g.x.to(device)
        ei = g.edge_index.to(device)
        ew = g.edge_weight.to(device) if hasattr(g, "edge_weight") and g.edge_weight is not None else None
        y  = g.y.to(device).float()

        lb = getattr(g, "label_mask", None)
        if lb is None and hasattr(g, "train_mask"):
            lb = g.train_mask
        if lb is None:
            lb = torch.zeros_like(g.y, dtype=torch.bool)

        split_mask = (g.is_test_week if split == "test" else g.is_val_week if split == "val" else g.is_train_week)
        mask = (lb & split_mask).to(device)

        with torch.no_grad():
            pred = model(x, ei, edge_weight=ew)

        if save_all_nodes:
            use_idx = np.arange(g.x.size(0))
        else:
            use_idx = torch.where(mask)[0].cpu().numpy()

        for i in use_idx:
            i = int(i)
            has_label = bool(lb[i].item())
            y_true = float(y[i].item()) if has_label else np.nan
            rows.append({
                "Week": str(g.week)[:10],
                "Ticker": g.tickers[i],
                "y_true": y_true,
                "y_pred": float(pred[i].detach().cpu().item()),
                "has_label": has_label
            })
    return pd.DataFrame(rows)


def predict_temporal_on_split(model, graphs_all, device, split="test", t_hist=4, save_all_nodes=False):
    """Temporal: target is the last graph in each sequence."""
    seqs = collect_sequences(graphs_all, split, t_hist)
    rows = []
    model.eval()
    for week_t, seq in seqs:
        # to device
        seq_dev = []
        for g in seq:
            g_dev = Data(
                x=g.x.to(device),
                edge_index=g.edge_index.to(device),
                edge_weight=(g.edge_weight.to(device) if g.edge_weight is not None else None),
                y=g.y.to(device).float()
            )
            g_dev.label_mask    = g.label_mask.to(device)
            g_dev.is_train_week = g.is_train_week.to(device)
            g_dev.is_val_week   = g.is_val_week.to(device)
            g_dev.is_test_week  = g.is_test_week.to(device)
            g_dev.tickers = g.tickers
            g_dev.week    = g.week
            seq_dev.append(g_dev)

        gT = seq_dev[-1]
        y  = gT.y
        lb = gT.label_mask
        split_mask = (gT.is_test_week if split == "test" else gT.is_val_week if split == "val" else gT.is_train_week)
        mask = (lb & split_mask)

        with torch.no_grad():
            pred = model(seq_dev)

        if save_all_nodes:
            use_idx = np.arange(gT.x.size(0))
        else:
            use_idx = torch.where(mask)[0].cpu().numpy()

        for i in use_idx:
            i = int(i)
            has_label = bool(lb[i].item())
            y_true = float(y[i].item()) if has_label else np.nan
            rows.append({
                "Week": str(gT.week)[:10],
                "Ticker": gT.tickers[i],
                "y_true": y_true,
                "y_pred": float(pred[i].detach().cpu().item()),
                "has_label": has_label
            })
    return pd.DataFrame(rows)


# -------------------- TRAINERS --------------------
def train_static(name, model_ctor, feat_dim, device, train_graphs, val_graphs, test_graphs,
                 results_dir, epochs, hidden, dropout, lr, wd, patience, debug, debug_epochs):
    model = model_ctor(feat_dim, hidden, dropout=dropout).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    best = {"val_rmse": np.inf, "state": None, "epoch": -1}
    logs = []

    for ep in range(1, epochs + 1):
        tr_loss, tr_m = run_epoch_static(model, train_graphs, device, optimizer=opt,
                                         epoch_idx=ep, model_name=name, debug=debug, debug_epochs=debug_epochs)
        vl_loss, vl_m = run_epoch_static(model, val_graphs, device, optimizer=None,
                                         epoch_idx=ep, model_name=name, debug=debug, debug_epochs=debug_epochs)
        logs.append({
            "epoch": ep,
            "train_loss": float(tr_loss), "val_loss": float(vl_loss),
            "train_rmse": float(tr_m["rmse"]), "val_rmse": float(vl_m["rmse"]),
            "train_mae":  float(tr_m["mae"]),  "val_mae":  float(vl_m["mae"]),
            "train_r2":   float(tr_m["r2"]),   "val_r2":   float(vl_m["r2"]),
        })

        print(f"[{name}] epoch {ep:03d} | "
              f"train RMSE={tr_m['rmse']:.4f} (MSE={tr_loss:.5f}) | "
              f"val RMSE={vl_m['rmse']:.4f} (MSE={vl_loss:.5f})")

        cur = vl_m["rmse"]
        if not np.isnan(cur) and cur < best["val_rmse"] - 1e-6:
            best.update({"val_rmse": cur,
                         "state": {k: v.detach().cpu() for k, v in model.state_dict().items()},
                         "epoch": ep})
        elif ep - best["epoch"] >= patience:
            print(f"[{name}] Early stop at epoch {ep}. Best val RMSE={best['val_rmse']:.4f} (epoch {best['epoch']})")
            break

    # restore best
    if best["state"] is not None:
        model.load_state_dict({k: v.to(device) for k, v in best["state"].items()})

    # save checkpoint + logs
    os.makedirs(results_dir, exist_ok=True)
    try:
        torch.save(model.state_dict(), os.path.join(results_dir, f"{name}_best.pt"))
        pd.DataFrame(logs).to_csv(os.path.join(results_dir, f"{name}_logs.csv"), index=False)
    except Exception as e:
        print(f"[{name}] WARN: saving checkpoint/logs failed: {e}")

    # test metrics
    _, ts_m = run_epoch_static(model, test_graphs, device, optimizer=None, epoch_idx=0, model_name=name)

    # predictions
    try:
        preds_scored = predict_static_on_split(model, graphs, device, split="test", save_all_nodes=False)
        preds_scored["Model"] = name
        preds_scored.to_parquet(os.path.join(results_dir, f"{name}_test_preds.parquet"), index=False)

        preds_full = predict_static_on_split(model, graphs, device, split="test", save_all_nodes=True)
        preds_full["Model"] = name
        preds_full.to_parquet(os.path.join(results_dir, f"{name}_test_preds_full.parquet"), index=False)

        print(f"[{name}] saved preds: {len(preds_scored)} scored, {len(preds_full)} full")
    except Exception as e:
        print(f"[{name}] WARN: saving predictions failed: {e}")

    return model, {"val_rmse": best["val_rmse"], "test": ts_m}


def train_temporal(name, device, graphs, results_dir, epochs, hidden, dropout, lr, wd, patience,
                   t_hist, debug, debug_epochs):
    train_seqs = collect_sequences(graphs, "train", t_hist)
    val_seqs   = collect_sequences(graphs, "val",   t_hist)
    test_seqs  = collect_sequences(graphs, "test",  t_hist)
    print(f"[{name}] sequences → train={len(train_seqs)}, val={len(val_seqs)}, test={len(test_seqs)}")

    feat_dim = graphs[0][1].x.size(1)
    model = TemporalGCN(feat_dim, hidden=hidden, dropout=dropout).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    best = {"val_rmse": np.inf, "state": None, "epoch": -1}
    logs = []

    for ep in range(1, epochs + 1):
        tr_loss, tr_m = run_epoch_temporal(model, train_seqs, device, optimizer=opt,
                                           epoch_idx=ep, model_name=name, debug=debug, debug_epochs=debug_epochs)
        vl_loss, vl_m = run_epoch_temporal(model, val_seqs, device, optimizer=None,
                                           epoch_idx=ep, model_name=name, debug=debug, debug_epochs=debug_epochs)
        logs.append({
            "epoch": ep,
            "train_loss": float(tr_loss), "val_loss": float(vl_loss),
            "train_rmse": float(tr_m["rmse"]), "val_rmse": float(vl_m["rmse"]),
            "train_mae":  float(tr_m["mae"]),  "val_mae":  float(vl_m["mae"]),
            "train_r2":   float(tr_m["r2"]),   "val_r2":   float(vl_m["r2"]),
        })

        print(f"[{name}] epoch {ep:03d} | "
              f"train RMSE={tr_m['rmse']:.4f} (MSE={tr_loss:.5f}) | "
              f"val RMSE={vl_m['rmse']:.4f} (MSE={vl_loss:.5f})")

        cur = vl_m["rmse"]
        if not np.isnan(cur) and cur < best["val_rmse"] - 1e-6:
            best.update({"val_rmse": cur,
                         "state": {k: v.detach().cpu() for k, v in model.state_dict().items()},
                         "epoch": ep})
        elif ep - best["epoch"] >= patience:
            print(f"[{name}] Early stop at epoch {ep}. Best val RMSE={best['val_rmse']:.4f} (epoch {best['epoch']})")
            break

    if best["state"] is not None:
        model.load_state_dict({k: v.to(device) for k, v in best["state"].items()})

    os.makedirs(results_dir, exist_ok=True)
    try:
        torch.save(model.state_dict(), os.path.join(results_dir, f"{name}_best.pt"))
        pd.DataFrame(logs).to_csv(os.path.join(results_dir, f"{name}_logs.csv"), index=False)
    except Exception as e:
        print(f"[{name}] WARN: saving checkpoint/logs failed: {e}")

    # test metrics
    _, ts_m = run_epoch_temporal(model, test_seqs, device, optimizer=None, epoch_idx=0, model_name=name)

    # predictions
    try:
        preds_scored = predict_temporal_on_split(model, graphs, device, split="test", t_hist=t_hist, save_all_nodes=False)
        preds_scored["Model"] = name
        preds_scored.to_parquet(os.path.join(results_dir, f"{name}_test_preds.parquet"), index=False)

        preds_full = predict_temporal_on_split(model, graphs, device, split="test", t_hist=t_hist, save_all_nodes=True)
        preds_full["Model"] = name
        preds_full.to_parquet(os.path.join(results_dir, f"{name}_test_preds_full.parquet"), index=False)

        print(f"[{name}] saved preds: {len(preds_scored)} scored, {len(preds_full)} full")
    except Exception as e:
        print(f"[{name}] WARN: saving predictions failed: {e}")

    return model, {"val_rmse": best["val_rmse"], "test": ts_m}


# -------------------- MAIN --------------------
if __name__ == "__main__":
    args = get_args()
    seed_everything(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    GRAPHS_DIR      = args.graphs_dir
    GRAPH_INDEX_CSV = args.graph_index_csv
    SPLIT_JSON      = args.split_json
    RESULTS_DIR     = args.results_dir
    os.makedirs(RESULTS_DIR, exist_ok=True)

    # Load index & optional split weeks
    idx = pd.read_csv(GRAPH_INDEX_CSV)
    idx["Week"] = pd.to_datetime(idx["Week"]).dt.tz_localize(None)
    idx = idx.sort_values("Week").reset_index(drop=True)

    split_weeks = None
    if Path(SPLIT_JSON).exists():
        with open(SPLIT_JSON, "r") as f:
            sp = json.load(f)
        split_weeks = {
            "train": set(pd.to_datetime(sp["train_weeks"]).tz_localize(None)),
            "val":   set(pd.to_datetime(sp["val_weeks"]).tz_localize(None)),
            "test":  set(pd.to_datetime(sp["test_weeks"]).tz_localize(None)),
        }

    # Load all graphs
    graphs = []
    for _, row in idx.iterrows():
        fpath = os.path.join(GRAPHS_DIR, row["file"])
        if not os.path.exists(fpath):
            print(f"[WARN] Missing graph file: {fpath} (skipping)")
            continue
        g = torch.load(fpath, map_location="cpu", weights_only=False)
        # tolerate legacy 'train_mask' as label mask
        if not hasattr(g, "label_mask") and hasattr(g, "train_mask"):
            g.label_mask = g.train_mask

        # attach split flags if missing
        if split_weeks is not None:
            n = g.x.size(0)
            def mask(allw): return torch.tensor([row["Week"] in allw] * n, dtype=torch.bool)
            g.is_train_week = mask(split_weeks["train"])
            g.is_val_week   = mask(split_weeks["val"])
            g.is_test_week  = mask(split_weeks["test"])
        else:
            g = attach_split_masks_if_missing(g, row["Week"], idx)

        # basic sanity
        assert torch.isfinite(g.x).all() and torch.isfinite(g.y).all(), "Found non-finite x/y"
        if hasattr(g, "edge_weight") and g.edge_weight is not None:
            assert torch.isfinite(g.edge_weight).all(), "Found non-finite edge_weight"

        graphs.append((row["Week"], g))

    assert graphs, "No graphs found."
    feat_dim = graphs[0][1].x.size(1)
    print(f"Loaded {len(graphs)} graphs; feature dim = {feat_dim}")

    # Split lists
    train_graphs = [(w, g) for (w, g) in graphs if g.is_train_week[0]]
    val_graphs   = [(w, g) for (w, g) in graphs if g.is_val_week[0]]
    test_graphs  = [(w, g) for (w, g) in graphs if g.is_test_week[0]]
    print(f"Split → train={len(train_graphs)}, val={len(val_graphs)}, test={len(test_graphs)}")

    # Active model constructors (only the ones you left uncommented)
    STATIC_MODELS = {
        "WSAGE": lambda in_dim, hidden, dropout: WeightedSAGEReg(in_dim, hidden, dropout=dropout),
        "GAT":   lambda in_dim, hidden, dropout: GATReg(in_dim, hidden, heads=4, dropout=dropout),
        "TAG":   lambda in_dim, hidden, dropout: TAGReg(in_dim, hidden, K=3, dropout=dropout),
        "CHEB":  lambda in_dim, hidden, dropout: ChebReg(in_dim, hidden, K=3, dropout=dropout),
        "ECC":   lambda in_dim, hidden, dropout: ECCReg(in_dim, hidden, dropout=dropout),
    }

    results, trained = [], {}

    # Train static models
    for name, ctor in STATIC_MODELS.items():
        try:
            m, stats = train_static(
                name, ctor, feat_dim, device,
                train_graphs, val_graphs, test_graphs,
                RESULTS_DIR,
                epochs=args.epochs, hidden=args.hidden, dropout=args.dropout,
                lr=args.lr, wd=args.wd, patience=args.patience,
                debug=args.debug, debug_epochs=args.debug_epochs
            )
            trained[name] = m
            results.append({
                "Model": name,
                "Val_RMSE": round(stats["val_rmse"], 6),
                "Test_RMSE": round(stats["test"]["rmse"], 6),
                "Test_MAE":  round(stats["test"]["mae"], 6),
                "Test_R2":   round(stats["test"]["r2"], 6)
            })
        except Exception as e:
            print(f"[{name}] ERROR during training: {e}")

    # Train temporal model
    try:
        m, stats = train_temporal(
            "TEMPORAL_GCN", device, graphs, RESULTS_DIR,
            epochs=args.epochs, hidden=args.hidden, dropout=args.dropout,
            lr=args.lr, wd=args.wd, patience=args.patience,
            t_hist=args.t_hist, debug=args.debug, debug_epochs=args.debug_epochs
        )
        trained["TEMPORAL_GCN"] = m
        results.append({
            "Model": "TEMPORAL_GCN",
            "Val_RMSE": round(stats["val_rmse"], 6),
            "Test_RMSE": round(stats["test"]["rmse"], 6),
            "Test_MAE":  round(stats["test"]["mae"], 6),
            "Test_R2":   round(stats["test"]["r2"], 6)
        })
    except Exception as e:
        print(f"[TEMPORAL_GCN] ERROR during training: {e}")

    # Summary
    summary = pd.DataFrame(results).sort_values("Test_RMSE").reset_index(drop=True)
    summary_path = os.path.join(RESULTS_DIR, "summary.csv")
    summary.to_csv(summary_path, index=False)
    print("\n=== SUMMARY (lower is better) ===")
    print(summary)
    print(f"\nSaved summary → {summary_path}")

    # Combine predictions: scored-only
    pred_files = [p for p in os.listdir(RESULTS_DIR) if p.endswith("_test_preds.parquet")]
    all_preds = [pd.read_parquet(os.path.join(RESULTS_DIR, p)) for p in pred_files] if pred_files else []
    if all_preds:
        preds_all = pd.concat(all_preds, ignore_index=True)
        preds_all["Week"]     = pd.to_datetime(preds_all["Week"], errors="coerce").dt.tz_localize(None)
        preds_all["Ticker"]   = preds_all["Ticker"].astype("string")
        preds_all["y_true"]   = pd.to_numeric(preds_all["y_true"], errors="coerce")
        preds_all["y_pred"]   = pd.to_numeric(preds_all["y_pred"], errors="coerce")
        preds_all["has_label"]= preds_all["has_label"].astype("boolean")
        preds_all = preds_all.sort_values(["Week", "Ticker"]).reset_index(drop=True)
        out_path = os.path.join(RESULTS_DIR, "all_models_test_preds.parquet")
        preds_all.to_parquet(out_path, index=False)
        print("Saved combined test preds →", out_path)
    else:
        print("No scored-only prediction files found to combine.")

    # Combine predictions: full-graph
    pred_files_full = [p for p in os.listdir(RESULTS_DIR) if p.endswith("_test_preds_full.parquet")]
    all_preds_full = [pd.read_parquet(os.path.join(RESULTS_DIR, p)) for p in pred_files_full] if pred_files_full else []
    if all_preds_full:
        preds_all_full = pd.concat(all_preds_full, ignore_index=True)
        preds_all_full["Week"]      = pd.to_datetime(preds_all_full["Week"], errors="coerce").dt.tz_localize(None)
        preds_all_full["Ticker"]    = preds_all_full["Ticker"].astype("string")
        preds_all_full["y_true"]    = pd.to_numeric(preds_all_full["y_true"], errors="coerce")
        preds_all_full["y_pred"]    = pd.to_numeric(preds_all_full["y_pred"], errors="coerce")
        preds_all_full["has_label"] = preds_all_full["has_label"].astype("boolean")
        preds_all_full = preds_all_full.sort_values(["Week", "Ticker"]).reset_index(drop=True)
        out_path_full = os.path.join(RESULTS_DIR, "all_models_test_preds_full.parquet")
        preds_all_full.to_parquet(out_path_full, index=False)
        print("Saved combined FULL test preds →", out_path_full)
    else:
        print("No full-graph prediction files found to combine.")


# 8. Hyper-parameter Tuning (Top-3: ECC, TAG, WSAGE)

This section grid-searches compact, sensible hyper-parameter spaces for our three best static GNNs:

- **ECC:** (Signed or Abs weights) with small edge MLPs generating per-edge filters.  
- **TAG:** with polynomial order $K \in \{2,3,4\}$.  
- **Weighted GraphSAGE (WSAGE):** with normalized edge-weighted aggregation.

We optimise validation RMSE with early stopping, log every trial, and export:

- Per-trial logs and checkpoints under `results_gnn_tune_top3/`.  
- Model-specific tuning tables (`tune_ECC.csv`, `tune_TAG.csv`, `tune_WSAGE.csv`).  
- A consolidated leaderboard across all three (`top3_tuning_all.csv`).  
- The overall best-of-three checkpoint + JSON metadata (`best_of_top3_best.pt`, `best_of_top3_meta.json`).  
- Scored-only and full-graph test predictions for the best trial.


In [None]:

import os, json, math, argparse, random, itertools
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import TAGConv, NNConv
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter_add


# -------------------- CLI --------------------
def get_args():
    p = argparse.ArgumentParser(description="Hyper-tune ECC/TAG/WSAGE on weekly graphs.")
    # data IO
    p.add_argument("--graphs_dir", type=str, default="./graphs")
    p.add_argument("--graph_index_csv", type=str, default="./graphs/graphs_index.csv")
    p.add_argument("--split_json", type=str, default="./graphs/split_weeks.json")
    # outputs
    p.add_argument("--results_dir", type=str, default="./results_gnn")
    p.add_argument("--tune_dir", type=str, default="./results_gnn_tune_top3")
    # base hparams / early stop
    p.add_argument("--epochs", type=int, default=120)
    p.add_argument("--patience", type=int, default=20)
    # seed & debug
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--debug", action="store_true")
    p.add_argument("--debug_epochs", type=int, default=2)
    # optional throttle
    p.add_argument("--max_trials_per_model", type=int, default=None,
                   help="Optional cap on common-grid trials per model (random subset).")
    return p.parse_args()


def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# -------------------- MODELS (only active: WSAGE, TAG, ECC-Signed) --------------------
class WeightedSAGEConv(MessagePassing):
    """Edge-weighted neighbor aggregation with per-destination normalization."""
    def __init__(self, in_channels, out_channels, aggr='add', bias=True):
        super().__init__(aggr=aggr)
        self.lin_neigh = nn.Linear(in_channels, out_channels, bias=False)
        self.lin_self  = nn.Linear(in_channels, out_channels, bias=bias)

    def forward(self, x, edge_index, edge_weight=None, use_abs_weight=True,
                add_self_loops=True, self_loop_weight=1.0):
        N = x.size(0)
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1), device=x.device, dtype=x.dtype)
        if use_abs_weight:
            edge_weight = edge_weight.abs()

        if add_self_loops:
            loop = torch.arange(N, device=x.device)
            loop_ei = torch.stack([loop, loop], dim=0)
            edge_index = torch.cat([edge_index, loop_ei], dim=1)
            loop_w = torch.full((N,), float(self_loop_weight), device=x.device, dtype=x.dtype)
            edge_weight = torch.cat([edge_weight, loop_w], dim=0)

        src, dst = edge_index[0], edge_index[1]
        w_sum = scatter_add(edge_weight, dst, dim=0, dim_size=N).clamp_min(1e-12)
        norm_w = edge_weight / w_sum[dst]

        x_neigh = self.lin_neigh(x)
        out = self.propagate(edge_index, x=x_neigh, weight=norm_w) + self.lin_self(x)
        return out

    def message(self, x_j, weight):  # j: source index
        return weight.view(-1, 1) * x_j


class WeightedSAGEReg(nn.Module):
    def __init__(self, in_dim, hidden=64, dropout=0.2):
        super().__init__()
        self.conv1 = WeightedSAGEConv(in_dim, hidden)
        self.conv2 = WeightedSAGEConv(hidden, hidden)
        self.lin   = nn.Linear(hidden, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x).squeeze(-1)


class TAGReg(nn.Module):
    def __init__(self, in_dim, hidden=64, K=3, dropout=0.2):
        super().__init__()
        self.conv1 = TAGConv(in_dim, hidden, K=K)
        self.conv2 = TAGConv(hidden, hidden, K=K)
        self.lin   = nn.Linear(hidden, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        ew = edge_weight.abs() if edge_weight is not None else None
        x = self.conv1(x, edge_index, ew)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, ew)
        x = F.relu(x); x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x).squeeze(-1)


class ECCRegSigned(nn.Module):
    """
    ECC variant for tuning that can optionally preserve edge sign (keep_sign=True)
    or use absolute weights (keep_sign=False).
    """
    def __init__(self, in_dim, hidden=64, dropout=0.2, keep_sign=False, edge_mlp=128):
        super().__init__()
        self.keep_sign = bool(keep_sign)
        self.dropout   = float(dropout)
        self.edge_net1 = nn.Sequential(nn.Linear(1, edge_mlp), nn.ReLU(), nn.Linear(edge_mlp, in_dim * hidden))
        self.edge_net2 = nn.Sequential(nn.Linear(1, edge_mlp), nn.ReLU(), nn.Linear(edge_mlp, hidden * hidden))
        self.conv1 = NNConv(in_dim, hidden, self.edge_net1, aggr='mean')
        self.conv2 = NNConv(hidden, hidden, self.edge_net2, aggr='mean')
        self.lin   = nn.Linear(hidden, 1)

    def forward(self, x, edge_index, edge_weight=None):
        edge_attr = None
        if edge_weight is not None:
            ea = edge_weight.view(-1, 1)
            edge_attr = ea if self.keep_sign else ea.abs()
        h = self.conv1(x, edge_index, edge_attr); h = torch.relu(h); h = F.dropout(h, p=self.dropout, training=self.training)
        h = self.conv2(h, edge_index, edge_attr); h = torch.relu(h); h = F.dropout(h, p=self.dropout, training=self.training)
        return self.lin(h).squeeze(-1)


# -------------------- METRICS --------------------
def masked_metrics(pred, y, mask):
    if mask.sum().item() == 0:
        return {"rmse": np.nan, "mae": np.nan, "r2": np.nan}
    p = pred[mask].detach().cpu().numpy()
    t = y[mask].detach().cpu().numpy()
    mse  = float(((p - t) ** 2).mean())
    rmse = float(np.sqrt(mse))
    mae  = float(np.abs(p - t).mean())
    var_t = float(np.var(t))
    r2 = float(1.0 - (np.sum((p - t) ** 2) / (np.sum((t - t.mean()) ** 2) + 1e-12))) if var_t > 1e-12 else np.nan
    return {"rmse": rmse, "mae": mae, "r2": r2}


# -------------------- DATA HELPERS --------------------
def attach_split_masks_if_missing(g: Data, week_ts: pd.Timestamp, idx_df: pd.DataFrame):
    """Attach week-level split flags per node using 70/15/15 chronological split when needed."""
    n = g.x.size(0)
    if hasattr(g, "is_train_week") and hasattr(g, "is_val_week") and hasattr(g, "is_test_week"):
        return g
    pos = int(idx_df.index[idx_df["Week"] == week_ts][0])
    W   = len(idx_df)
    i_train_end = math.floor(0.70 * W)
    i_val_end   = i_train_end + math.floor(0.15 * W)
    g.is_train_week = torch.tensor([pos < i_train_end] * n, dtype=torch.bool)
    g.is_val_week   = torch.tensor([(pos >= i_train_end) and (pos < i_val_end)] * n, dtype=torch.bool)
    g.is_test_week  = torch.tensor([pos >= i_val_end] * n, dtype=torch.bool)
    return g


def load_graph_bundle(graphs_dir, graph_index_csv, split_json=None):
    """Load all graphs + index, normalize dates, enforce splits."""
    idx = pd.read_csv(graph_index_csv)
    idx["Week"] = pd.to_datetime(idx["Week"]).dt.tz_localize(None)
    idx = idx.sort_values("Week").reset_index(drop=True)

    split_weeks = None
    if split_json and Path(split_json).exists():
        with open(split_json, "r") as f:
            sp = json.load(f)
        split_weeks = {
            "train": set(pd.to_datetime(sp["train_weeks"]).tz_localize(None)),
            "val":   set(pd.to_datetime(sp["val_weeks"]).tz_localize(None)),
            "test":  set(pd.to_datetime(sp["test_weeks"]).tz_localize(None)),
        }

    graphs = []
    for _, row in idx.iterrows():
        fpath = os.path.join(graphs_dir, row["file"])
        if not os.path.exists(fpath):
            print(f"[WARN] Missing graph file: {fpath} (skipping)")
            continue
        g = torch.load(fpath, map_location="cpu", weights_only=False)

        # tolerate legacy label naming
        if not hasattr(g, "label_mask") and hasattr(g, "train_mask"):
            g.label_mask = g.train_mask

        # attach split flags
        if split_weeks is not None:
            n = g.x.size(0)
            def mask(allw): return torch.tensor([row["Week"] in allw] * n, dtype=torch.bool)
            g.is_train_week = mask(split_weeks["train"])
            g.is_val_week   = mask(split_weeks["val"])
            g.is_test_week  = mask(split_weeks["test"])
        else:
            g = attach_split_masks_if_missing(g, row["Week"], idx)

        assert torch.isfinite(g.x).all() and torch.isfinite(g.y).all(), "Found non-finite x/y"
        if hasattr(g, "edge_weight") and g.edge_weight is not None:
            assert torch.isfinite(g.edge_weight).all(), "Found non-finite edge_weight"

        graphs.append((row["Week"], g))

    assert graphs, "No graphs found."
    feat_dim = graphs[0][1].x.size(1)
    train_graphs = [(w, g) for (w, g) in graphs if g.is_train_week[0]]
    val_graphs   = [(w, g) for (w, g) in graphs if g.is_val_week[0]]
    test_graphs  = [(w, g) for (w, g) in graphs if g.is_test_week[0]]

    return idx, graphs, feat_dim, train_graphs, val_graphs, test_graphs


# -------------------- EPOCH LOOPS & PREDICTION --------------------
def run_epoch_static(model, graphs_list, device, optimizer=None, epoch_idx=0,
                     model_name="", debug=False, debug_epochs=2):
    train = optimizer is not None
    model.train() if train else model.eval()
    total_loss, total_count = 0.0, 0
    agg = {"rmse": [], "mae": [], "r2": []}

    for week_ts, g in graphs_list:
        x  = g.x.to(device)
        ei = g.edge_index.to(device)
        ew = g.edge_weight.to(device) if hasattr(g, "edge_weight") and g.edge_weight is not None else None
        y  = g.y.to(device).float()

        lb = getattr(g, "label_mask", None) or torch.zeros_like(g.y, dtype=torch.bool)
        if train:
            mask = (lb & g.is_train_week).to(device)
        else:
            split_mask = g.is_val_week if g.is_val_week[0] else g.is_test_week
            mask = (lb & split_mask).to(device)

        pred = model(x, ei, edge_weight=ew)

        used = int(mask.sum().item()); loss_val = np.nan
        if used > 0:
            loss = F.mse_loss(pred[mask], y[mask])
            loss_val = float(loss.detach().cpu().item())
            if train:
                optimizer.zero_grad(); loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                optimizer.step()
            total_loss += loss_val * used
            total_count += used
            m = masked_metrics(pred, y, mask)
            agg["rmse"].append(m["rmse"]); agg["mae"].append(m["mae"]); agg["r2"].append(m["r2"])

        if debug and epoch_idx < debug_epochs:
            e_cnt = int(ei.size(1))
            ew_min = float(ew.min().item()) if ew is not None and ew.numel() else float("nan")
            ew_max = float(ew.max().item()) if ew is not None and ew.numel() else float("nan")
            print(f"[{model_name}][ep{epoch_idx:02d}] {str(week_ts)[:10]} N={x.size(0)} E={e_cnt} "
                  f"used={used} loss={loss_val:.6f} ew=[{ew_min:.3f},{ew_max:.3f}]")

    mean_loss = (total_loss / total_count) if total_count > 0 else np.nan
    mean_metrics = {k: (np.nanmean(v) if len(v) else np.nan) for k, v in agg.items()}
    return mean_loss, mean_metrics


def predict_static_on_split(model, graphs_list, device, split="test", save_all_nodes=False):
    rows = []
    model.eval()
    for week_ts, g in graphs_list:
        if split == "test" and not g.is_test_week[0]:   continue
        if split == "val"  and not g.is_val_week[0]:    continue
        if split == "train" and not g.is_train_week[0]: continue

        x  = g.x.to(device)
        ei = g.edge_index.to(device)
        ew = g.edge_weight.to(device) if hasattr(g, "edge_weight") and g.edge_weight is not None else None
        y  = g.y.to(device).float()

        lb = getattr(g, "label_mask", None) or torch.zeros_like(g.y, dtype=torch.bool)
        split_mask = (g.is_test_week if split == "test" else g.is_val_week if split == "val" else g.is_train_week)
        mask = (lb & split_mask).to(device)

        with torch.no_grad():
            pred = model(x, ei, edge_weight=ew)

        use_idx = np.arange(g.x.size(0)) if save_all_nodes else torch.where(mask)[0].cpu().numpy()
        for i in use_idx:
            i = int(i)
            has_label = bool(lb[i].item())
            y_true = float(y[i].item()) if has_label else np.nan
            rows.append({
                "Week": str(g.week)[:10],
                "Ticker": g.tickers[i],
                "y_true": y_true,
                "y_pred": float(pred[i].detach().cpu().item()),
                "has_label": has_label
            })
    return pd.DataFrame(rows)


# -------------------- TRIAL TRAINER --------------------
def train_static_hp(name, model_ctor, device, feat_dim, train_graphs, val_graphs, test_graphs,
                    *, hidden, dropout, lr, wd, epochs, patience,
                    log_prefix="", save_dir="./results_gnn_tune_top3",
                    debug=False, debug_epochs=2):
    model = model_ctor(feat_dim, int(hidden), float(dropout)).to(device)
    opt   = torch.optim.Adam(model.parameters(), lr=float(lr), weight_decay=float(wd))
    best  = {"val_rmse": np.inf, "state": None, "epoch": -1}
    logs  = []

    for ep in range(1, int(epochs) + 1):
        tr_loss, tr_m = run_epoch_static(model, train_graphs, device, optimizer=opt,
                                         epoch_idx=ep, model_name=name, debug=debug, debug_epochs=debug_epochs)
        vl_loss, vl_m = run_epoch_static(model,  val_graphs, device, optimizer=None,
                                         epoch_idx=ep, model_name=name, debug=debug, debug_epochs=debug_epochs)

        logs.append({
            "epoch": ep,
            "train_loss": float(tr_loss), "val_loss": float(vl_loss),
            "train_rmse": float(tr_m["rmse"]), "val_rmse": float(vl_m["rmse"]),
            "train_mae":  float(tr_m["mae"]),  "val_mae":  float(vl_m["mae"]),
            "train_r2":   float(tr_m["r2"]),   "val_r2":   float(vl_m["r2"]),
        })

        cur = vl_m["rmse"]
        if not np.isnan(cur) and cur < best["val_rmse"] - 1e-9:
            best.update({"val_rmse": cur,
                         "state": {k: v.detach().cpu() for k, v in model.state_dict().items()},
                         "epoch": ep})
        elif ep - best["epoch"] >= int(patience):
            break

    if best["state"] is not None:
        model.load_state_dict({k: v.to(device) for k, v in best["state"].items()})

    os.makedirs(save_dir, exist_ok=True)
    tag = f"{log_prefix}{name}_H{hidden}_do{dropout}_lr{lr}_wd{wd}"
    try:
        torch.save(model.state_dict(), os.path.join(save_dir, f"{tag}_best.pt"))
        pd.DataFrame(logs).to_csv(os.path.join(save_dir, f"{tag}_logs.csv"), index=False)
    except Exception as e:
        print(f"[{name}] WARN: save failed: {e}")

    _, ts_m = run_epoch_static(model, test_graphs, device, optimizer=None, epoch_idx=0, model_name=name)
    return model, best, ts_m, tag


# -------------------- GRID SEARCH RUNNER --------------------
def spec_to_tag(spec: dict):
    if not spec: return "base"
    parts = []
    for k, v in spec.items():
        if isinstance(v, bool): v = "T" if v else "F"
        parts.append(f"{k}{v}")
    return "_".join(parts)


def run_grid_for_model(model_name, ctor_factory, model_specific_grid, common_grid, device, feat_dim,
                       train_graphs, val_graphs, test_graphs, epochs, patience, save_dir,
                       max_trials=None, debug=False, debug_epochs=2):
    rows = []
    best_global = {"val_rmse": np.inf, "bundle": None, "test": None, "ckpt": None, "tag": None}

    common = list(common_grid)
    if (max_trials is not None) and (max_trials < len(common)):
        random.seed(42)
        common = random.sample(common, max_trials)

    for spec in model_specific_grid:
        ctor = ctor_factory(**spec) if spec is not None else ctor_factory()
        spec_tag = spec_to_tag(spec or {})
        for (hidden, dropout, lr, wd) in common:
            trial_prefix = f"{model_name}_{spec_tag}__"
            model, best, ts_m, tag = train_static_hp(
                name=model_name, model_ctor=ctor, device=device, feat_dim=feat_dim,
                train_graphs=train_graphs, val_graphs=val_graphs, test_graphs=test_graphs,
                hidden=hidden, dropout=dropout, lr=lr, wd=wd,
                epochs=epochs, patience=patience, log_prefix=trial_prefix, save_dir=save_dir,
                debug=debug, debug_epochs=debug_epochs
            )
            row = {
                "Model": model_name, **{f"spec_{k}": v for k, v in (spec or {}).items()},
                "hidden": hidden, "dropout": dropout, "lr": lr, "wd": wd,
                "Val_RMSE": float(best["val_rmse"]),
                "Test_RMSE": float(ts_m["rmse"]), "Test_MAE": float(ts_m["mae"]), "Test_R2": float(ts_m["r2"]),
                "best_epoch": int(best["epoch"]), "ckpt_tag": tag
            }
            rows.append(row)

            if best["val_rmse"] < best_global["val_rmse"]:
                best_global = {"val_rmse": best["val_rmse"], "bundle": (spec, hidden, dropout, lr, wd),
                               "test": ts_m, "ckpt": os.path.join(save_dir, f"{tag}_best.pt"), "tag": tag}

    df = pd.DataFrame(rows).sort_values("Val_RMSE").reset_index(drop=True)
    out_csv = os.path.join(save_dir, f"tune_{model_name}.csv")
    df.to_csv(out_csv, index=False)
    print(f"[{model_name}] Saved tuning table → {out_csv}")
    print(f"[{model_name}] BEST (by Val RMSE):", best_global["bundle"], "Val_RMSE=", best_global["val_rmse"])
    return df, best_global


# -------------------- MAIN --------------------
if __name__ == "__main__":
    args = get_args()
    seed_everything(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    os.makedirs(args.results_dir, exist_ok=True)
    os.makedirs(args.tune_dir, exist_ok=True)

    # Load graphs
    idx, graphs, feat_dim, train_graphs, val_graphs, test_graphs = load_graph_bundle(
        args.graphs_dir, args.graph_index_csv, split_json=args.split_json if args.split_json else None
    )
    print(f"Loaded {len(graphs)} graphs; feature dim = {feat_dim}")
    print(f"Split → train={len(train_graphs)}, val={len(val_graphs)}, test={len(test_graphs)}")

    # ----- Search spaces (compact, sensible) -----
    HIDDEN_SPACE   = [64, 96, 128]
    DROPOUT_SPACE  = [0.1, 0.2, 0.3]
    LR_SPACE       = [1e-3, 5e-4]
    WD_SPACE       = [1e-4, 5e-5]
    COMMON_GRID = list(itertools.product(HIDDEN_SPACE, DROPOUT_SPACE, LR_SPACE, WD_SPACE))

    ECC_SPECS   = [{"keep_sign": s, "edge_mlp": w} for s in [False, True] for w in [64, 128]]
    TAG_SPECS   = [{"K": K} for K in [2, 3, 4]]
    WSAGE_SPECS = [None]  # only common hparams

    # ----- Ctor factories -----
    make_tag_ctor   = lambda K: (lambda in_dim, hidden, dropout: TAGReg(in_dim, hidden, K=int(K), dropout=float(dropout)))
    make_wsage_ctor = lambda : (lambda in_dim, hidden, dropout: WeightedSAGEReg(in_dim, hidden, dropout=float(dropout)))
    make_ecc_ctor   = lambda keep_sign=False, edge_mlp=128: (
        lambda in_dim, hidden, dropout: ECCRegSigned(in_dim, hidden, dropout=float(dropout),
                                                     keep_sign=bool(keep_sign), edge_mlp=int(edge_mlp))
    )

    # ----- Run tuning for each model -----
    ecc_df, ecc_best = run_grid_for_model(
        "ECC",  make_ecc_ctor,  ECC_SPECS,  COMMON_GRID, device, feat_dim,
        train_graphs, val_graphs, test_graphs,
        epochs=args.epochs, patience=args.patience, save_dir=args.tune_dir,
        max_trials=args.max_trials_per_model, debug=args.debug, debug_epochs=args.debug_epochs
    )
    tag_df, tag_best = run_grid_for_model(
        "TAG",  make_tag_ctor,  TAG_SPECS,  COMMON_GRID, device, feat_dim,
        train_graphs, val_graphs, test_graphs,
        epochs=args.epochs, patience=args.patience, save_dir=args.tune_dir,
        max_trials=args.max_trials_per_model, debug=args.debug, debug_epochs=args.debug_epochs
    )
    ws_df, ws_best = run_grid_for_model(
        "WSAGE", make_wsage_ctor, WSAGE_SPECS, COMMON_GRID, device, feat_dim,
        train_graphs, val_graphs, test_graphs,
        epochs=args.epochs, patience=args.patience, save_dir=args.tune_dir,
        max_trials=args.max_trials_per_model, debug=args.debug, debug_epochs=args.debug_epochs
    )

    # ----- Consolidated leaderboard -----
    leader = pd.concat([
        ecc_df.assign(ModelGroup="ECC"),
        tag_df.assign(ModelGroup="TAG"),
        ws_df.assign(ModelGroup="WSAGE")
    ], ignore_index=True).sort_values("Val_RMSE").reset_index(drop=True)

    leader_path = os.path.join(args.tune_dir, "top3_tuning_all.csv")
    leader.to_csv(leader_path, index=False)
    print("\nSaved consolidated leaderboard →", leader_path)
    print(leader.head(10))

    # ----- Pick best-of-three and persist -----
    all_bests = [("ECC", ecc_best), ("TAG", tag_best), ("WSAGE", ws_best)]
    best_name, best_info = min(all_bests, key=lambda x: x[1]["val_rmse"])

    best_ckpt   = best_info["ckpt"]
    best_tag    = best_info["tag"]
    best_bundle = best_info["bundle"]  # (spec, hidden, dropout, lr, wd)
    best_test   = best_info["test"]

    print(f"\n[BEST OF TOP3] {best_name}  Val_RMSE={best_info['val_rmse']:.6f}  Test={best_test}")

    best_clean_ckpt = os.path.join(args.tune_dir, "best_of_top3_best.pt")
    try:
        state = torch.load(best_ckpt, map_location="cpu")
        torch.save(state, best_clean_ckpt)
    except Exception as e:
        print("WARN: could not re-save best ckpt:", e)

    meta = {
        "best_model_name": best_name,
        "ckpt_source_tag": best_tag,
        "bundle": {
            "spec": best_bundle[0], "hidden": best_bundle[1], "dropout": best_bundle[2],
            "lr": best_bundle[3], "wd": best_bundle[4]
        },
        "val_rmse": float(best_info["val_rmse"]),
        "test_metrics": {k: float(v) for k, v in best_test.items()}
    }
    with open(os.path.join(args.tune_dir, "best_of_top3_meta.json"), "w") as f:
        json.dump(meta, f, indent=2)
    print("Saved best-of-top3 checkpoint + meta in:", args.tune_dir)

    # ----- Save test predictions for the best-of-three (scored-only + full) -----
    spec, hidden, dropout, lr, wd = best_bundle
    if best_name == "ECC":
        ctor = make_ecc_ctor(**(spec or {}))
    elif best_name == "TAG":
        ctor = make_tag_ctor(**(spec or {}))
    else:
        ctor = make_wsage_ctor()

    best_model = ctor(feat_dim, int(hidden), float(dropout)).to(device)
    best_model.load_state_dict({k: v.to(device) for k, v in state.items()})

    try:
        preds_scored = predict_static_on_split(best_model, graphs, device, split="test", save_all_nodes=False)
        preds_scored["Model"] = f"BEST_OF_TOP3_{best_name}"
        preds_scored.to_parquet(os.path.join(args.tune_dir, "best_of_top3_test_preds.parquet"), index=False)

        preds_full = predict_static_on_split(best_model, graphs, device, split="test", save_all_nodes=True)
        preds_full["Model"] = f"BEST_OF_TOP3_{best_name}"
        preds_full.to_parquet(os.path.join(args.tune_dir, "best_of_top3_test_preds_full.parquet"), index=False)

        print(f"Saved best-of-top3 predictions → {args.tune_dir}")
    except Exception as e:
        print("WARN: saving best-of-top3 predictions failed:", e)


# 9. Evaluation Figures & Tables

This section compiles all evaluation artifacts for the **baselines** and **GNNs**:

**Baselines** (from the baseline pipeline):
- Table: overall metrics (CSV + LaTeX).  
- Weekly panel RMSE line plot (e.g., VAR vs. BEKK).  
- Per-ticker RMSE boxplot.  
- Actual vs. predicted scatter plots.  
- Diebold-Mariano (DM) tests: overall (from file) + per-ticker (computed here).  

**GNNs** (from Section 7/8 outputs):
- Grid of learning curves across trained GNNs.  
- Best tuned run per model curves.  
- Bar chart: Test RMSE before vs. after hyper-parameter tuning (top-3 models).



In [None]:


import os
import argparse
import glob
import json
from math import sqrt
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ----------------------------- CLI ----------------------------- #
def get_args():
    p = argparse.ArgumentParser("Make evaluation figures & tables")
    p.add_argument("--base_dir", type=str, default=".")
    p.add_argument("--baselines_dir", type=str, default="./Baselines",
                   help="Where baseline pipeline artifacts live")
    p.add_argument("--results_gnn", type=str, default="./results_gnn",
                   help="Where Section 7 training artifacts live")
    p.add_argument("--tune_dir", type=str, default="./results_gnn_tune_top3",
                   help="Where Section 8 tuning artifacts live")
    p.add_argument("--fig_dir", type=str, default="./figures")
    p.add_argument("--tab_dir", type=str, default="./tables")
    return p.parse_args()


# -------------------------- Utilities -------------------------- #
def ensure_dirs(*paths):
    for p in paths:
        os.makedirs(p, exist_ok=True)


def _safe_read_csv(path):
    if os.path.exists(path):
        return pd.read_csv(path)
    print(f"[WARN] Missing CSV: {path}")
    return None


def _safe_read_parquet(path):
    if os.path.exists(path):
        return pd.read_parquet(path)
    print(f"[WARN] Missing Parquet: {path}")
    return None


# -------------------- Baseline: Tables & Plots ------------------ #
def make_overall_baseline_table(overall_df: pd.DataFrame, tab_dir: str):
    """Save overall baseline metrics as CSV + LaTeX."""
    overall_sorted = overall_df.sort_values("rmse").rename(columns=str.upper)
    out_csv = os.path.join(tab_dir, "tab_baseline_overall.csv")
    out_tex = os.path.join(tab_dir, "tab_baseline_overall.tex")
    overall_sorted.to_csv(out_csv, index=False)
    with open(out_tex, "w") as f:
        f.write(overall_sorted.to_latex(
            index=False, float_format="%.4f",
            caption="Baseline performance on test set (lower is better).",
            label="tab:baseline_overall"
        ))
    print("Saved:", out_csv, "|", out_tex)
    return overall_sorted


def _panel_losses(preds_long: pd.DataFrame):
    """Weekly RMSE/MAE aggregation across panel."""
    g = preds_long.groupby(["Week", "model"], as_index=False).apply(
        lambda x: pd.Series({
            "RMSE": np.sqrt(np.mean((x["y_next"] - x["yhat"])**2)),
            "MAE":  np.mean(np.abs(x["y_next"] - x["yhat"]))
        })
    ).reset_index(drop=True)
    return g.sort_values("Week")


def fig_weekly_panel_rmse(preds_long: pd.DataFrame, fig_dir: str):
    panel = _panel_losses(preds_long)
    plt.figure(figsize=(10, 5))
    for m, g in panel.groupby("model"):
        plt.plot(pd.to_datetime(g["Week"]), g["RMSE"], label=m, linewidth=2)
    plt.title("Weekly Panel RMSE by Baseline Model")
    plt.xlabel("Week"); plt.ylabel("RMSE")
    plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
    out = os.path.join(fig_dir, "fig_baseline_weekly_panel_rmse.png")
    plt.savefig(out, dpi=180); plt.show()
    print("Saved:", out)


def fig_per_ticker_rmse_box(by_ticker: pd.DataFrame, fig_dir: str):
    order = sorted(by_ticker["model"].unique())
    data = [by_ticker.loc[by_ticker["model"] == m, "rmse"].values for m in order]
    plt.figure(figsize=(7, 5))
    plt.boxplot(data, labels=order, showfliers=False)
    plt.title("Per-Ticker RMSE Distribution (Baselines)")
    plt.ylabel("RMSE"); plt.grid(True, axis="y", alpha=0.3); plt.tight_layout()
    out = os.path.join(fig_dir, "fig_baseline_box_rmse_by_ticker.png")
    plt.savefig(out, dpi=180); plt.show()
    print("Saved:", out)


def fig_actual_vs_pred_scatter(preds_long: pd.DataFrame, fig_dir: str):
    plt.figure(figsize=(6, 6))
    for m in sorted(preds_long["model"].unique()):
        g = preds_long[preds_long["model"] == m]
        plt.scatter(g["y_next"], g["yhat"], s=10, alpha=0.45, label=m)
    lims = [
        min(preds_long["y_next"].min(), preds_long["yhat"].min()),
        max(preds_long["y_next"].max(), preds_long["yhat"].max())
    ]
    plt.plot(lims, lims, linewidth=2)  # 45°
    plt.xlim(lims); plt.ylim(lims)
    plt.xlabel("Actual spillover (y_next)")
    plt.ylabel("Predicted spillover (ŷ)")
    plt.title("Actual vs Predicted — Baselines")
    plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
    out = os.path.join(fig_dir, "fig_baseline_scatter_actual_vs_pred.png")
    plt.savefig(out, dpi=180); plt.show()
    print("Saved:", out)


# ------------------------- DM Test Utils ------------------------ #
def _newey_west_var(d, max_lag=None):
    d = np.asarray(d, dtype=float)
    T = len(d)
    if T == 0:
        return np.nan
    if max_lag is None:
        max_lag = int(np.floor(4 * (T / 100.0)**(2/9)))
    gamma0 = np.dot(d, d) / T
    s = gamma0
    for h in range(1, max_lag + 1):
        w = 1.0 - h / (max_lag + 1)
        gamma = np.dot(d[h:], d[:-h]) / T
        s += 2 * w * gamma
    return s


def _dm_test_series(e1, e2, max_lag=None):
    d = (e1**2 - e2**2)
    d_bar = d.mean()
    var_nw = _newey_west_var(d - d.mean(), max_lag=max_lag)
    if not np.isfinite(var_nw) or var_nw <= 0:
        return np.nan, np.nan, len(d)
    stat = d_bar / sqrt(var_nw / len(d))
    from math import erf
    def norm_cdf(z): return 0.5 * (1 + erf(z / np.sqrt(2)))
    p = 2 * (1 - norm_cdf(abs(stat)))
    return stat, p, len(d)


def dm_per_ticker_VAR_vs_BEKK(preds_long: pd.DataFrame, tab_dir: str, dm_panel_df: pd.DataFrame | None):
    # Pivot into per-model columns and align
    pl = preds_long.pivot_table(index=["Week", "Ticker"], columns="model", values=["y_next", "yhat"])
    pl.columns = [f"{a}_{b}" for a, b in pl.columns]  # e.g., y_next_VAR-FEVD
    needed = ["y_next_VAR-FEVD", "yhat_VAR-FEVD", "y_next_S-BEKK(1,1)", "yhat_S-BEKK(1,1)"]
    pl = pl.dropna(subset=[c for c in needed if c in pl.columns], how="any").reset_index()

    rows = []
    if all(c in pl.columns for c in needed):
        for tk, g in pl.groupby("Ticker"):
            y = g["y_next_VAR-FEVD"].values  # same as BEKK y_next after alignment
            e_var  = (y - g["yhat_VAR-FEVD"].values)
            e_bekk = (y - g["yhat_S-BEKK(1,1)"].values)
            if len(e_var) >= 12:
                stat, p, n = _dm_test_series(e_var, e_bekk, max_lag=None)
                rows.append({"ticker": tk, "DM_stat_VAR_vs_BEKK": stat, "p_value": p, "n_weeks": n})
    else:
        print("[WARN] Could not find both VAR-FEVD and S-BEKK(1,1) columns for DM per-ticker test.")

    dm_by_ticker = pd.DataFrame(rows).sort_values("p_value") if rows else pd.DataFrame(columns=["ticker","DM_stat_VAR_vs_BEKK","p_value","n_weeks"])
    out_per = os.path.join(tab_dir, "tab_dm_per_ticker.csv")
    dm_by_ticker.to_csv(out_per, index=False)
    print("Saved per-ticker DM:", out_per)

    if dm_panel_df is not None:
        out_panel = os.path.join(tab_dir, "tab_dm_overall_panel.csv")
        dm_panel_df.to_csv(out_panel, index=False)
        print("Saved overall DM (panel):", out_panel)

    return dm_by_ticker


# -------------------- GNN Learning Curves (Grid) ------------------- #
def learning_curves_grid(results_dir: str, fig_out_png: str, fig_out_pdf: str,
                         smooth_alpha: float = 0.2, ncols: int = 3, sharey: bool = True):
    log_files = sorted(glob.glob(os.path.join(results_dir, "*_logs.csv")))
    if not log_files:
        print(f"[WARN] No '*_logs.csv' files found in {results_dir}")
        return

    # drop MLP explicitly
    log_files = [f for f in log_files if "MLP" not in os.path.basename(f)]

    order_pref = ["ECC", "TAG", "WSAGE", "CHEB", "GAT", "TEMPORAL_GCN"]
    name2path = {os.path.basename(p).replace("_logs.csv", ""): p for p in log_files}
    names = [n for n in order_pref if n in name2path] + [n for n in name2path if n not in order_pref]

    import math
    n = len(names)
    nrows = math.ceil(n / ncols)

    # precompute global y-lims
    import numpy as np
    ymins, ymaxs, dfs = [], [], {}
    for name in names:
        df = pd.read_csv(name2path[name])
        for col in ["train_rmse", "val_rmse"]:
            if col in df:
                df[f"{col}_ema"] = df[col].ewm(alpha=smooth_alpha).mean()
        dfs[name] = df
        ys = []
        for col in ["train_rmse", "val_rmse", "train_rmse_ema", "val_rmse_ema"]:
            if col in df:
                ys.append(df[col].values)
        if ys:
            arr = np.concatenate(ys)
            ymins.append(np.nanmin(arr)); ymaxs.append(np.nanmax(arr))
    ylo = float(np.nanmin(ymins)) if sharey and ymins else None
    yhi = float(np.nanmax(ymaxs)) if sharey and ymaxs else None
    if sharey and ylo is not None and yhi is not None:
        pad = 0.05 * (yhi - ylo + 1e-9)
        ylo, yhi = ylo - pad, yhi + pad

    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 3.8 * nrows), squeeze=False)
    for idx, name in enumerate(names):
        r, c = divmod(idx, ncols)
        ax = axes[r][c]
        df = dfs[name]
        if "train_rmse" in df: ax.plot(df["epoch"], df["train_rmse"], label="train RMSE", linewidth=1.2)
        if "val_rmse"   in df: ax.plot(df["epoch"], df["val_rmse"],   label="val RMSE",   linewidth=1.2)
        if "train_rmse_ema" in df: ax.plot(df["epoch"], df["train_rmse_ema"], label="train RMSE (EMA)", linestyle="--")
        if "val_rmse_ema"   in df: ax.plot(df["epoch"], df["val_rmse_ema"],   label="val RMSE (EMA)",   linestyle="--")
        ax.set_title(name, fontsize=14, pad=6)
        ax.set_xlabel("Epoch"); ax.set_ylabel("RMSE")
        ax.grid(True, alpha=0.3)
        if sharey and ylo is not None and yhi is not None:
            ax.set_ylim(ylo, yhi)
        if idx == 0: ax.legend(loc="upper right", fontsize=10)

    # remove empty panels
    for k in range(len(names), nrows * ncols):
        r, c = divmod(k, ncols)
        fig.delaxes(axes[r][c])

    fig.suptitle("Learning Curves — Selected GNNs", fontsize=16, y=0.995)
    plt.tight_layout(rect=[0, 0.00, 1, 0.97])
    fig.savefig(fig_out_png, dpi=220)
    fig.savefig(fig_out_pdf)
    plt.show()
    print("Saved:", fig_out_png, "|", fig_out_pdf)


# ---------------- Best tuned curves + tuned-vs-default ------------- #
def best_tuned_curves(tune_dir: str, results_gnn: str):
    path = os.path.join(tune_dir, "best_by_model.csv")
    if not os.path.exists(path):
        print(f"[WARN] Missing best_by_model.csv at {path}; skipping best tuned curves.")
        return
    best_by_model = pd.read_csv(path)
    for _, row in best_by_model.iterrows():
        logs_path = row.get("logs_path")
        if not isinstance(logs_path, str) or not os.path.exists(logs_path):
            print(f"Skip plotting (no logs): {logs_path}")
            continue
        logs = pd.read_csv(logs_path)
        plt.figure()
        if "train_rmse" in logs: plt.plot(logs["epoch"], logs["train_rmse"], label="train RMSE")
        if "val_rmse"   in logs: plt.plot(logs["epoch"], logs["val_rmse"],   label="val RMSE")
        plt.title(f"Tuned {row['Model']} — RMSE (best run)")
        plt.xlabel("Epoch"); plt.ylabel("RMSE"); plt.legend(); plt.grid(True, alpha=0.3)
        out = os.path.join(results_gnn, f"fig_tuned_{row['Model']}_learning_curve.png")
        plt.savefig(out, dpi=180); plt.show()
        print("Saved:", out)


def tuned_vs_default_bar(results_gnn: str, tune_dir: str, top3=None):
    if top3 is None:
        top3 = ["ECC", "TAG", "WSAGE"]
    default_path = os.path.join(results_gnn, "summary.csv")
    tuned_path   = os.path.join(tune_dir, "best_by_model.csv")

    if not (os.path.exists(default_path) and os.path.exists(tuned_path)):
        print("[WARN] Missing summary.csv or best_by_model.csv; skipping tuned-vs-default plot.")
        return

    default = pd.read_csv(default_path)
    tuned   = pd.read_csv(tuned_path)
    default = default[default["Model"].isin(top3)]
    tuned   = tuned[tuned["Model"].isin(top3)]

    default["Version"] = "Default"
    tuned["Version"]   = "Tuned"

    df = pd.concat([default[["Model", "Test_RMSE", "Version"]],
                    tuned[["Model", "Test_RMSE", "Version"]]])

    plt.figure(figsize=(7, 5))
    df.pivot(index="Model", columns="Version", values="Test_RMSE").plot(kind="bar", ax=plt.gca(), width=0.7)
    plt.ylabel("Test RMSE")
    plt.title("Test RMSE before vs after hyperparameter tuning (Top-3)")
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()

    out = os.path.join(tune_dir, "fig_rmse_before_after_tuning.png")
    plt.savefig(out, dpi=180); plt.show()
    print("Saved:", out)


# ---------------------------  MLP -------------------------- #
def optional_mlp_plots(results_gnn: str):
    """If MLP artifacts exist, make its learning curve and scatter plot."""
    # learning curve(s)
    mlp_logs = glob.glob(os.path.join(results_gnn, "MLP_logs.csv"))
    for f in mlp_logs:
        df = pd.read_csv(f)
        # EMA smoothing
        for col in ["train_rmse", "val_rmse", "train_loss", "val_loss"]:
            if col in df:
                df[f"{col}_ema"] = df[col].ewm(alpha=0.2).mean()
        plt.figure()
        if "train_rmse" in df: plt.plot(df["epoch"], df["train_rmse"], label="train RMSE", linewidth=1)
        if "val_rmse"   in df: plt.plot(df["epoch"], df["val_rmse"],   label="val RMSE",   linewidth=1)
        if "train_rmse_ema" in df: plt.plot(df["epoch"], df["train_rmse_ema"], label="train RMSE (EMA)", linestyle="--")
        if "val_rmse_ema"   in df: plt.plot(df["epoch"], df["val_rmse_ema"],   label="val RMSE (EMA)",   linestyle="--")
        plt.title("MLP — RMSE"); plt.xlabel("Epoch"); plt.ylabel("RMSE")
        plt.legend(); plt.grid(True, alpha=0.3)
        out = os.path.join(results_gnn, "fig_mlp_learning_curve.png")
        plt.savefig(out, dpi=180); plt.show()
        print("Saved:", out)

    # scatter
    mlp_preds = os.path.join(results_gnn, "MLP_test_preds.parquet")
    if os.path.exists(mlp_preds):
        df = pd.read_parquet(mlp_preds)  # labeled only
        x, y = df["y_true"].values, df["y_pred"].values
        lims = [min(x.min(), y.min()), max(x.max(), y.max())]
        plt.figure(figsize=(5.5, 5.5))
        plt.scatter(x, y, s=10, alpha=0.5)
        plt.plot(lims, lims, linewidth=2)  # 45°
        plt.xlim(lims); plt.ylim(lims)
        plt.xlabel("Actual spillover (y)"); plt.ylabel("Predicted (ŷ)")
        plt.title("Actual vs Predicted — MLP")
        plt.grid(True, alpha=0.3); plt.tight_layout()
        out = os.path.join(results_gnn, "fig_mlp_scatter_actual_vs_pred.png")
        plt.savefig(out, dpi=180); plt.show()
        print("Saved:", out)


# ------------------------------ MAIN ------------------------------ #
def main():
    args = get_args()
    fig_dir = args.fig_dir
    tab_dir = args.tab_dir

    ensure_dirs(fig_dir, tab_dir)

    # ---- Baseline artifacts ----
    overall_path       = os.path.join(args.baselines_dir, "metrics_overall.csv")
    by_ticker_path     = os.path.join(args.baselines_dir, "metrics_by_ticker.csv")
    preds_unified_path = os.path.join(args.baselines_dir, "preds_unified_long.parquet")
    dm_panel_path      = os.path.join(args.baselines_dir, "dm_test.csv")

    overall_df   = _safe_read_csv(overall_path)
    by_ticker_df = _safe_read_csv(by_ticker_path)
    preds_long   = _safe_read_parquet(preds_unified_path)
    dm_panel_df  = _safe_read_csv(dm_panel_path) if os.path.exists(dm_panel_path) else None

    if overall_df is not None:
        make_overall_baseline_table(overall_df, tab_dir)
    if preds_long is not None:
        fig_weekly_panel_rmse(preds_long, fig_dir)
        fig_actual_vs_pred_scatter(preds_long, fig_dir)
    if by_ticker_df is not None:
        fig_per_ticker_rmse_box(by_ticker_df, fig_dir)
    if preds_long is not None:
        dm_per_ticker_VAR_vs_BEKK(preds_long, tab_dir, dm_panel_df)

    # ---- GNN learning curves (grid across models) ----
    out_png = os.path.join(args.results_gnn, "fig_learning_curves_grid_noMLP.png")
    out_pdf = os.path.join(args.results_gnn, "fig_learning_curves_grid_noMLP.pdf")
    learning_curves_grid(args.results_gnn, out_png, out_pdf, smooth_alpha=0.2, ncols=3, sharey=True)

    # ---- Best tuned curves (if tuning artifacts exist) ----
    # also creates best_by_model.csv if you run the Section 8 script
    best_tuned_curves(args.tune_dir, args.results_gnn)

    # ---- Tuned vs default RMSE comparison (top-3) ----
    tuned_vs_default_bar(args.results_gnn, args.tune_dir, top3=["ECC", "TAG", "WSAGE"])

    # ---- Optional: if MLP artifacts exist, plot them too ----
    optional_mlp_plots(args.results_gnn)

    print("\nAll done.")


if __name__ == "__main__":
    main()


# 10. Final Aggregation & Graph Visualisation

This section consolidates all evaluation results into a single table (CSV + LaTeX): Baselines, the default GNN runs from Section 7, and the hyper-tuned top-3 from Section 8.
It also loads the best model discovered during tuning (by Test RMSE), locates its test predictions parquet, and renders a weekly graph where node colors/sizes reflect predicted spillover for that week.
If multiple tuned models’ predictions are available, you can optionally produce a side-by-side panel to compare their spatial risk patterns.

In [None]:


import os, re, glob, argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import networkx as nx

# ----------------------------- CLI ----------------------------- #
def get_args():
    p = argparse.ArgumentParser("Final outputs: tables + graph overlays")
    p.add_argument("--base_dir", type=str, default=".")
    p.add_argument("--baselines_dir", type=str, default="./Baselines")
    p.add_argument("--results_gnn", type=str, default="./results_gnn")
    p.add_argument("--tune_dir", type=str, default="./results_gnn_tune_top3")
    p.add_argument("--graphs_dir", type=str, default="./graphs")
    p.add_argument("--graph_index_csv", type=str, default="./graphs/graphs_index.csv")
    p.add_argument("--fig_dir", type=str, default="./figures")
    p.add_argument("--save_panels", action="store_true", help="Save multi-model panel as PNG")
    return p.parse_args()


# ----------------------- Utilities & I/O ----------------------- #
def ensure_dirs(*paths):
    for p in paths:
        os.makedirs(p, exist_ok=True)

def first_existing(paths):
    for p in paths:
        if p and os.path.exists(p):
            return p
    return None

def _normalize_weeks(df):
    w = pd.to_datetime(df["Week"], errors="coerce")
    df = df.copy()
    df["_week_dt"]   = w
    df["_week_date"] = w.dt.date.astype("string")
    df["_week_str"]  = df["Week"].astype("string")
    return df

def tokens(name):
    base = os.path.basename(name).lower()
    return re.split(r"[^a-z0-9]+", base)

def infer_model_from_filename(path):
    toks = tokens(path)
    if "wsage" in toks or ("weighted" in toks and "sage" in toks):
        return "WSAGE"
    if "ecc" in toks:
        return "ECC"
    if "tag" in toks and "best" not in toks:  # avoid matching 'best_of_top3'
        return "TAG"
    return None

def find_preds_file_for_model(tune_dir, model_raw, ckpt_tag=None):
    model_raw = str(model_raw).upper().strip()
    if not os.path.isdir(tune_dir):
        return None
    cands = []
    for fname in os.listdir(tune_dir):
        if not fname.lower().endswith(".parquet"):
            continue
        if "test" not in fname.lower() or "pred" not in fname.lower():
            continue
        path = os.path.join(tune_dir, fname)
        m_from_name = infer_model_from_filename(path)
        if m_from_name == model_raw:
            cands.append(path)
        elif ckpt_tag and ckpt_tag.lower() in fname.lower():
            cands.append(path)
    # prefer *full* predictions
    cands_full = [p for p in cands if p.lower().endswith("test_preds_full.parquet")]
    return sorted(cands_full)[0] if cands_full else (sorted(cands)[0] if cands else None)

def load_and_normalize_preds(parquet_path, model_label=None):
    df = pd.read_parquet(parquet_path)
    df.columns = [str(c).strip() for c in df.columns]
    PRED_CANDS = ["yhat","y_hat","y_pred","pred","prediction","preds","pred_mean","yhat_pred","y_pred_mean"]
    TRUE_CANDS = ["y_next","ytrue","y_true","label","target","y"]
    def _find_col(cands, cols):
        cl = {c.lower(): c for c in cols}
        for c in cands:
            if c.lower() in cl:
                return cl[c.lower()]
        return None
    pred_col = _find_col(PRED_CANDS, df.columns)
    true_col = _find_col(TRUE_CANDS, df.columns)
    if pred_col is None:
        raise ValueError(f"No prediction column found in {parquet_path}.")
    rename_map = {pred_col: "yhat"}
    if true_col is not None:
        rename_map[true_col] = "y_next"
    df = df.rename(columns=rename_map)
    if "Model" not in df.columns:
        df["Model"] = model_label if model_label else "MODEL"
    if "Week" not in df.columns or "Ticker" not in df.columns:
        raise ValueError(f"Missing Week/Ticker columns in {parquet_path}")
    df["Week"] = pd.to_datetime(df["Week"], errors="coerce")
    # Collapse accidental duplicates
    agg = {"yhat":"mean"}
    if "y_next" in df.columns: agg["y_next"] = "mean"
    return df.groupby(["Week","Ticker","Model"], as_index=False).agg(agg)

def load_all_model_preds_if_available(tune_dir, rows_from_best_csv):
    dfs = []
    for _, r in rows_from_best_csv.iterrows():
        m = str(r["Model"]).upper().strip()
        ckpt_tag = str(r["ckpt_tag"]) if "ckpt_tag" in r and pd.notna(r["ckpt_tag"]) else None
        p = find_preds_file_for_model(tune_dir, m, ckpt_tag=ckpt_tag)
        if p is None:
            print(f"[WARN] No preds parquet for {m}.")
            continue
        try:
            dfm = load_and_normalize_preds(p, model_label=m)
            dfs.append(dfm)
            print(f"[INFO] Loaded preds for {m}: {os.path.basename(p)}")
        except Exception as e:
            print(f"[WARN] Failed to load preds for {m}: {e}")
    return pd.concat(dfs, ignore_index=True) if dfs else None


# --------------------- Combined Evaluation Table -------------------- #
def assign_family(m):
    if m in ["S-BEKK", "VAR-FEVD"]:
        return "Baseline"
    elif m == "MLP":
        return "MLP"
    elif m in ["ECC","TAG","WSAGE","CHEB","GAT","GCN","SAGE","GraphSAGE","ChebNet","GATv2","TEMPORAL_GCN","TAGConv"]:
        return "GNN (Selected)"
    else:
        return "Other"

def build_eval_table(baselines_dir, results_gnn, tune_dir):
    # Load sources
    metrics_overall = pd.read_csv(os.path.join(baselines_dir, "metrics_overall.csv"))
    summary         = pd.read_csv(os.path.join(results_gnn, "summary.csv"))
    tuned_csv       = os.path.join(tune_dir, "best_by_model.csv")
    tuned = pd.read_csv(tuned_csv) if os.path.exists(tuned_csv) else None

    # Normalise columns
    metrics_overall = metrics_overall.rename(columns={"model":"Model","rmse":"Test_RMSE","mae":"Test_MAE","r2":"Test_R2"})
    metrics_overall["Val_RMSE"] = None

    summary["Family"] = summary["Model"].map(assign_family)
    metrics_overall["Family"] = "Baseline"

    frames = [metrics_overall, summary]
    if tuned is not None:
        need = ["Model","Val_RMSE","Test_RMSE","Test_MAE","Test_R2"]
        tuned = tuned.rename(columns={c:c for c in need})
        tuned = tuned[need].copy()
        tuned["Family"] = "GNN (Hyper-tuned)"
        tuned["Model"] = tuned["Model"].astype(str) + " (tuned)"
        frames.append(tuned)

    cols = ["Family","Model","Val_RMSE","Test_RMSE","Test_MAE","Test_R2"]
    eval_all = pd.concat([f[cols] for f in frames], ignore_index=True).drop_duplicates().reset_index(drop=True)

    cat_order = pd.CategoricalDtype(["Baseline","MLP","GNN (Selected)","GNN (Hyper-tuned)"], ordered=True)
    eval_all["Family"] = eval_all["Family"].astype(cat_order)
    eval_all = eval_all.sort_values(["Family","Test_RMSE","Test_MAE"]).reset_index(drop=True)
    return eval_all


# ---------------------- Graph Loading & Drawing --------------------- #
def load_graph_index(graph_index_csv):
    gi = pd.read_csv(graph_index_csv)
    return _normalize_weeks(gi)

def load_week_graph(graphs_dir, graph_index, week):
    # Match by date string or exact timestamp
    w_dt = pd.to_datetime(week, errors="coerce")
    w_date = (w_dt.date().isoformat() if pd.notna(w_dt) else str(week))
    w_str = str(week)

    row = graph_index.loc[graph_index["_week_date"] == w_date]
    if row.empty: row = graph_index.loc[graph_index["_week_str"] == w_str]
    if row.empty and pd.notna(w_dt): row = graph_index.loc[graph_index["_week_dt"] == w_dt]
    if row.empty:
        raise FileNotFoundError(f"Week {week} not found in index.")

    fpath = os.path.join(graphs_dir, row.iloc[0]["file"])
    g = torch.load(fpath, weights_only=False, map_location="cpu")

    edges_np = g.edge_index.cpu().numpy().T
    w_np     = g.edge_weight.cpu().numpy() if getattr(g, "edge_weight", None) is not None else np.zeros(len(edges_np))
    tickers  = list(g.tickers)

    G = nx.DiGraph()
    for i, tk in enumerate(tickers):
        lab = float(g.y[i].item()) if (getattr(g, "y", None) is not None and i < len(g.y)) else np.nan
        G.add_node(i, ticker=tk, label=lab)
    for (s, d), w in zip(edges_np, w_np):
        G.add_edge(int(s), int(d), weight=float(w))
    return G, tickers, g

def spring_layout_cached(G, seed=42):
    # Compute layout once
    return nx.spring_layout(G, seed=seed, k=1.7, iterations=100)

def overlay_node_values(G, pos, values_by_idx, *, title, cmap="viridis", size_scale=1600, alpha=0.95, vmin=None, vmax=None, edge_alpha=0.25, save_path=None):
    vals = np.array([values_by_idx.get(i, np.nan) for i in G.nodes], dtype=float)
    if vmin is None: vmin = np.nanmin(vals)
    if vmax is None: vmax = np.nanmax(vals)
    sizes = 300 + size_scale * (np.nan_to_num(vals - vmin) / (vmax - vmin + 1e-9))

    eweights = np.array([G[u][v].get("weight", 0.0) for u, v in G.edges()], dtype=float)
    norm_vals = (eweights + 1.0) / 2.0
    edge_colors = plt.cm.bwr(norm_vals)

    fig, ax = plt.subplots(figsize=(10, 8))
    nx.draw_networkx_edges(G, pos, ax=ax, edge_color=edge_colors, alpha=edge_alpha)

    sc = nx.draw_networkx_nodes(
        G, pos, ax=ax, node_color=vals, cmap=cmap,
        node_size=sizes, alpha=alpha, edgecolors="black", linewidths=0.8,
        vmin=vmin, vmax=vmax
    )
    nx.draw_networkx_labels(G, pos, ax=ax, labels={i: G.nodes[i]["ticker"] for i in G.nodes}, font_size=8)

    cb = plt.colorbar(sc, ax=ax); cb.set_label("Predicted spillover")
    ax.set_title(title); ax.axis("off"); plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=200)
    plt.show()


# -------------------- Best Model & Predictions --------------------- #
def pick_example_week(preds_full, graph_index):
    ws_preds = pd.to_datetime(preds_full["Week"].unique())
    ws_graph = pd.to_datetime(graph_index["_week_dt"].unique())
    inter = sorted(set(ws_preds).intersection(set(ws_graph)))
    return inter[-1] if inter else ws_preds[-1]

def node_values_from_preds(preds_df, model_name, week_dt, tickers, kind="yhat"):
    df = preds_df[(preds_df["Model"].astype(str).str.upper() == str(model_name).upper()) &
                  (preds_df["Week"] == pd.to_datetime(week_dt))].copy()
    if df.empty:
        return {i: np.nan for i in range(len(tickers))}
    df = df.groupby("Ticker", as_index=True)[kind].mean().to_frame()
    vals = {}
    for i, tk in enumerate(tickers):
        vals[i] = float(df.loc[tk, kind]) if tk in df.index else np.nan
    return vals


# ------------------------------- MAIN ------------------------------ #
def main():
    args = get_args()
    ensure_dirs(args.fig_dir)

    # 1) Build & save the combined evaluation table
    eval_all = build_eval_table(args.baselines_dir, args.results_gnn, args.tune_dir)
    out_csv = os.path.join(args.results_gnn, "table_eval_all_models_incl_mlp_tuned.csv")
    out_tex = os.path.join(args.results_gnn, "table_eval_all_models_incl_mlp_tuned.tex")
    eval_all.to_csv(out_csv, index=False)
    with open(out_tex, "w") as f:
        f.write(eval_all.to_latex(index=False, float_format="%.4f",
                                  caption="Evaluation metrics on the test set (baselines, MLP, selected GNNs, and hyper-tuned GNNs).",
                                  label="tab:eval_all_models_tuned"))
    print("Saved:", out_csv, "|", out_tex)
    print(eval_all.head(12))

    # 2) Load best tuned model row (if available)
    best_csv_candidates = [
        os.path.join(args.tune_dir, "best_by_model.csv"),
        "/mnt/data/best_by_model.csv",
    ]
    best_csv_path = first_existing(best_csv_candidates)
    if best_csv_path is None:
        print("[WARN] No best_by_model.csv found; skipping graph visualisation.")
        return
    best_df = pd.read_csv(best_csv_path)
    for col in ["Val_RMSE","Test_RMSE","Test_MAE","Test_R2"]:
        if col in best_df.columns:
            best_df[col] = pd.to_numeric(best_df[col], errors="coerce")
    best_df = best_df.sort_values(by=["Test_RMSE","Val_RMSE","Test_R2"], ascending=[True, True, False], kind="mergesort").reset_index(drop=True)
    best_model_raw = str(best_df.loc[0, "Model"]).upper().strip()
    best_ckpt_tag = str(best_df.loc[0, "ckpt_tag"]) if "ckpt_tag" in best_df.columns and pd.notna(best_df.loc[0, "ckpt_tag"]) else None
    print(f"[INFO] Best model by Test_RMSE: {best_model_raw}")

    # 3) Load predictions parquet for best model (fallback: best_of_top3 file)
    preds_path = find_preds_file_for_model(args.tune_dir, best_model_raw, ckpt_tag=best_ckpt_tag)
    preds_full = None
    if preds_path:
        preds_full = load_and_normalize_preds(preds_path, model_label=best_model_raw)
        print(f"[INFO] Using predictions: {os.path.basename(preds_path)}")
    else:
        fallback = os.path.join(args.tune_dir, "best_of_top3_test_preds_full.parquet")
        if os.path.exists(fallback):
            preds_full = load_and_normalize_preds(fallback, model_label="UNKNOWN")
            print(f"[WARN] Using fallback predictions: {os.path.basename(fallback)}")
        else:
            print("[WARN] No predictions parquet found; skipping visualisation.")
            return

    # 4) Load graph index + a target week
    graph_index = load_graph_index(args.graph_index_csv)
    week = pick_example_week(preds_full, graph_index)
    G, tickers, tg = load_week_graph(args.graphs_dir, graph_index, week)
    pos = spring_layout_cached(G, seed=42)

    # 5) Overlay: single best model
    vals_best = node_values_from_preds(preds_full, best_model_raw, week, tickers, kind="yhat")
    overlay_node_values(
        G, pos, vals_best,
        title=f"{best_model_raw} — predictions (Week {pd.to_datetime(week).date()})",
        cmap="viridis", size_scale=1600,
        save_path=os.path.join(args.fig_dir, f"graph_overlay_{best_model_raw}_{pd.to_datetime(week).date()}.png")
    )

    # 6) Multi-panel comparison across tuned models (if preds exist)
    combined = load_all_model_preds_if_available(args.tune_dir, best_df.copy())
    if combined is not None:
        # sort models by Test_RMSE for ordering
        mlist = (best_df.sort_values(by=["Test_RMSE","Val_RMSE","Test_R2"], ascending=[True, True, False])
                        ["Model"].astype(str).tolist())
        # shared color scale
        all_vals = []
        for m in mlist:
            d = combined[(combined["Model"].astype(str).str.upper()==m.upper()) &
                         (combined["Week"]==pd.to_datetime(week))]
            if not d.empty:
                all_vals.extend(d["yhat"].values.tolist())
        if all_vals:
            vmin, vmax = float(np.min(all_vals)), float(np.max(all_vals))
        else:
            vmin, vmax = None, None

        k = len(mlist)
        fig, axes = plt.subplots(1, k, figsize=(5*k, 8))
        axes = [axes] if k == 1 else axes.ravel()

        eweights = np.array([G[u][v].get("weight", 0.0) for u, v in G.edges()], dtype=float)
        norm_vals = (eweights + 1.0) / 2.0
        edge_colors = plt.cm.bwr(norm_vals)

        sc = None
        for ax, m in zip(axes, mlist):
            d = combined[(combined["Model"].astype(str).str.upper()==m.upper()) &
                         (combined["Week"]==pd.to_datetime(week))]
            if d.empty:
                ax.set_axis_off(); ax.set_title(f"{m}\n(no preds)"); continue
            d = d.groupby("Ticker", as_index=True)["yhat"].mean().to_frame()
            vmap = {i: float(d.loc[tk, "yhat"]) if tk in d.index else np.nan for i, tk in enumerate(tickers)}
            vals = np.array([vmap.get(i, np.nan) for i in G.nodes], dtype=float)
            sizes = 300 + 1600 * (np.nan_to_num(vals - (vmin if vmin is not None else np.nanmin(vals))) /
                                  ((vmax if vmax is not None else np.nanmax(vals)) - (vmin if vmin is not None else np.nanmin(vals)) + 1e-9))
            nx.draw_networkx_edges(G, pos, ax=ax, edge_color=edge_colors, alpha=0.25)
            sc = nx.draw_networkx_nodes(
                G, pos, ax=ax, node_color=vals, cmap="viridis",
                node_size=sizes, alpha=0.95, edgecolors="black", linewidths=0.8,
                vmin=vmin, vmax=vmax
            )
            nx.draw_networkx_labels(G, pos, ax=ax, labels={i: G.nodes[i]["ticker"] for i in G.nodes}, font_size=7)
            ax.set_title(f"{m}\nWeek {pd.to_datetime(week).date()}"); ax.axis("off")

        cbar = fig.colorbar(sc, ax=axes, fraction=0.02, pad=0.02)
        cbar.set_label("Predicted spillover")
        plt.tight_layout()
        if args.save_panels:
            outp = os.path.join(args.fig_dir, f"graph_overlay_panel_{pd.to_datetime(week).date()}.png")
            plt.savefig(outp, dpi=200)
            print("Saved:", outp)
        plt.show()

    print("\nDone.")

if __name__ == "__main__":
    main()
