In [2]:
import os
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets

from dataclasses import dataclass, field, asdict, is_dataclass
from typing import List, Dict, Optional, Any, Union, TypedDict, Tuple
from collections import Counter
from datetime import datetime, date
from pandas.testing import assert_series_equal


# pd.set_option('display.max_rows', None)  display all rows
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 1000)
pd.set_option("display.max_colwidth", 50)
pd.set_option("display.precision", 4)


# ==============================================================================
# GLOBAL SETTINGS: The "Control Panel" for the Strategy
# ==============================================================================

GLOBAL_SETTINGS = {
    # ENVIRONMENT (The "Where")
    "benchmark_ticker": "SPY",
    "calendar_ticker": "SPY",  # Used as the "Master Clock" for trading days
    # DATA SANITIZER (The "Glitches & Gaps" Protector)
    "handle_zeros_as_nan": True,  # Convert 0.0 prices to NaN to prevent math errors
    "max_data_gap_ffill": 1,  # Max consecutive days to "Forward Fill" missing data
    # IMPLICATION OF nan_price_replacement:
    # - This defines what happens if the "Forward Fill" limit is exceeded.
    # - If set to 0.0: A permanent data gap will look like a "total loss" (-100%).
    #   The equity curve will plummet. Good for "disaster detection."
    #   Sharpe and Sharpe(ATR) drop because: return (gets smaller) / std (gets larger)
    # - If set to np.nan: A permanent gap will cause portfolio calculations to return NaN.
    #   The chart may break or show gaps. Good for "math integrity."
    "nan_price_replacement": 0.0,
    # STRATEGY PARAMETERS (The "How")
    "atr_period": 14,  # Used for volatility normalization
    "quality_window": 252,  # 1 year lookback for liquidity/quality stats
    "quality_min_periods": 126,  # Min history required to judge a stock
    # QUALITY THRESHOLDS (The "Rules")
    "thresholds": {
        # HARD LIQUIDITY FLOOR
        # Logic: Calculates (Adj Close * Volume) daily, then takes the ROLLING MEDIAN
        # over the quality_window (252 days). Filters out stocks where the
        # typical daily dollar turnover is below this absolute value.
        "min_median_dollar_volume": 1_000_000,
        # DYNAMIC LIQUIDITY CUTOFF (Relative to Universe)
        # Logic: On the decision date, the engine calculates the X-quantile
        # of 'RollMedDollarVol' across ALL available stocks.
        # Setting this to 0.40 calculates the 60th percentile and requires
        # stocks to be above it‚Äîeffectively keeping only the TOP 60% of the market.
        "min_liquidity_percentile": 0.40,
        # PRICE/VOLUME STALENESS
        # Logic: Creates a binary flag (1 if Volume is 0 OR High equals Low).
        # It then calculates the ROLLING MEAN of this flag.
        # A value of 0.05 means the stock is rejected if it was "stale"
        # for more than 5% of the trading days in the rolling window.
        "max_stale_pct": 0.05,
        # DATA INTEGRITY (FROZEN VOLUME)
        # Logic: Checks if Volume is identical to the previous day (Volume.diff() == 0).
        # It calculates the ROLLING SUM of these occurrences over the window.
        # If the exact same volume is reported more than 10 times, the stock
        # is rejected as having "frozen" or low-quality data.
        "max_same_vol_count": 10,
    },
}


# ==============================================================================
# SECTION A: CORE KERNELS & QUANT UTILITIES (THE SAFE ROOM)
# ==============================================================================


class QuantUtils:
    """
    MATHEMATICAL KERNEL REGISTRY: THE SINGLE SOURCE OF TRUTH.
    !!! DANGER: DO NOT REFACTOR NULL HANDLING !!!
    """

    @staticmethod
    def compute_returns(
        series: Union[pd.Series, pd.DataFrame],
    ) -> Union[pd.Series, pd.DataFrame]:
        # 1. Calculate raw percentage change
        returns_WITH_BOUNDARY_NAN = series.pct_change()

        # 2. THE AI-PROOF GUARDRAIL (ROBUST VERSION)
        if len(returns_WITH_BOUNDARY_NAN) > 0:
            # We look at the first observation (scalar for Series, row for DataFrame)
            first_obs = returns_WITH_BOUNDARY_NAN.iloc[0]

            # np.all handles both a single True/False AND a list of True/False
            if not np.all(pd.isna(first_obs)):
                raise AssertionError(
                    "!!! REGRESSION: Leading NaN boundary was destroyed !!!"
                )

        # 3. Numerical Stability replacement
        if isinstance(returns_WITH_BOUNDARY_NAN, (pd.Series, pd.DataFrame)):
            returns_WITH_BOUNDARY_NAN = returns_WITH_BOUNDARY_NAN.replace(
                [np.inf, -np.inf], np.nan
            )

        return returns_WITH_BOUNDARY_NAN

    @staticmethod
    def calculate_total_gain(price_series: pd.Series) -> float:
        clean = price_series.dropna()
        if len(clean) < 2:
            return 0.0
        res = (clean.iloc[-1] / clean.iloc[0]) - 1
        return float(res) if np.isfinite(res) else 0.0

    @staticmethod
    def calculate_sharpe(
        returns_WITH_BOUNDARY_NAN: pd.Series, periods: int = 252
    ) -> float:
        mu, std = returns_WITH_BOUNDARY_NAN.mean(), returns_WITH_BOUNDARY_NAN.std()
        if pd.isna(std) or std < 1e-8:
            return 0.0
        return float((mu / std) * np.sqrt(periods))

    @staticmethod
    def calculate_sharpe_atr(
        returns_WITH_BOUNDARY_NAN: pd.Series, atrp_series: pd.Series
    ) -> float:
        avg_atrp = atrp_series.mean()
        if pd.isna(avg_atrp) or avg_atrp < 1e-8:
            return 0.0
        return float(returns_WITH_BOUNDARY_NAN.mean() / avg_atrp)


# ==============================================================================
# SECTION B: STRATEGY HELPERS & FEATURES
# ==============================================================================
# ... (Keep generate_features, calculate_gain, calculate_sharpe,
#      calculate_sharpe_atr, calculate_buy_and_hold_performance as is) ...


