In [None]:
#!/usr/bin/env python3
"""
Compare: (1) slope->direct + yl_all (1-branch)
         (2) slope->direct + yl_all (branch1) + MacroFactors (branch2)
Requires rolling_framework (Machine) and TorchMultiBranchStrategy ("DNN_NBR").
"""

import os, sys, argparse, warnings, re
import pandas as pd
warnings.filterwarnings("ignore")

from rolling_framework import Machine  # Machine: (X,y, model_type="DNN_NBR", ...)

# -------------------------- USER CONFIG --------------------------
DATA_DIR     = "data/"       # 파일이 놓인 폴더
Y_FILE       = os.path.join(DATA_DIR, "exrets.csv")
SLOPE_FILE   = os.path.join(DATA_DIR, "slope.csv")
YL_FILE      = os.path.join(DATA_DIR, "yl_all.csv")
MACRO_FILE   = os.path.join(DATA_DIR, "MacroFactors.csv")

OUTPUT_DIR   = "./output"
OUT_CSV      = "results_slope_direct_vs_yields.csv"

# 샘플/예측 구간 (필요시 수정)
BURN_START   = "197108"
BURN_END     = "199001"
PERIOD_START = "197108"
PERIOD_END   = "202312"
HORIZON      = 12

# 사용할 만기(타깃 열) (파일에 존재하는 것만 자동 필터)
MATURITIES   = ["xr_2","xr_3","xr_5","xr_7","xr_10"]

# Grid
param_grid_case1 = {
    "dnn__optimizer__lr":           [1e-3],
    "dnn__optimizer__weight_decay": [1e-4],
    "dnn__lr_br":                  [[1e-3]],      # 1-branch
    "dnn__lr_head":                 [1e-3],
    "dnn__lr_direct":               [5e-4],       # direct map (trainable)
    "dnn__module__head_hidden":     [16],
}
param_grid_case2 = {
    "dnn__optimizer__lr":           [1e-3],
    "dnn__optimizer__weight_decay": [1e-4],
    "dnn__lr_br":                  [[1e-3, 5e-4]], # 2-branches: [branch1, branch2]
    "dnn__lr_head":                 [1e-3],
    "dnn__lr_direct":               [5e-4],
    "dnn__module__head_hidden":     [16],
}

# -------------------------- HELPERS --------------------------
def _load_csv(path, name):
    try:
        return pd.read_csv(path, index_col="Time")
    except FileNotFoundError as e:
        sys.exit(f"[ERROR] missing {name} → {e.filename}")

def _align_by_time(*dfs):
    """모든 데이터프레임 공통 인덱스(시간) 교집합으로 inner-join."""
    idx = None
    for d in dfs:
        idx = d.index if idx is None else idx.intersection(d.index)
    return [d.loc[idx].sort_index() for d in dfs]

def _build_direct_pairs(slope_cols, y_cols):
    """
    예: slope_2 → xr_2, slope_3 → xr_3 …
    숫자 suffix 일치 기반 자동 매핑. (못찾는 건 건너뜀)
    """
    def sufnum(s):
        m = re.search(r"(\d+)", s)
        return m.group(1) if m else None

    y_map = {sufnum(c): c for c in y_cols}
    pairs = []
    for sc in slope_cols:
        sn = sufnum(sc)
        yc = y_map.get(sn)
        if sn and yc:
            pairs.append((sc, yc))
    return pairs

