In [None]:
# =============================================
# p-out-of-q (ARM, macro-MLP)
# - Prints R2OOS by maturity for each baseline: naive / condmean / cs_yhat
# =============================================
import os, sys, re, time, warnings
from typing import Dict, Any, List
import numpy as np
import pandas as pd
warnings.filterwarnings("ignore")

from rolling_framework import Machine

# ---------------- Parameters ----------------
TOP_P   = 3
TOTAL_Q = 5
if TOP_P > TOTAL_Q:
    sys.exit("[ERROR] p must be <= q")

DATA_DIR    = "data/"
Y_FILE      = os.path.join(DATA_DIR, "exrets.csv")
SLOPE_FILE  = os.path.join(DATA_DIR, "slope.csv")
YL_FILE     = os.path.join(DATA_DIR, "yl_all.csv")
MACRO_FILE  = os.path.join(DATA_DIR, "MacroFactors.csv")
CS_YHAT_RAW = os.path.join(DATA_DIR, "cs_yhat.csv")

BURN_START, BURN_END     = "197108", "199001"
PERIOD_START, PERIOD_END = "197108", "202312"
HORIZON = 12
MATURITIES = ["xr_2","xr_3","xr_5","xr_7","xr_10"]

# ---------------- Helpers ----------------
def _load_csv(path):
    try: return pd.read_csv(path, index_col="Time")
    except FileNotFoundError as e: sys.exit(f"[ERROR] missing file → {e.filename}")

def _align_time(*dfs):
    idx = None
    for d in dfs:
        idx = d.index if idx is None else idx.intersection(d.index)
    return [d.loc[idx].sort_index() for d in dfs]

def _clean_cs_to_tmp(cs_df, y_cols, tmp_path):
    cs = cs_df.copy()
    if cs.index.duplicated().any():
        cs = cs[~cs.index.duplicated(keep="last")]
    cs = cs.reindex(columns=y_cols)
    cs.to_csv(tmp_path)
    return tmp_path

# ---------------- Load ----------------
y     = _load_csv(Y_FILE)
slope = _load_csv(SLOPE_FILE)
yl    = _load_csv(YL_FILE)
macro = _load_csv(MACRO_FILE)

y_cols = [c for c in MATURITIES if c in y.columns]
if not y_cols:
    sys.exit("[ERROR] MATURITIES not found in exrets.csv")
y = y[y_cols]

y, slope, yl, macro = _align_time(y, slope, yl, macro)
X_macro = pd.concat([slope, macro], axis=1)

cs_raw = _load_csv(CS_YHAT_RAW)
CS_YHAT_CLEAN = os.path.join(DATA_DIR, "_cs_yhat_clean_tmp.csv")
CS_YHAT_PATH  = _clean_cs_to_tmp(cs_raw, list(y.columns), CS_YHAT_CLEAN)

print("✓ data:", {"y": y.shape, "X_macro": X_macro.shape})
print("✓ p-out-of-q:", TOP_P, "/", TOTAL_Q)

# ---------------- Model ----------------
BASE_OPT = {
    "base_on": True,
    "base_cols":   list(slope.columns),
    "target_cols": list(y.columns),
    "residual_kind": "mlp",
    "feature_cols": list(macro.columns),
    "standardize_res": True,
    "mlp_hidden": (16,),
    "mlp_dropout": 0.1,
    "mlp_lr": 1e-3,
    "mlp_wd": 1e-4,
    "mlp_epochs": 200,
    "mlp_patience": 20,
    "seed": 0,
}
BASE_GRID = {
    "arm__residual_model__module__hidden": [(16,)],
    "arm__residual_model__module__dropout": [0.2],
    "arm__residual_model__optimizer__lr": [1e-3],
    "arm__residual_model__optimizer__weight_decay": [1e-2],
}

# ---------------- Run ----------------
def run_once(seed: int) -> Dict[str, Any]:
    opt = dict(BASE_OPT); opt["seed"] = int(seed)
    try:
        import torch; torch.manual_seed(seed)
    except Exception: pass
    np.random.seed(seed)

    m = Machine(
        X_macro, y, "ARM",
        option=opt, params_grid=BASE_GRID,
        burn_in_start=BURN_START, burn_in_end=BURN_END,
        period=[PERIOD_START, PERIOD_END], forecast_horizon=HORIZON
    )
    print(f"\n▶ Run seed={seed}")
    t0 = time.time()
    m.training()
    elapsed = time.time() - t0

    r2_naive = m.R2OOS(baseline="naive")
    r2_cond  = m.R2OOS(baseline="condmean")
    r2_cs    = m.R2OOS(baseline="cs_yhat", cs_path=CS_YHAT_PATH)

    print("  R2OOS (naive):")
    print(r2_naive.round(4))
    print("  R2OOS (condmean):")
    print(r2_cond.round(4))
    print("  R2OOS (cs_yhat):")
    print(r2_cs.round(4))
    print(f"  elapsed: {elapsed/60:.2f} min")

    return {"seed": seed, "r2_naive": r2_naive, "r2_cond": r2_cond, "r2_cs": r2_cs}

# ---------------- p-out-of-q -----------
results = [run_once(s) for s in range(1, TOTAL_Q + 1)]

# rank by arbitrary maturity average for reference only (no mean displayed)
rank_df = pd.DataFrame({
    "seed": [r["seed"] for r in results]
})
print("\n=== Completed runs ===")
print(rank_df)

# top-p selection by order
top_seeds = rank_df.head(TOP_P)["seed"].tolist()
print(f"\nTop-{TOP_P} seeds:", top_seeds)

def _avg_over_top(key: str) -> pd.Series:
    arr = [r[key].reindex(y.columns) for r in results if r["seed"] in top_seeds]
    return pd.concat(arr, axis=1).mean(axis=1) if arr else pd.Series(np.nan, index=y.columns)

avg_naive = _avg_over_top("r2_naive")
avg_cond  = _avg_over_top("r2_cond")
avg_cs    = _avg_over_top("r2_cs")

print("\n=== Per-maturity average R2OOS across top-p runs ===")
out_tbl = pd.DataFrame({
    "avg_R2OOS_naive":   avg_naive,
    "avg_R2OOS_condmean":avg_cond,
    "avg_R2OOS_cs_yhat": avg_cs
})
print(out_tbl.round(4))

DNN_DUAL rolling:   4%|▍         | 20/520 [03:47<1:34:56, 11.39s/it]