def generate_features(
    df_ohlcv: pd.DataFrame,
    df_indices: pd.DataFrame = None,
    benchmark_ticker: str = "SPY",
    atr_period: int = 14,
    rsi_period: int = 14,
    quality_window: int = 252,
    quality_min_periods: int = 126,
) -> pd.DataFrame:

    print(f"‚ö° Generating SOTA Quant Features (Benchmark: {benchmark_ticker})...")

    # 1. Sort and Group
    if not df_ohlcv.index.is_monotonic_increasing:
        df_ohlcv = df_ohlcv.sort_index()
    grouped = df_ohlcv.groupby(level="Ticker")

    # 2. VECTORIZED ATR (Wilder's)
    prev_close = grouped["Adj Close"].shift(1)
    high_low = df_ohlcv["Adj High"] - df_ohlcv["Adj Low"]
    high_prev = abs(df_ohlcv["Adj High"] - prev_close)
    low_prev = abs(df_ohlcv["Adj Low"] - prev_close)
    tr = pd.concat([high_low, high_prev, low_prev], axis=1).max(axis=1, skipna=False)

    atr = (
        tr.groupby(level="Ticker")
        .ewm(alpha=1 / atr_period, adjust=False)
        .mean()
        .reset_index(level=0, drop=True)
    )
    atrp = (atr / df_ohlcv["Adj Close"]).replace([np.inf, -np.inf], np.nan)

    # 3. VECTORIZED RSI
    delta = grouped["Adj Close"].diff()
    up = delta.clip(lower=0)
    down = -1 * delta.clip(upper=0)

    ma_up = (
        up.groupby(level="Ticker")
        .ewm(alpha=1 / rsi_period, adjust=False)
        .mean()
        .reset_index(level=0, drop=True)
    )
    ma_down = (
        down.groupby(level="Ticker")
        .ewm(alpha=1 / rsi_period, adjust=False)
        .mean()
        .reset_index(level=0, drop=True)
    )

    rsi = 100 - (100 / (1 + (ma_up / ma_down)))
    rsi = rsi.replace([np.inf, -np.inf], 50).fillna(50)

    # 4. OBV (Ticker Specific)
    direction = np.sign(delta).fillna(0)
    obv_raw = (direction * df_ohlcv["Volume"]).groupby(level="Ticker").cumsum()
    obv_roll_mean = (
        obv_raw.groupby(level="Ticker")
        .rolling(21)
        .mean()
        .reset_index(level=0, drop=True)
    )
    obv_roll_std = (
        obv_raw.groupby(level="Ticker")
        .rolling(21)
        .std()
        .reset_index(level=0, drop=True)
    )
    obv_score = (
        ((obv_raw - obv_roll_mean) / obv_roll_std)
        .fillna(0.0)
        .clip(lower=-5.0, upper=5.0)
    )

    # Dollar Volume
    dollar_vol_series = df_ohlcv["Adj Close"] * df_ohlcv["Volume"]

    # 5. BENCHMARK FEATURES (Price & Volume)
    bench_close_series = None
    bench_vol_series = None

    found_bench = False
    if (
        df_indices is not None
        and benchmark_ticker in df_indices.index.get_level_values(0)
    ):
        try:
            bench_close_series = df_indices.xs(benchmark_ticker, level=0)["Adj Close"]
            bench_vol_series = df_indices.xs(benchmark_ticker, level=0)["Volume"]
            found_bench = True
        except Exception:
            pass
    if not found_bench and benchmark_ticker in df_ohlcv.index.get_level_values(0):
        try:
            bench_close_series = df_ohlcv.xs(benchmark_ticker, level=0)["Adj Close"]
            bench_vol_series = df_ohlcv.xs(benchmark_ticker, level=0)["Volume"]
            found_bench = True
        except Exception:
            pass

    # Initialize Containers
    rel_strength_21 = pd.Series(0.0, index=df_ohlcv.index)
    spy_rvol = pd.Series(1.0, index=df_ohlcv.index)
    spy_obv_score = pd.Series(0.0, index=df_ohlcv.index)  # <--- NEW CONTAINER

    if found_bench:
        try:
            # A. Relative Strength
            bench_close_aligned = bench_close_series.reindex(
                df_ohlcv.index.get_level_values("Date")
            ).values
            rel_ratio = df_ohlcv["Adj Close"] / bench_close_aligned
            rel_strength_21 = (
                rel_ratio.groupby(level="Ticker")
                .pct_change(21, fill_method=None)
                .fillna(0.0)
            )

            # B. Spy RVol (Magnitude)
            bench_dvol = bench_close_series * bench_vol_series
            bench_dvol_avg = bench_dvol.rolling(21).mean()
            bench_rvol_raw = (bench_dvol / bench_dvol_avg).fillna(1.0)

            # C. SPY OBV Score (Direction) <--- NEW
            # Calculate OBV for SPY specifically
            spy_delta = bench_close_series.diff()
            spy_direction = np.sign(spy_delta).fillna(0)
            spy_obv_raw = (spy_direction * bench_vol_series).cumsum()

            # Normalize SPY OBV (Z-Score)
            spy_obv_mean = spy_obv_raw.rolling(21).mean()
            spy_obv_std = spy_obv_raw.rolling(21).std()
            spy_obv_z = (
                ((spy_obv_raw - spy_obv_mean) / spy_obv_std)
                .fillna(0.0)
                .clip(lower=-5.0, upper=5.0)
            )

            # D. BROADCAST TO ALL TICKERS
            # Reindex creates a Series aligned to the full DataFrame (Date-matched)
            spy_rvol_values = (
                bench_rvol_raw.reindex(df_ohlcv.index.get_level_values("Date"))
                .fillna(1.0)
                .values
            )
            spy_obv_values = (
                spy_obv_z.reindex(df_ohlcv.index.get_level_values("Date"))
                .fillna(0.0)
                .values
            )

            spy_rvol = pd.Series(spy_rvol_values, index=df_ohlcv.index).clip(upper=10.0)
            spy_obv_score = pd.Series(spy_obv_values, index=df_ohlcv.index)  # <--- NEW

        except Exception as e:
            print(f"‚ö†Ô∏è Benchmark Math Error: {e}")

    # 6. TICKER RELATIVE VOLUME (RVol)
    dvol_grouped = dollar_vol_series.groupby(level="Ticker")
    dvol_avg = dvol_grouped.rolling(21).mean().reset_index(level=0, drop=True)
    ticker_rvol = (
        (dollar_vol_series / dvol_avg)
        .replace([np.inf, -np.inf], 1.0)
        .fillna(1.0)
        .clip(upper=10.0)
    )

    # 7. MOMENTUM / RETURN FEATURES
    daily_returns = grouped["Adj Close"].pct_change(1, fill_method=None)
    roc_1 = daily_returns
    roc_3 = grouped["Adj Close"].pct_change(3, fill_method=None)
    roc_5 = grouped["Adj Close"].pct_change(5, fill_method=None)
    roc_10 = grouped["Adj Close"].pct_change(10, fill_method=None)
    roc_21 = grouped["Adj Close"].pct_change(21, fill_method=None)

    # 8. VOLATILITY REGIME
    returns_grouped = daily_returns.groupby(level="Ticker")
    std_5 = returns_grouped.rolling(5).std().reset_index(level=0, drop=True)
    std_21 = returns_grouped.rolling(21).std().reset_index(level=0, drop=True)

    if std_5.index.nlevels > df_ohlcv.index.nlevels:
        std_5 = std_5.reset_index(level=0, drop=True)
        std_21 = std_21.reset_index(level=0, drop=True)

    vol_regime = (std_5 / std_21).replace([np.inf, -np.inf], 1.0)

    # 9. MERGE
    indicator_df = pd.DataFrame(
        {
            "ATR": atr,
            "ATRP": atrp,
            "RSI": rsi,
            "RelStrength": rel_strength_21,
            "VolRegime": vol_regime,
            "RVol": ticker_rvol,
            "Spy_RVol": spy_rvol,
            "OBV_Score": obv_score,
            "Spy_OBV_Score": spy_obv_score,  # <--- NEW
            "ROC_1": roc_1,
            "ROC_3": roc_3,
            "ROC_5": roc_5,
            "ROC_10": roc_10,
            "ROC_21": roc_21,
        }
    )

    # 10. Quality/Liquidity Features
    quality_temp_df = pd.DataFrame(
        {
            "IsStale": np.where(
                (df_ohlcv["Volume"] == 0)
                | (df_ohlcv["Adj High"] == df_ohlcv["Adj Low"]),
                1,
                0,
            ),
            "DollarVolume": dollar_vol_series,
            "HasSameVolume": (grouped["Volume"].diff() == 0).astype(int),
        },
        index=df_ohlcv.index,
    )

    rolling_result = (
        quality_temp_df.groupby(level="Ticker")
        .rolling(window=quality_window, min_periods=quality_min_periods)
        .agg({"IsStale": "mean", "DollarVolume": "median", "HasSameVolume": "sum"})
        .rename(
            columns={
                "IsStale": "RollingStalePct",
                "DollarVolume": "RollMedDollarVol",
                "HasSameVolume": "RollingSameVolCount",
            }
        )
        .reset_index(level=0, drop=True)
    )

    return pd.concat([indicator_df, rolling_result], axis=1)


def verify_feature_engineering_integrity():
    """
    üõ°Ô∏è TRIPWIRE: Validates Feature Engineering Logic.
    Enforces:
    1. Day 1 ATR must be NaN (No PrevClose).
    2. Wilder's Smoothing must use Alpha = 1/Period.
    3. Recursion must match manual calculation.
    """
    print("\n--- üõ°Ô∏è Starting Feature Engineering Audit ---")

    # 1. Create Synthetic Data (3 Days)
    # Day 1: High-Low = 10. No PrevClose.
    # Day 2: High-Low = 20. Gap up implies TR might be larger.
    # Day 3: High-Low = 10.
    dates = pd.to_datetime(["2020-01-01", "2020-01-02", "2020-01-03"])
    idx = pd.MultiIndex.from_product([["TEST"], dates], names=["Ticker", "Date"])

    df_mock = pd.DataFrame(
        {
            "Adj Open": [100, 110, 110],
            "Adj High": [110, 130, 120],
            "Adj Low": [100, 110, 110],
            "Adj Close": [105, 120, 115],  # PrevClose: NaN, 105, 120
            "Volume": [1000, 1000, 1000],
        },
        index=idx,
    )

    # 2. Run the Generator
    # We use Period=2 to make manual math easy (Alpha = 1/2 = 0.5)
    feats = generate_features(
        df_mock, atr_period=2, rsi_period=2, quality_min_periods=1
    )
    atr_series = feats["ATR"]

    # 3. MANUAL CALCULATION (The "Truth")
    # Day 1:
    #   TR = Max(H-L, |H-PC|, |L-PC|)
    #   TR = Max(10, NaN, NaN) -> NaN (Because skipna=False)
    #   Expected ATR: NaN

    # Day 2:
    #   PrevClose = 105
    #   H-L=20, |130-105|=25, |110-105|=5
    #   TR = 25
    #   Expected ATR: First valid observation = 25.0

    # Day 3:
    #   PrevClose = 120
    #   H-L=10, |120-120|=0, |110-120|=10
    #   TR = 10
    #   Wilder's Smoothing (Alpha=0.5):
    #   ATR_3 = (TR_3 * alpha) + (ATR_2 * (1-alpha))
    #   ATR_3 = (10 * 0.5) + (25 * 0.5) = 5 + 12.5 = 17.5

    print(f"Audit Values:\n{atr_series.values}")

    # 4. ASSERTIONS
    try:
        # Check Day 1
        if not np.isnan(atr_series.iloc[0]):
            raise AssertionError(
                f"Day 1 Regression: Expected NaN, got {atr_series.iloc[0]}. (Check skipna=False)"
            )

        # Check Day 2 (Initialization)
        if not np.isclose(atr_series.iloc[1], 25.0):
            raise AssertionError(
                f"Initialization Regression: Expected 25.0, got {atr_series.iloc[1]}."
            )

        # Check Day 3 (Recursion)
        if not np.isclose(atr_series.iloc[2], 17.5):
            raise AssertionError(
                f"Wilder's Logic Regression: Expected 17.5, got {atr_series.iloc[2]}. (Check Alpha=1/N)"
            )

        print("‚úÖ FEATURE INTEGRITY PASSED: Wilder's ATR logic is strictly enforced.")

    except AssertionError as e:
        print(f"üî• LOGIC FAILURE: {str(e)}")
        raise e


def _compute_vol_adjusted_performance(
    prices: pd.DataFrame, atrp_matrix: pd.DataFrame, weights: pd.Series
) -> Tuple[pd.Series, pd.Series, pd.Series]:
    """
    KERNEL: Pure Matrix Math.
    Calculates drift-adjusted portfolio value and volatility-normalized returns.
    """
    # 1. Equity Curve Logic (Price-Weighted Drift)
    norm_prices = prices.div(prices.bfill().iloc[0])
    weighted_components = norm_prices.mul(weights, axis=1)
    equity_curve = weighted_components.sum(axis=1)

    # !!! MANDATORY: Use QuantUtils to preserve boundary NaN !!!
    returns_WITH_BOUNDARY_NAN = QuantUtils.compute_returns(equity_curve)

    # 2. Portfolio ATRP Logic (Weighted Volatility)
    current_weights = weighted_components.div(equity_curve, axis=0)
    w, a = current_weights.align(atrp_matrix, join="inner", axis=0)
    portfolio_atrp = (w * a).sum(axis=1)

    return equity_curve, returns_WITH_BOUNDARY_NAN, portfolio_atrp