# -------------------------- MAIN --------------------------
def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    out_csv = os.path.join(OUTPUT_DIR, OUT_CSV)

    # 1) Load
    y     = _load_csv(Y_FILE,      "exrets (target)")
    slope = _load_csv(SLOPE_FILE,  "slope (direct inputs)")
    yl    = _load_csv(YL_FILE,     "yl_all (branch1)")
    macro = _load_csv(MACRO_FILE,  "MacroFactors (branch2)")

    # 2) Columns & align
    y_cols = [c for c in MATURITIES if c in y.columns]
    if not y_cols:
        sys.exit(f"[ERROR] None of {MATURITIES} found in exrets columns.")
    y = y[y_cols]

    # 시간축 정렬 (case2 위해 macro까지 함께)
    y, slope, yl, macro = _align_by_time(y, slope, yl, macro)

    # 3) Direct map pairs (slope_* → xr_*)
    direct_pairs = _build_direct_pairs(slope.columns.tolist(), y_cols)
    if not direct_pairs:
        print("[WARN] No direct pairs were matched from slope to y; training will proceed without direct residuals.")

    # 4) 두 케이스 구성
    rows = []

    # ---- Case 1: slope direct + yl_all(Branch1) ----
    X_case1 = pd.concat([slope, yl], axis=1)
    opt_case1 = {
        "branches": [
            {"cols": yl.columns.tolist(), "hidden": (16,), "drop": 0.1},  # branch1
        ],
        "direct_map": direct_pairs,   # [('slope_2','xr_2'), ...]
        "freeze_direct": False,       # OLS로 초기화 후 학습 허용
        "head_hidden": 16,
    }
    m1 = Machine(
        X_case1, y,
        model_type       = "DNN_NBR",
        option           = opt_case1,
        params_grid      = param_grid_case1,
        burn_in_start    = BURN_START,
        burn_in_end      = BURN_END,
        period           = [PERIOD_START, PERIOD_END],
        forecast_horizon = HORIZON,
    )
    print("\n▶ Case 1: slope→direct + yl_all (1-branch)")
    m1.training()
    r2_1  = m1.R2OOS()     # Series by maturity
    mse_1 = m1.MSEOOS()    # Series by maturity

    for mty in r2_1.index:
        rows.append({"case": "direct+slope | yl_all",
                     "maturity": mty,
                     "R2_OOS": float(r2_1[mty]),
                     "MSE": float(mse_1[mty])})

    # ---- Case 2: slope direct + yl_all(Branch1) + MacroFactors(Branch2) ----
    # X는 두 브랜치의 피처를 모두 포함하는 합집합을 전달하면 됩니다.
    X_case2 = pd.concat([slope, yl, macro], axis=1)
    # 중복열 제거(있다면)
    X_case2 = X_case2.loc[:, ~X_case2.columns.duplicated()].copy()

    opt_case2 = {
        "branches": [
            {"cols": yl.columns.tolist(),    "hidden": (16,), "drop": 0.1},  # branch1
            {"cols": macro.columns.tolist(), "hidden": (16,), "drop": 0.1},  # branch2
        ],
        "direct_map": direct_pairs,
        "freeze_direct": False,
        "head_hidden": 16,
    }
    m2 = Machine(
        X_case2, y,
        model_type       = "DNN_NBR",
        option           = opt_case2,
        params_grid      = param_grid_case2,  # lr_br 길이=2
        burn_in_start    = BURN_START,
        burn_in_end      = BURN_END,
        period           = [PERIOD_START, PERIOD_END],
        forecast_horizon = HORIZON,
    )
    print("\n▶ Case 2: slope→direct + yl_all(branch1) + MacroFactors(branch2)")
    m2.training()
    r2_2  = m2.R2OOS()
    mse_2 = m2.MSEOOS()

    for mty in r2_2.index:
        rows.append({"case": "direct+slope | yl_all + macro",
                     "maturity": mty,
                     "R2_OOS": float(r2_2[mty]),
                     "MSE": float(mse_2[mty])})

    # 5) Save
    out_df = (pd.DataFrame(rows)
                .sort_values(["maturity", "case"])
                .reset_index(drop=True))
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    out_df.to_csv(os.path.join(OUTPUT_DIR, OUT_CSV), index=False)
    print(f"\n★ Saved → {os.path.join(OUTPUT_DIR, OUT_CSV)}")
    display(out_df) if "display" in globals() else None


if __name__ == "__main__":
    # Jupyter에서 생기는 --f 인자를 무시
    ap = argparse.ArgumentParser(add_help=False)
    ap.add_argument("--out", default=None)
    _, _ = ap.parse_known_args()
    main()

DNN_DUAL rolling:   4%|▍         | 20/520 [03:47<1:34:56, 11.39s/it]