def _prepare_initial_weights(tickers: List[str]) -> pd.Series:
    """
    METADATA: Converts a list of tickers into a weight map.
    Example: ['AAPL', 'AAPL', 'TSLA'] -> {'AAPL': 0.66, 'TSLA': 0.33}
    """
    ticker_counts = Counter(tickers)
    total = len(tickers)
    return pd.Series({t: c / total for t, c in ticker_counts.items()})


def calculate_buy_and_hold_performance(
    df_close_wide: pd.DataFrame,  # Use the WIDE version
    df_atrp_wide: pd.DataFrame,  # Use the WIDE version
    tickers: List[str],
    start_date: pd.Timestamp,
    end_date: pd.Timestamp,
):
    if not tickers:
        return pd.Series(), pd.Series(), pd.Series()

    initial_weights = _prepare_initial_weights(tickers)

    ################################
    # # SLICE - This is fast because no pivot/unstack is happening
    # p_slice = df_close_wide[initial_weights.index].loc[start_date:end_date]
    # a_slice = df_atrp_wide[initial_weights.index].loc[start_date:end_date]

    # SLICE (Fix Part B)
    ticker_list = initial_weights.index.tolist()
    p_slice = df_close_wide.reindex(columns=ticker_list).loc[start_date:end_date]
    a_slice = df_atrp_wide.reindex(columns=ticker_list).loc[start_date:end_date]

    ################################

    # KERNEL - Pure Math
    return _compute_vol_adjusted_performance(p_slice, a_slice, initial_weights)


def calculate_summary_gain(price_series: pd.Series) -> float:
    """REPORTING: Returns the total return of a single series."""
    if price_series.dropna().shape[0] < 2:
        return 0.0
    # (Final Price / Starting Price) - 1
    res = (price_series.ffill().iloc[-1] / price_series.bfill().iloc[0]) - 1
    return float(res) if np.isfinite(res) else 0.0


def calculate_cross_sectional_gain(price_df: pd.DataFrame) -> pd.Series:
    """RANKING: Returns the total return for every ticker in the universe."""
    if price_df.empty:
        return pd.Series(dtype=float)
    # Vectorized calculation across all columns (tickers)
    res = (price_df.ffill().iloc[-1] / price_df.bfill().iloc[0]) - 1
    return res.replace([np.inf, -np.inf], np.nan).fillna(0.0)


def calculate_summary_sharpe(return_series: pd.Series) -> float:
    """REPORTING: Returns a single Reward value."""
    if return_series.dropna().shape[0] < 2:
        return 0.0
    mu, std = return_series.mean(), return_series.std()

    # SENIOR FIX: Volatility floor to prevent 'Infinity' or 'Exploding' rewards
    if std < 1e-6:
        return 0.0

    with np.errstate(divide="ignore", invalid="ignore"):
        res = (mu / std) * np.sqrt(252)
    return float(res) if np.isfinite(res) else 0.0


def calculate_cross_sectional_sharpe(return_df: pd.DataFrame) -> pd.Series:
    """RANKING: Returns a Series of values for the whole universe."""
    if return_df.empty:
        return pd.Series(dtype=float)
    mu, std = return_df.mean(), return_df.std()

    with np.errstate(divide="ignore", invalid="ignore"):
        res = (mu / std) * np.sqrt(252)

    # SENIOR FIX: Convert 'Broken' data (std=0) into 0.0 reward
    return res.replace([np.inf, -np.inf], np.nan).fillna(0.0)


def calculate_summary_sharpe_atr(
    return_series: pd.Series, atrp_input: Union[pd.Series, float]
) -> float:
    """REPORTING: Returns a single Reward value normalized by Volatility."""
    if return_series.dropna().shape[0] < 2:
        return 0.0
    avg_atrp = atrp_input.mean() if hasattr(atrp_input, "mean") else atrp_input

    if avg_atrp < 1e-6:
        return 0.0  # Safety floor

    with np.errstate(divide="ignore", invalid="ignore"):
        res = return_series.mean() / avg_atrp
    return float(res) if np.isfinite(res) else 0.0


def calculate_cross_sectional_sharpe_atr(
    return_df: pd.DataFrame, atrp_series: pd.Series
) -> pd.Series:
    """RANKING: Returns a Series of Volatility-normalized values."""
    with np.errstate(divide="ignore", invalid="ignore"):
        res = return_df.mean() / atrp_series
    return res.replace([np.inf, -np.inf], np.nan).fillna(0.0)


# ==============================================================================
# SECTION C: METRIC REGISTRY
# ==============================================================================


class MarketObservation(TypedDict):
    """
    The 'STATE' (Observation) in Reinforcement Learning.
    This defines the context given to the agent to make a decision.
    """

    # --- The Movie (Time Series) ---
    lookback_returns: pd.DataFrame  # (Time x Tickers)
    lookback_close: pd.DataFrame  # (Time x Tickers)

    # --- The Snapshot (Scalar values at Decision Time) ---
    atrp: pd.Series  # Volatility (Mean over lookback)

    # NEW SENSORS
    rsi: pd.Series  # Internal Momentum (0-100)
    rel_strength: pd.Series  # Performance vs SPY
    vol_regime: pd.Series  # Volatility Expansion/Compression
    rvol: pd.Series  # Ticker Conviction
    spy_rvol: pd.Series  # Market Participation
    obv_score: pd.Series  # Ticker Accumulation/Distribution
    spy_obv_score: pd.Series  # Market Tide

    # MOMENTUM VECTORS
    roc_1: pd.Series
    roc_3: pd.Series
    roc_5: pd.Series
    roc_10: pd.Series
    roc_21: pd.Series


def metric_price(obs: MarketObservation) -> pd.Series:
    return calculate_cross_sectional_gain(obs["lookback_close"])


def metric_sharpe(obs: MarketObservation) -> pd.Series:
    return calculate_cross_sectional_sharpe(obs["lookback_returns"])


def metric_sharpe_atr(obs: MarketObservation) -> pd.Series:
    return calculate_cross_sectional_sharpe_atr(obs["lookback_returns"], obs["atrp"])


METRIC_REGISTRY = {
    # --- CLASSIC METRICS ---
    "Price": metric_price,
    "Sharpe": metric_sharpe,
    "Sharpe (ATR)": metric_sharpe_atr,
    # --- MOMENTUM VECTORS ---
    "Momentum 1D": lambda obs: obs["roc_1"],
    "Momentum 3D": lambda obs: obs["roc_3"],
    "Momentum 5D": lambda obs: obs["roc_5"],
    "Momentum 10D": lambda obs: obs["roc_10"],
    "Momentum 1M": lambda obs: obs["roc_21"],
    # --- PULLBACK VECTORS ---
    "Pullback 1D": lambda obs: -obs["roc_1"],
    "Pullback 3D": lambda obs: -obs["roc_3"],
    "Pullback 5D": lambda obs: -obs["roc_5"],
    "Pullback 10D": lambda obs: -obs["roc_10"],
    "Pullback 1M": lambda obs: -obs["roc_21"],
    # --- NEW SOTA SENSORS ---
    "RSI (Reversal)": lambda obs: -obs["rsi"],  # Rank Low RSI (Oversold) higher
    "RSI (Trend)": lambda obs: obs["rsi"],  # Rank High RSI (Strong Trend) higher
    "Alpha (RelStrength)": lambda obs: obs["rel_strength"],  # Rank stocks beating SPY
    "OBV (Accumulation)": lambda obs: obs["obv_score"],  # Rank High OBV Score
    "Volume Conviction": lambda obs: obs["rvol"],  # Rank High Relative Volume
    "Volatility Regime (Breakout)": lambda obs: obs[
        "vol_regime"
    ],  # Rank High Volatility Expansion
}


# ==============================================================================
# SECTION D: DATA CONTRACTS
# ==============================================================================


@dataclass
class EngineInput:
    mode: str
    start_date: pd.Timestamp
    lookback_period: int
    holding_period: int
    metric: str
    benchmark_ticker: str
    rank_start: int = 1
    rank_end: int = 10
    # Default factory pulls from Global thresholds
    quality_thresholds: Dict[str, float] = field(
        default_factory=lambda: GLOBAL_SETTINGS["thresholds"].copy()
    )
    manual_tickers: List[str] = field(default_factory=list)
    debug: bool = False


@dataclass
class EngineOutput:
    portfolio_series: pd.Series
    benchmark_series: pd.Series
    normalized_plot_data: pd.DataFrame
    tickers: List[str]
    initial_weights: pd.Series
    perf_metrics: Dict[str, float]
    results_df: pd.DataFrame

    # Dates
    start_date: pd.Timestamp
    decision_date: pd.Timestamp
    buy_date: pd.Timestamp
    holding_end_date: pd.Timestamp

    error_msg: Optional[str] = None
    debug_data: Optional[Dict[str, Any]] = None


class AlphaEngine:
    def __init__(
        self,
        df_ohlcv: pd.DataFrame,
        features_df: pd.DataFrame = None,
        df_close_wide: pd.DataFrame = None,
        df_atrp_wide: pd.DataFrame = None,  # <--- PINPOINT 1: Add this argument
        master_ticker: str = GLOBAL_SETTINGS["calendar_ticker"],
    ):
        print("--- ‚öôÔ∏è Initializing AlphaEngine v2.2 (Sanitized) ---")

        # 1. SETUP PRICES (CLEAN-AT-ENTRY)
        if df_close_wide is not None:
            self.df_close = df_close_wide
        else:
            print("üê¢ Pivoting and Sanitizing Price Data...")
            self.df_close = df_ohlcv["Adj Close"].unstack(level=0)

        # APPLY DATA SANITIZER LOGIC
        if GLOBAL_SETTINGS["handle_zeros_as_nan"]:
            # Replace 0.0 with NaN so math functions (mean/std) ignore them
            self.df_close = self.df_close.replace(0, np.nan)

        # Smooth over 1-2 day glitches (The "FNV" Fix)
        self.df_close = self.df_close.ffill(limit=GLOBAL_SETTINGS["max_data_gap_ffill"])

        # Handle the remaining "unfillable" gaps
        self.df_close = self.df_close.fillna(GLOBAL_SETTINGS["nan_price_replacement"])

        # 2. SETUP FEATURES
        if features_df is not None:
            self.features_df = features_df
        else:
            # We pass the cleaned price data if needed, or calculate from raw
            self.features_df = generate_features(
                df_ohlcv,
                atr_period=GLOBAL_SETTINGS["atr_period"],
                quality_window=GLOBAL_SETTINGS["quality_window"],
                quality_min_periods=GLOBAL_SETTINGS["quality_min_periods"],
            )

        # --- PINPOINT 2: Logic Swap for Speed ---
        if df_atrp_wide is not None:
            # INSTANT: Use the matrix precomputed outside the UI
            self.df_atrp = df_atrp_wide
        else:
            # SLOW FALLBACK: Only runs if you forget to precompute
            print("üöÄ Pre-aligning Volatility (ATRP) Matrix (Slow Fallback)...")
            self.df_atrp = self.features_df["ATRP"].unstack(level=0)

        # Final safety alignment (Always cheap once already unstacked)
        self.df_atrp = self.df_atrp.reindex(
            index=self.df_close.index, columns=self.df_close.columns
        )

        # # --- THE PINPOINT CHANGE: Create Wide Feature Matrix ---
        # print("üöÄ Pre-aligning Volatility (ATRP) Matrix...")
        # self.df_atrp = self.features_df["ATRP"].unstack(level=0)

        # # Ensure the ATRP matrix has the exact same index (Dates) as our Price matrix
        # self.df_atrp = self.df_atrp.reindex(self.df_close.index)

        # 3. Setup Calendar
        if master_ticker not in self.df_close.columns:
            master_ticker = self.df_close.columns[0]
        self.trading_calendar = (
            self.df_close[master_ticker].dropna().index.unique().sort_values()
        )

    def run(self, inputs: EngineInput) -> EngineOutput:
        dates, error = self._validate_timeline(inputs)
        if error:
            return self._error_result(error)
        (safe_start, safe_decision, safe_buy, safe_end) = dates

        tickers_to_trade, results_table, debug_dict, error = self._select_tickers(
            inputs, safe_start, safe_decision
        )
        if error:
            return self._error_result(error)

        # GENERATE TRACKS (Fix Part A)
        p_f_val, p_f_ret, p_f_atrp = calculate_buy_and_hold_performance(
            self.df_close, self.df_atrp, tickers_to_trade, safe_start, safe_end
        )
        b_f_val, b_f_ret, b_f_atrp = calculate_buy_and_hold_performance(
            self.df_close, self.df_atrp, [inputs.benchmark_ticker], safe_start, safe_end
        )

        p_h_val, p_h_ret, p_h_atrp = calculate_buy_and_hold_performance(
            self.df_close, self.df_atrp, tickers_to_trade, safe_buy, safe_end
        )
        b_h_val, b_h_ret, b_h_atrp = calculate_buy_and_hold_performance(
            self.df_close, self.df_atrp, [inputs.benchmark_ticker], safe_buy, safe_end
        )

        # CALCULATE METRICS
        p_metrics, p_slices = self._calculate_period_metrics(
            p_f_val,
            p_f_ret,
            p_f_atrp,
            safe_decision,
            p_h_val,
            p_h_ret,
            p_h_atrp,
            prefix="p",
        )
        b_metrics, b_slices = self._calculate_period_metrics(
            b_f_val,
            b_f_ret,
            b_f_atrp,
            safe_decision,
            b_h_val,
            b_h_ret,
            b_h_atrp,
            prefix="b",
        )

        # CONSOLIDATE DEBUG DATA
        debug_dict["verification"] = {"portfolio": p_slices, "benchmark": b_slices}

        # ADD RAW COMPONENT EXPORTS
        debug_dict["portfolio_raw_components"] = {
            "prices": self.df_close[tickers_to_trade].loc[safe_start:safe_end],
            "atrp": self.features_df.loc[
                (tickers_to_trade, slice(safe_start, safe_end)), "ATRP"
            ].unstack(level=0),
        }
        debug_dict["benchmark_raw_components"] = {
            "prices": self.df_close[[inputs.benchmark_ticker]].loc[safe_start:safe_end],
            "atrp": self.features_df.loc[
                ([inputs.benchmark_ticker], slice(safe_start, safe_end)), "ATRP"
            ].unstack(level=0),
        }

        # FINAL OUTPUT
        results_table["Holding Gain"] = (p_h_val.iloc[-1] / p_h_val.iloc[0]) - 1
        return EngineOutput(
            portfolio_series=p_f_val,
            benchmark_series=b_f_val,
            normalized_plot_data=self._get_normalized_plot_data(
                tickers_to_trade, safe_start, safe_end
            ),
            tickers=tickers_to_trade,
            initial_weights=pd.Series(
                {t: 1 / len(tickers_to_trade) for t in tickers_to_trade}
            ),
            perf_metrics={**p_metrics, **b_metrics},
            results_df=results_table,
            start_date=safe_start,
            decision_date=safe_decision,
            buy_date=safe_buy,
            holding_end_date=safe_end,
            debug_data=debug_dict,
        )

    # ==============================================================================
    # INTERNAL LOGIC MODULES
    # ==============================================================================

    def _validate_timeline(self, inputs: EngineInput):
        cal = self.trading_calendar
        last_idx = len(cal) - 1

        if len(cal) <= inputs.lookback_period:
            return (
                None,
                f"‚ùå Dataset too small.\nNeed > {inputs.lookback_period} days of history.",
            )

        # 2. Check "Past" Constraints (Lookback)
        min_decision_date = cal[inputs.lookback_period]
        if inputs.start_date < min_decision_date:
            # Added \n here
            return None, (
                f"‚ùå Not enough history for a {inputs.lookback_period}-day lookback.\n"
                f"Earliest valid Decision Date: {min_decision_date.date()}"
            )

        # 3. Check "Future" Constraints (Entry T+1 and Holding Period)
        required_future_days = 1 + inputs.holding_period
        latest_valid_idx = last_idx - required_future_days

        if latest_valid_idx < 0:
            return (
                None,
                f"‚ùå Holding period too long.\n{inputs.holding_period} days exceeds available data.",
            )

        # If user picked a date beyond the available "future" runway
        if inputs.start_date > cal[latest_valid_idx]:
            latest_date = cal[latest_valid_idx].date()
            # Added \n here and shortened the text slightly to fit better
            return None, (
                f"‚ùå Decision Date too late for a {inputs.holding_period}-day hold.\n"
                f"Latest valid date: {latest_date}. Please move picker back."
            )

        # 4. Map the safe indices
        decision_idx = cal.searchsorted(inputs.start_date)
        if decision_idx > latest_valid_idx:
            decision_idx = latest_valid_idx

        start_idx = decision_idx - inputs.lookback_period
        entry_idx = decision_idx + 1
        end_idx = entry_idx + inputs.holding_period

        return (cal[start_idx], cal[decision_idx], cal[entry_idx], cal[end_idx]), None

    def _select_tickers(self, inputs: EngineInput, start_date, decision_date):
        debug_dict = {}

        # --- PATH A: MANUAL LIST ---
        if inputs.mode == "Manual List":
            validation_errors = []
            valid_tickers = []
            for t in inputs.manual_tickers:
                if t not in self.df_close.columns:
                    validation_errors.append(f"‚ùå {t}: Not found.")
                    continue
                if pd.isna(self.df_close.at[start_date, t]):
                    validation_errors.append(f"‚ö†Ô∏è {t}: No data on start date.")
                    continue
                valid_tickers.append(t)

            if validation_errors:
                return [], pd.DataFrame(), {}, "\n".join(validation_errors)
            if not valid_tickers:
                return [], pd.DataFrame(), {}, "No valid tickers found."
            return valid_tickers, pd.DataFrame(index=valid_tickers), {}, None

        # --- PATH B: RANKING ---
        else:
            audit_info = {}
            eligible_tickers = self._filter_universe(
                decision_date, inputs.quality_thresholds, audit_info
            )
            debug_dict["audit_liquidity"] = audit_info

            if not eligible_tickers:
                return (
                    [],
                    pd.DataFrame(),
                    debug_dict,
                    "No tickers passed quality filters.",
                )

            lookback_close = self.df_close.loc[
                start_date:decision_date, eligible_tickers
            ]

            # 1. Get the Snapshot of Features for the Decision Date
            feat_slice_current = self.features_df.xs(
                decision_date, level="Date"
            ).reindex(eligible_tickers)

            # Calculate mean ATRP over the lookback period
            idx_product = pd.MultiIndex.from_product(
                [eligible_tickers, lookback_close.index], names=["Ticker", "Date"]
            )
            feat_slice_period = self.features_df.reindex(idx_product)
            atrp_value_for_obs = (
                feat_slice_period["ATRP"].groupby(level="Ticker").mean()
            )

            # 2. Package the Observation (The 'State')
            observation: MarketObservation = {
                # Time Series Data
                "lookback_close": lookback_close,
                "lookback_returns": lookback_close.ffill().pct_change(),
                # Snapshot Data (Scalar values for today)
                "atrp": atrp_value_for_obs,  # <--- USES THE TOGGLED VALUE HERE
                "rsi": feat_slice_current["RSI"],
                "rel_strength": feat_slice_current["RelStrength"],
                "vol_regime": feat_slice_current["VolRegime"],
                "rvol": feat_slice_current["RVol"],
                "spy_rvol": feat_slice_current["Spy_RVol"],
                "obv_score": feat_slice_current["OBV_Score"],
                "spy_obv_score": feat_slice_current["Spy_OBV_Score"],
                # Momentum Vectors
                "roc_1": feat_slice_current["ROC_1"],
                "roc_3": feat_slice_current["ROC_3"],
                "roc_5": feat_slice_current["ROC_5"],
                "roc_10": feat_slice_current["ROC_10"],
                "roc_21": feat_slice_current["ROC_21"],
            }

            # 3. Run the Strategy (The 'Agent')
            if inputs.metric not in METRIC_REGISTRY:
                return [], pd.DataFrame(), {}, f"Strategy '{inputs.metric}' not found."

            metric_vals = METRIC_REGISTRY[inputs.metric](observation)
            sorted_tickers = metric_vals.sort_values(ascending=False)
            start_r = max(0, inputs.rank_start - 1)
            end_r = inputs.rank_end
            selected_tickers = sorted_tickers.iloc[start_r:end_r].index.tolist()

            # Audit
            debug_dict["full_universe_ranking"] = pd.DataFrame(
                {
                    "Strategy_Score": metric_vals,
                    "Lookback_Return_Ann": observation["lookback_returns"].mean() * 252,
                    "Lookback_ATRP": observation["atrp"],
                }
            )

            if not selected_tickers:
                return (
                    [],
                    pd.DataFrame(),
                    debug_dict,
                    "No tickers generated from ranking.",
                )

            results_table = pd.DataFrame(
                {
                    "Rank": range(
                        inputs.rank_start, inputs.rank_start + len(selected_tickers)
                    ),
                    "Ticker": selected_tickers,
                    "Strategy Value": sorted_tickers.loc[selected_tickers].values,
                }
            ).set_index("Ticker")

            return selected_tickers, results_table, debug_dict, None

    def _filter_universe(self, date_ts, thresholds, audit_container=None):
        avail_dates = (
            self.features_df.index.get_level_values("Date").unique().sort_values()
        )
        valid_dates = avail_dates[avail_dates <= date_ts]
        if valid_dates.empty:
            return []
        target_date = valid_dates[-1]
        day_features = self.features_df.xs(target_date, level="Date")

        vol_cutoff = thresholds.get("min_median_dollar_volume", 0)
        percentile_used = "N/A"
        if "min_liquidity_percentile" in thresholds:
            percentile_used = thresholds["min_liquidity_percentile"]
            dynamic_val = day_features["RollMedDollarVol"].quantile(percentile_used)
            vol_cutoff = max(vol_cutoff, dynamic_val)

        mask = (
            (day_features["RollMedDollarVol"] >= vol_cutoff)
            & (day_features["RollingStalePct"] <= thresholds["max_stale_pct"])
            & (day_features["RollingSameVolCount"] <= thresholds["max_same_vol_count"])
        )

        if audit_container is not None:
            audit_container["date"] = target_date
            audit_container["total_tickers_available"] = len(day_features)
            audit_container["percentile_setting"] = percentile_used
            audit_container["final_cutoff_usd"] = vol_cutoff
            audit_container["tickers_passed"] = mask.sum()
            snapshot = day_features.copy()
            snapshot["Calculated_Cutoff"] = vol_cutoff
            snapshot["Passed_Vol_Check"] = snapshot["RollMedDollarVol"] >= vol_cutoff
            snapshot["Passed_Final"] = mask
            snapshot = snapshot.sort_values("RollMedDollarVol", ascending=False)
            audit_container["universe_snapshot"] = snapshot

        return day_features[mask].index.tolist()

    def _calculate_period_metrics(
        self, f_val, f_ret, f_atrp, decision_date, h_val, h_ret, h_atrp, prefix
    ):
        metrics = {}
        slices = {}

        # 1. Temporal Slicing (Routing)
        lb_val, lb_ret, lb_atrp = (
            f_val.loc[:decision_date],
            f_ret.loc[:decision_date],
            f_atrp.loc[:decision_date],
        )

        # 2. GAIN
        metrics[f"full_{prefix}_gain"] = QuantUtils.calculate_total_gain(f_val)
        metrics[f"lookback_{prefix}_gain"] = QuantUtils.calculate_total_gain(lb_val)
        metrics[f"holding_{prefix}_gain"] = QuantUtils.calculate_total_gain(h_val)

        # 3. SHARPE
        metrics[f"full_{prefix}_sharpe"] = QuantUtils.calculate_sharpe(f_ret)
        metrics[f"lookback_{prefix}_sharpe"] = QuantUtils.calculate_sharpe(lb_ret)
        metrics[f"holding_{prefix}_sharpe"] = QuantUtils.calculate_sharpe(h_ret)

        # 4. SHARPE (ATR)
        metrics[f"full_{prefix}_sharpe_atr"] = QuantUtils.calculate_sharpe_atr(
            f_ret, f_atrp
        )
        metrics[f"lookback_{prefix}_sharpe_atr"] = QuantUtils.calculate_sharpe_atr(
            lb_ret, lb_atrp
        )
        metrics[f"holding_{prefix}_sharpe_atr"] = QuantUtils.calculate_sharpe_atr(
            h_ret, h_atrp
        )

        # 5. Metadata Collection
        slices["full_val"], slices["full_ret"], slices["full_atrp"] = (
            f_val,
            f_ret,
            f_atrp,
        )
        slices["lookback_val"], slices["lookback_ret"], slices["lookback_atrp"] = (
            lb_val,
            lb_ret,
            lb_atrp,
        )
        slices["holding_val"], slices["holding_ret"], slices["holding_atrp"] = (
            h_val,
            h_ret,
            h_atrp,
        )

        return metrics, slices

    def _get_normalized_plot_data(self, tickers, start_date, end_date):
        if not tickers:
            return pd.DataFrame()
        data = self.df_close[list(set(tickers))].loc[start_date:end_date]
        if data.empty:
            return pd.DataFrame()
        return data / data.bfill().iloc[0]

    def _error_result(self, msg):
        return EngineOutput(
            portfolio_series=pd.Series(dtype=float),
            benchmark_series=pd.Series(dtype=float),
            normalized_plot_data=pd.DataFrame(),
            tickers=[],
            initial_weights=pd.Series(dtype=float),
            perf_metrics={},
            results_df=pd.DataFrame(),
            start_date=pd.Timestamp.min,
            decision_date=pd.Timestamp.min,
            buy_date=pd.Timestamp.min,
            holding_end_date=pd.Timestamp.min,
            error_msg=msg,
        )


# ==============================================================================
# SECTION E: THE UI (Visualization)
# ==============================================================================


def plot_walk_forward_analyzer(
    df_ohlcv,
    precomputed_features=None,
    precomputed_close=None,
    precomputed_atrp=None,  # <--- NEW ARGUMENT
    default_start_date="2025-01-17",
    default_lookback=10,
    default_holding=5,
    default_strategy="Sharpe (ATR)",
    default_rank_start=1,
    default_rank_end=10,
    default_benchmark_ticker=GLOBAL_SETTINGS["benchmark_ticker"],
    master_calendar_ticker=GLOBAL_SETTINGS["calendar_ticker"],
    quality_thresholds=GLOBAL_SETTINGS["thresholds"],
    debug=False,
):

    engine = AlphaEngine(
        df_ohlcv,
        features_df=precomputed_features,
        df_close_wide=precomputed_close,
        df_atrp_wide=precomputed_atrp,  # <--- PASS IT HERE
        master_ticker=master_calendar_ticker,
    )

    # Initialize containers
    results_container = [None]
    debug_container = [{}]

    # If no thresholds passed, use the global Source of Truth
    if quality_thresholds is None:
        quality_thresholds = GLOBAL_SETTINGS["thresholds"]

    # --- Widgets ---
    mode_selector = widgets.RadioButtons(
        options=["Ranking", "Manual List"],
        value="Ranking",
        description="Mode:",
        layout={"width": "max-content"},
        style={"description_width": "initial"},
    )
    lookback_input = widgets.IntText(
        value=default_lookback,
        description="Lookback (Days):",
        layout={"width": "200px"},
        style={"description_width": "initial"},
    )
    decision_date_picker = widgets.DatePicker(
        description="Decision Date:",
        value=pd.to_datetime(default_start_date),
        layout={"width": "auto"},
        style={"description_width": "initial"},
    )
    holding_input = widgets.IntText(
        value=default_holding,
        description="Holding (Days):",
        layout={"width": "200px"},
        style={"description_width": "initial"},
    )
    strategy_dropdown = widgets.Dropdown(
        options=list(METRIC_REGISTRY.keys()),
        value=default_strategy,
        description="Strategy:",
        layout={"width": "220px"},
        style={"description_width": "initial"},
    )
    benchmark_input = widgets.Text(
        value=default_benchmark_ticker,
        description="Benchmark:",
        placeholder="Enter Ticker",
        layout={"width": "180px"},
        style={"description_width": "initial"},
    )
    rank_start_input = widgets.IntText(
        value=default_rank_start,
        description="Rank Start:",
        layout={"width": "150px"},
        style={"description_width": "initial"},
    )
    rank_end_input = widgets.IntText(
        value=default_rank_end,
        description="Rank End:",
        layout={"width": "150px"},
        style={"description_width": "initial"},
    )
    manual_tickers_input = widgets.Textarea(
        value="",
        placeholder="Enter tickers...",
        description="Manual Tickers:",
        layout={"width": "400px", "height": "80px"},
        style={"description_width": "initial"},
    )
    update_button = widgets.Button(description="Run Simulation", button_style="primary")
    ticker_list_output = widgets.Output()

    # --- Layouts ---
    timeline_box = widgets.HBox(
        [lookback_input, decision_date_picker, holding_input],
        layout=widgets.Layout(
            justify_content="space-between",
            border="1px solid #ddd",
            padding="10px",
            margin="5px",
        ),
    )
    strategy_box = widgets.HBox([strategy_dropdown, benchmark_input])
    ranking_box = widgets.HBox([rank_start_input, rank_end_input])

    def on_mode_change(c):
        ranking_box.layout.display = "flex" if c["new"] == "Ranking" else "none"
        manual_tickers_input.layout.display = (
            "none" if c["new"] == "Ranking" else "flex"
        )
        strategy_dropdown.disabled = c["new"] == "Manual List"

    mode_selector.observe(on_mode_change, names="value")
    on_mode_change({"new": mode_selector.value})

    ui = widgets.VBox(
        [
            widgets.HTML(
                "<b>1. Timeline Configuration:</b> (Past <--- Decision ---> Future)"
            ),
            timeline_box,
            widgets.HTML("<b>2. Strategy Settings:</b>"),
            widgets.HBox([mode_selector, strategy_box]),
            ranking_box,
            manual_tickers_input,
            widgets.HTML("<hr>"),
            update_button,
            ticker_list_output,
        ],
        layout=widgets.Layout(margin="10px 0 20px 0"),
    )

    fig = go.FigureWidget()
    fig.update_layout(
        title="Event-Driven Walk-Forward Analysis",
        height=600,
        template="plotly_white",
        hovermode="x unified",
    )
    for i in range(50):
        fig.add_trace(go.Scatter(visible=False, line=dict(width=2)))
    fig.add_trace(
        go.Scatter(
            name="Benchmark",
            visible=True,
            line=dict(color="black", width=3, dash="dash"),
        )
    )
    fig.add_trace(
        go.Scatter(
            name="Group Portfolio", visible=True, line=dict(color="green", width=3)
        )
    )

    # --- Update Logic ---
    def update_plot(b):
        ticker_list_output.clear_output()
        manual_list = [
            t.strip().upper()
            for t in manual_tickers_input.value.split(",")
            if t.strip()
        ]
        decision_date_raw = pd.to_datetime(decision_date_picker.value)

        inputs = EngineInput(
            mode=mode_selector.value,
            start_date=decision_date_raw,
            lookback_period=lookback_input.value,
            holding_period=holding_input.value,
            metric=strategy_dropdown.value,
            benchmark_ticker=benchmark_input.value.strip().upper(),
            rank_start=rank_start_input.value,
            rank_end=rank_end_input.value,
            quality_thresholds=quality_thresholds,
            manual_tickers=manual_list,
            debug=debug,
        )

        # --- CAPTURE INPUTS FOR AUDIT ---
        debug_container[0]["inputs"] = inputs

        with ticker_list_output:
            res = engine.run(inputs)
            results_container[0] = res

            # --- MERGE ENGINE DEBUG DATA ---
            if res.debug_data:
                debug_container[0].update(res.debug_data)

            if res.error_msg:
                print(f"‚ö†Ô∏è Simulation Stopped: {res.error_msg}")
                return

            # Plotting
            with fig.batch_update():
                cols = res.normalized_plot_data.columns.tolist()
                for i in range(50):
                    if i < len(cols):
                        fig.data[i].update(
                            x=res.normalized_plot_data.index,
                            y=res.normalized_plot_data[cols[i]],
                            name=cols[i],
                            visible=True,
                        )
                    else:
                        fig.data[i].visible = False

                fig.data[50].update(
                    x=res.benchmark_series.index,
                    y=res.benchmark_series.values,
                    name=f"Benchmark ({inputs.benchmark_ticker})",
                    visible=not res.benchmark_series.empty,
                )
                fig.data[51].update(
                    x=res.portfolio_series.index,
                    y=res.portfolio_series.values,
                    visible=True,
                )

                # Visual Lines
                fig.layout.shapes = [
                    dict(
                        type="line",
                        x0=res.decision_date,
                        y0=0,
                        x1=res.decision_date,
                        y1=1,
                        xref="x",
                        yref="paper",
                        line=dict(color="red", width=2, dash="dash"),
                    ),
                    dict(
                        type="line",
                        x0=res.buy_date,
                        y0=0,
                        x1=res.buy_date,
                        y1=1,
                        xref="x",
                        yref="paper",
                        line=dict(color="blue", width=2, dash="dot"),
                    ),
                ]

                fig.layout.annotations = [
                    dict(
                        x=res.decision_date,
                        y=0.05,
                        xref="x",
                        yref="paper",
                        text="DECISION",
                        showarrow=False,
                        bgcolor="red",
                        font=dict(color="white"),
                    ),
                    dict(
                        x=res.buy_date,
                        y=1.0,
                        xref="x",
                        yref="paper",
                        text="ENTRY (T+1)",
                        showarrow=False,
                        bgcolor="blue",
                        font=dict(color="white"),
                    ),
                ]

            start_date = res.start_date.date()
            act_date = res.decision_date.date()
            entry_date = res.buy_date.date()

            # Liquidity Audit Print
            if (
                inputs.mode == "Ranking"
                and res.debug_data
                and "audit_liquidity" in res.debug_data
            ):
                audit = res.debug_data["audit_liquidity"]
                if audit:
                    raw_percentile = audit.get("percentile_setting", 0)
                    keep_pct = (
                        1 - raw_percentile
                    ) * 100  # Calculates the actual portion kept
                    cut_val = audit.get("final_cutoff_usd", 0)

                    print("-" * 60)
                    print(f"üîç LIQUIDITY CHECK (On Decision Date: {act_date})")
                    print(
                        f"   Universe Size: {audit.get('total_tickers_available')} tickers"
                    )
                    print(
                        f"   Liquidity Threshold: {raw_percentile*100:.0f}th Percentile"
                    )
                    print(f"   Action: Keeping the Top {keep_pct:.0f}% of Market")
                    print(f"   Calculated Cutoff: ${cut_val:,.0f} / day")
                    print(f"   Tickers Remaining: {audit.get('tickers_passed')}")
                    print("-" * 60)

            # --- UPDATED TIMELINE PRINT ---
            print(
                f"Timeline: Start [ {start_date} ] --> Decision [ {act_date} ] --> Cash (1d) --> Entry [ {entry_date} ] --> End [ {res.holding_end_date.date()} ]"
            )

            if inputs.mode == "Ranking":
                print(f"Ranked Tickers ({len(res.tickers)}):")
                for i in range(0, len(res.tickers), 10):
                    print(", ".join(res.tickers[i : i + 10]))
            else:
                print("Manual Portfolio Tickers:")
                for i in range(0, len(res.tickers), 10):
                    print(", ".join(res.tickers[i : i + 10]))

            m = res.perf_metrics

            # --- DRY UI GENERATION ---
            # 1. Define the metrics we want to display
            metrics_to_show = [
                ("Gain", "gain"),
                ("Sharpe", "sharpe"),
                ("Sharpe (ATR)", "sharpe_atr"),
            ]

            rows = []
            for label, key in metrics_to_show:
                p_row = {
                    "Metric": f"Group {label}",
                    "Full": m.get(f"full_p_{key}"),
                    "Lookback": m.get(f"lookback_p_{key}"),
                    "Holding": m.get(f"holding_p_{key}"),
                }
                b_row = {
                    "Metric": f"Benchmark {label}",
                    "Full": m.get(f"full_b_{key}"),
                    "Lookback": m.get(f"lookback_b_{key}"),
                    "Holding": m.get(f"holding_b_{key}"),
                }

                # Delta calculation
                d_row = {"Metric": f"== {label} Delta"}
                for col in ["Full", "Lookback", "Holding"]:
                    d_row[col] = (p_row[col] or 0) - (b_row[col] or 0)

                rows.extend([p_row, b_row, d_row])

            df_report = pd.DataFrame(rows).set_index("Metric")

            # --- 2. STYLING (The "Senior" Design) ---
            # --- 1. PREP DATA (Flattening the Index) ---
            # We convert the index to a column so "Metric" sits on the same row as other headers
            df_report = pd.DataFrame(rows)
            df_report = df_report.set_index("Metric")

            # --- 2. THE STYLING (Sleek & Proportional) ---
            def apply_sleek_style(styler):
                # Match notebook font size (usually 13px)
                styler.format("{:+.4f}", na_rep="N/A")

                # Dynamic Row Highlighting
                def row_logic(row):
                    if "Delta" in row.name:
                        return [
                            "background-color: #f9f9f9; font-weight: 600; border-top: 1px solid #ddd"
                        ] * len(row)
                    if "Group" in row.name:
                        return ["color: #2c5e8f; background-color: #fcfdfe"] * len(row)
                    return ["color: #555"] * len(
                        row
                    )  # Benchmark rows are slightly muted

                styler.apply(row_logic, axis=1)

                styler.set_table_styles(
                    [
                        # Base Table Font - Scaling down to match standard text
                        {
                            "selector": "",
                            "props": [
                                ("font-family", "inherit"),
                                ("font-size", "12px"),
                                ("border-collapse", "collapse"),
                                ("width", "auto"),
                                ("margin-left", "0"),
                            ],
                        },
                        # Header Row - Flattened and Muted
                        {
                            "selector": "th",
                            "props": [
                                ("background-color", "white"),
                                ("color", "#222"),
                                ("font-weight", "600"),
                                ("padding", "6px 12px"),
                                ("border-bottom", "2px solid #444"),
                                ("text-align", "center"),
                                (
                                    "vertical-align",
                                    "bottom",
                                ),  # Aligns 'Metric' with others
                            ],
                        },
                        # Index Column (The "Metric" labels)
                        {
                            "selector": "th.row_heading",
                            "props": [
                                ("text-align", "left"),
                                ("padding-right", "30px"),
                                ("border-bottom", "1px solid #eee"),
                            ],
                        },
                        # Cell Data - Tighter padding
                        {
                            "selector": "td",
                            "props": [
                                ("padding", "4px 12px"),
                                ("border-bottom", "1px solid #eee"),
                            ],
                        },
                        # Remove the extra "Index Name" row completely
                        {
                            "selector": "thead tr:nth-child(1) th",
                            "props": [("display", "table-cell")],
                        },
                    ]
                )

                # Hack to fix the 'Metric' alignment:
                # We remove the index name and set it as the horizontal label for the index
                styler.index.name = None

                return styler

            display(apply_sleek_style(df_report.style))

    update_button.on_click(update_plot)
    update_plot(None)
    display(ui, fig)
    return results_container, debug_container


# ==============================================================================
# INTEGRITY PROTECTION: THE TRIPWIRE
# ==============================================================================


def verify_math_integrity():
    """
    üõ°Ô∏è TRIPWIRE: Ensures Sample Boundary Integrity.
    """
    print("\n--- üõ°Ô∏è Starting Final Integrity Audit ---")

    try:
        # Test 1: Series Input
        mock_series = pd.Series([100.0, 102.0, 101.0])
        rets_s = QuantUtils.compute_returns(mock_series)
        # Verify first value is actually NaN
        if not pd.isna(rets_s.iloc[0]):
            raise ValueError("Series Leading NaN missing")
        print("‚úÖ Series Boundary: OK")

        # Test 2: DataFrame Input
        mock_df = pd.DataFrame({"A": [100, 101], "B": [200, 202]})
        rets_df = QuantUtils.compute_returns(mock_df)
        if not rets_df.iloc[0].isna().all():
            raise ValueError("DataFrame Leading NaN missing")
        print("‚úÖ DataFrame Boundary: OK")

        print("‚úÖ AUDIT PASSED: Mathematical boundaries are strictly enforced.")
    except Exception as e:
        print(f"üî• SYSTEM BREACH: {str(e)}")
        raise e


# Auto-run the check
verify_math_integrity()

# Run the Tripwire
verify_feature_engineering_integrity()

#


--- üõ°Ô∏è Starting Final Integrity Audit ---
‚úÖ Series Boundary: OK
‚úÖ DataFrame Boundary: OK
‚úÖ AUDIT PASSED: Mathematical boundaries are strictly enforced.

--- üõ°Ô∏è Starting Feature Engineering Audit ---
‚ö° Generating SOTA Quant Features (Benchmark: SPY)...
Audit Values:
[ nan 25.  17.5]
‚úÖ FEATURE INTEGRITY PASSED: Wilder's ATR logic is strictly enforced.


In [5]:
df_ohlcv = pd.read_parquet('/content/df_OHLCV_stocks_etfs.parquet')

In [6]:
# Calculate features ONCE and store them in a variable
print("Calculating features... this might take about 10 minutes...")
print("1. Calculating Features...")
features_df = generate_features(
    df_ohlcv=df_ohlcv,
    atr_period=GLOBAL_SETTINGS["atr_period"],
    quality_window=GLOBAL_SETTINGS["quality_window"],
    quality_min_periods=GLOBAL_SETTINGS["quality_min_periods"],
)

print("2. Pivoting Price Matrix...")
# This is the line that takes 12 seconds, but now we only run it ONCE.
my_close_matrix = df_ohlcv["Adj Close"].unstack(level=0)

# --- NEW: PRE-PIVOT ATRP ---
print("3. Pivoting Volatility (ATRP) Matrix...")
my_atrp_matrix = features_df["ATRP"].unstack(level=0)

# --- SENIOR SAFETY: Align them perfectly once ---
# This ensures dates and tickers match 100% before the UI even starts
my_atrp_matrix = my_atrp_matrix.reindex_like(my_close_matrix)

print("‚úÖ Pre-computation Complete. UI will now be fast.")

Calculating features... this might take about 10 minutes...
1. Calculating Features...
‚ö° Generating SOTA Quant Features (Benchmark: SPY)...
2. Pivoting Price Matrix...
3. Pivoting Volatility (ATRP) Matrix...
‚úÖ Pre-computation Complete. UI will now be fast.


In [8]:
df_close_wide = my_close_matrix.copy()
df_atrp_wide = my_atrp_matrix.copy()

In [38]:
# ==============================================================================
# 1. THE RL ENVIRONMENT (Fixed for NaN Safety)
# ==============================================================================
class TradingEnv:
    """
    The Gym: Where the Agent learns to play.
    Wraps AlphaEngine to provide States and accept Actions.
    """
    def __init__(self, engine: AlphaEngine, benchmark_ticker="SPY"):
        self.engine = engine
        self.benchmark_ticker = benchmark_ticker

        # ACTION SPACE: Keys from METRIC_REGISTRY
        self.action_map = list(METRIC_REGISTRY.keys())
        self.n_actions = len(self.action_map)

        # STATE SPACE: Available dates aligned with the engine
        self.available_dates = engine.df_close.index
        self.current_step = 0

    def reset(self):
        """Resets the game to a random start date."""
        # Ensure we don't pick the very start (need history) or very end (need future)
        valid_range_start = 252
        valid_range_end = len(self.available_dates) - 25 # Buffer for holding period

        # Safety: Ensure dataset is large enough
        if valid_range_end <= valid_range_start:
             raise ValueError("Dataset too small for RL training")

        self.current_step = np.random.randint(valid_range_start, valid_range_end)
        return self._get_observation()

    def _get_observation(self):
        """
        Returns the 'State' (Market Context).
        """
        current_date = self.available_dates[self.current_step]

        # Get feature vector for the specific date
        # We use .mean() to get the "Market Average" condition
        try:
            day_features = self.engine.features_df.xs(current_date, level="Date")

            state_values = [
                day_features['VolRegime'].mean(),       # 1. Volatility State
                day_features['RSI'].mean() / 100.0,     # 2. Overbought/Oversold
                day_features['OBV_Score'].mean(),       # 3. Money Flow
                day_features['RelStrength'].mean()      # 4. Market Breadth
            ]
        except KeyError:
            # Fallback if date is missing in features
            state_values = [0.0, 0.5, 0.0, 0.0]

        # PINPOINT FIX: Data Sanitization for Neural Net
        # Neural Nets explode if fed NaN or Infinity
        clean_state = []
        for x in state_values:
            if pd.isna(x) or np.isinf(x):
                clean_state.append(0.0)
            else:
                clean_state.append(float(x))

        return torch.tensor(clean_state, dtype=torch.float32)

    def step(self, action_idx):
        """
        The Agent takes an action. We calculate the reward.
        """
        metric_name = self.action_map[action_idx]
        current_date = self.available_dates[self.current_step]

        # 1. SETUP ENGINE INPUT
        inputs = EngineInput(
            mode="Ranking",
            start_date=current_date,
            lookback_period=20,
            holding_period=5,
            metric=metric_name,
            benchmark_ticker=self.benchmark_ticker,
            rank_end=5
        )

        # 2. RUN SIMULATION
        try:
            output = self.engine.run(inputs)

            # --- PINPOINT FIX: Use QuantUtils for Math Safety ---
            # Old Code: Manual division that caused RuntimeWarning/NaN
            # New Code: Delegate to the robust kernel

            port_gain = QuantUtils.calculate_total_gain(output.portfolio_series)
            bench_gain = QuantUtils.calculate_total_gain(output.benchmark_series)

            # 3. CALCULATE REWARD (Alpha)
            # If the calculation failed (result is 0.0 exactly often implies empty),
            # or if it returns huge outliers, we clip it.

            reward = port_gain - bench_gain

            # Sanity Check: If reward is NaN (e.g., benchmark failed), punish slightly
            if np.isnan(reward) or np.isinf(reward):
                reward = -0.01

        except Exception as e:
            # If Engine crashes, give negative reward to teach agent to avoid this state
            reward = -0.05

        # 4. MOVE TIME FORWARD
        self.current_step += 5
        done = self.current_step >= (len(self.available_dates) - 20)

        next_state = self._get_observation() if not done else torch.zeros(4)

        return next_state, reward, done, {}


# ==============================================================================
# 2. THE AGENT (The Brain)
# ==============================================================================
class SimpleAgent(nn.Module):
    def __init__(self, state_dim, n_actions):
        super().__init__()
        # A simple Multi-Layer Perceptron (MLP)
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, n_actions),
            # Softmax outputs probabilities (e.g., 20% Momentum, 80% RSI)
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.net(state)

    def get_action(self, state):
        # 1. Get probabilities from Neural Net
        probs = self.forward(state)

        # 2. Create a distribution to sample from
        dist = Categorical(probs)

        # 3. Sample an action (Exploration is built-in via sampling)
        action = dist.sample()

        return action.item(), dist.log_prob(action)

# ==============================================================================
# 3. TEST DRIVER (Verification)
# ==============================================================================
def verify_rl_setup(engine):
    print("--- ü§ñ Initializing RL Environment ---")

    # 1. Create Env
    env = TradingEnv(engine)
    obs = env.reset()
    print(f"State Vector: {obs} (Shape: {obs.shape})")

    # 2. Create Agent
    state_dim = obs.shape[0]
    n_actions = env.n_actions
    agent = SimpleAgent(state_dim, n_actions)
    print(f"Agent Action Space: {n_actions} Strategies")

    # 3. Test One Step
    action, log_prob = agent.get_action(obs)
    print(f"Agent chose Action Index {action}: '{env.action_map[action]}'")

    next_state, reward, done, _ = env.step(action)
    print(f"Reward (Alpha): {reward:.4f}")
    print("‚úÖ RL Setup Verified.")


In [39]:
# Instantiate the AlphaEngine with precomputed data
engine = AlphaEngine(
    df_ohlcv=df_ohlcv,
    features_df=features_df,
    df_close_wide=df_close_wide,
    df_atrp_wide=df_atrp_wide,
    master_ticker=GLOBAL_SETTINGS["calendar_ticker"],
)

--- ‚öôÔ∏è Initializing AlphaEngine v2.2 (Sanitized) ---


In [42]:
verify_rl_setup(engine)

--- ü§ñ Initializing RL Environment ---
State Vector: tensor([0.9859, 0.5563, 0.4250, 0.0000]) (Shape: torch.Size([4]))
Agent Action Space: 19 Strategies
Agent chose Action Index 18: 'Volatility Regime (Breakout)'
Reward (Alpha): 0.0741
‚úÖ RL Setup Verified.


  res = (clean.iloc[-1] / clean.iloc[0]) - 1
  res = (clean.iloc[-1] / clean.iloc[0]) - 1


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

# ==============================================================================
# 1. THE RL ENVIRONMENT (Bridging AlphaEngine to PyTorch)
# ==============================================================================
class TradingEnv:
    """
    The Gym: Where the Agent learns to play.
    It wraps your AlphaEngine to provide States and accept Actions.
    """
    def __init__(self, engine: AlphaEngine, benchmark_ticker="SPY"):
        self.engine = engine
        self.benchmark_ticker = benchmark_ticker

        # ACTION SPACE: The keys from your METRIC_REGISTRY
        # 0: Price, 1: Sharpe, 2: Momentum 1D, etc.
        self.action_map = list(METRIC_REGISTRY.keys())
        self.n_actions = len(self.action_map)

        # STATE SPACE: We will use SPY features as the "Market Context"
        # We need to pre-calculate which features define the "State"
        # For simplicity, we use: [Spy_RVol, Spy_OBV_Score, VolRegime]
        self.available_dates = engine.df_close.index
        self.current_step = 0

    def reset(self):
        """Resets the game to a random start date."""
        # Pick a random index, ensuring we have enough history and future data
        valid_range_start = 252 # Need history for lookback
        valid_range_end = len(self.available_dates) - 20 # Need future for reward

        self.current_step = np.random.randint(valid_range_start, valid_range_end)
        return self._get_observation()

    def _get_observation(self):
        """
        Returns the 'State' (Market Context) for the Agent to see.
        Returns a PyTorch Tensor.
        """
        current_date = self.available_dates[self.current_step]

        # We need to grab specific features for the Benchmark (SPY) or Average Market
        # Let's use the features_df we already have.
        # Note: We take the average of all stocks to get a "Market Pulse"
        # OR specifically look at SPY if it exists in features.

        # Fast vector lookup for the specific date
        day_features = self.engine.features_df.xs(current_date, level="Date")

        # Simple State: Mean VolRegime, Mean RSI, Mean OBV across the universe
        # This tells the agent: "Is the market hot or cold?"
        state_values = [
            day_features['VolRegime'].mean(),       # Volatility State
            day_features['RSI'].mean() / 100.0,     # Overbought/Oversold (Normalized)
            day_features['OBV_Score'].mean(),       # Money Flow
            day_features['RelStrength'].mean()      # Breadth
        ]

        # Convert to Tensor (The language of PyTorch)
        # Handle NaNs by replacing with 0
        state_values = [0.0 if np.isnan(x) else x for x in state_values]
        return torch.tensor(state_values, dtype=torch.float32)

    def step(self, action_idx):
        """
        The Agent takes an action (picks a strategy).
        We calculate the reward.
        """
        metric_name = self.action_map[action_idx]
        current_date = self.available_dates[self.current_step]

        # 1. SETUP THE ENGINE INPUT
        # The Agent acts as the "User" selecting the dropdown
        inputs = EngineInput(
            mode="Ranking",
            start_date=current_date,
            lookback_period=20,     # Fixed for now
            holding_period=5,       # Fixed for now
            metric=metric_name,
            benchmark_ticker=self.benchmark_ticker,
            rank_end=5              # Top 5 stocks
        )

        # 2. RUN THE ENGINE
        # We suppress prints to keep training clean
        try:
            output = self.engine.run(inputs)

            # 3. CALCULATE REWARD
            # Reward = The Gain of the Portfolio relative to Benchmark
            # RLVR Principle: The reward must be Veritable (True Return)

            # Extract % return from the portfolio series
            if len(output.portfolio_series) > 0:
                port_return = (output.portfolio_series.iloc[-1] / output.portfolio_series.iloc[0]) - 1
                bench_return = (output.benchmark_series.iloc[-1] / output.benchmark_series.iloc[0]) - 1

                # Alpha (Excess Return) is a great reward
                reward = port_return - bench_return
            else:
                reward = -0.01 # Penalty for crashing/no stocks found

        except Exception as e:
            reward = -0.01 # Penalty for error

        # 4. MOVE TIME FORWARD
        self.current_step += 5 # Skip ahead by the holding period
        done = self.current_step >= (len(self.available_dates) - 20)

        next_state = self._get_observation() if not done else torch.zeros(4)

        return next_state, reward, done, {}

# ==============================================================================
# 2. THE AGENT (The Brain)
# ==============================================================================
class SimpleAgent(nn.Module):
    def __init__(self, state_dim, n_actions):
        super().__init__()
        # A simple Multi-Layer Perceptron (MLP)
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, n_actions),
            # Softmax outputs probabilities (e.g., 20% Momentum, 80% RSI)
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.net(state)

    def get_action(self, state):
        # 1. Get probabilities from Neural Net
        probs = self.forward(state)

        # 2. Create a distribution to sample from
        dist = Categorical(probs)

        # 3. Sample an action (Exploration is built-in via sampling)
        action = dist.sample()

        return action.item(), dist.log_prob(action)

# ==============================================================================
# 3. TEST DRIVER (Verification)
# ==============================================================================
def verify_rl_setup(engine):
    print("--- ü§ñ Initializing RL Environment ---")

    # 1. Create Env
    env = TradingEnv(engine)
    obs = env.reset()
    print(f"State Vector: {obs} (Shape: {obs.shape})")

    # 2. Create Agent
    state_dim = obs.shape[0]
    n_actions = env.n_actions
    agent = SimpleAgent(state_dim, n_actions)
    print(f"Agent Action Space: {n_actions} Strategies")

    # 3. Test One Step
    action, log_prob = agent.get_action(obs)
    print(f"Agent chose Action Index {action}: '{env.action_map[action]}'")

    next_state, reward, done, _ = env.step(action)
    print(f"Reward (Alpha): {reward:.4f}")
    print("‚úÖ RL Setup Verified.")

# To run this, you need the 'engine' instance from the previous plot function
# If you ran plot_walk_forward_analyzer, the engine is inside it.
# For now, we assume 'engine' exists or we create a quick one:
# engine = AlphaEngine(df_ohlcv, features_df=features_df)
# verify_rl_setup(engine)

In [16]:
# Instantiate the AlphaEngine with precomputed data
engine = AlphaEngine(
    df_ohlcv=df_ohlcv,
    features_df=features_df,
    df_close_wide=df_close_wide,
    df_atrp_wide=df_atrp_wide,
    master_ticker=GLOBAL_SETTINGS["calendar_ticker"],
)

verify_rl_setup(engine)

--- ‚öôÔ∏è Initializing AlphaEngine v2.2 (Sanitized) ---
--- ü§ñ Initializing RL Environment ---
State Vector: tensor([0.8772, 0.5343, 0.4296, 0.0000]) (Shape: torch.Size([4]))
Agent Action Space: 19 Strategies
Agent chose Action Index 6: 'Momentum 10D'
Reward (Alpha): nan
‚úÖ RL Setup Verified.


  res = (clean.iloc[-1] / clean.iloc[0]) - 1
  bench_return = (output.benchmark_series.iloc[-1] / output.benchmark_series.iloc[0]) - 1


In [26]:
# To run this, you need the 'engine' instance from the previous plot function
# If you ran plot_walk_forward_analyzer, the engine is inside it.
# For now, we assume 'engine' exists or we create a quick one:
# engine = AlphaEngine(df_ohlcv, features_df=features_df)
verify_rl_setup(engine)

--- ü§ñ Initializing RL Environment ---
State Vector: tensor([ 0.8182,  0.5165, -0.1428,  0.0000]) (Shape: torch.Size([4]))
Agent Action Space: 19 Strategies
Agent chose Action Index 12: 'Pullback 1M'
Reward (Alpha): nan
‚úÖ RL Setup Verified.


  res = (clean.iloc[-1] / clean.iloc[0]) - 1
  bench_return = (output.benchmark_series.iloc[-1] / output.benchmark_series.iloc[0]) - 1
