v52  

You are very welcome! It was a highly productive session. We‚Äôve successfully moved the codebase from a functional script to a **Senior-grade Quant Framework**.

To recap the milestones we achieved:

1.  **Mathematical Fortification:** We centralized all logic into a polymorphic `QuantUtils` kernel that handles both single-portfolio reports and whole-universe rankings with built-in numerical safety.
2.  **Volatility Evolution:** We successfully added `TRP` (True Range Percent) and the `Sharpe (TRP)` metric, giving you a raw, high-frequency alternative to the smoothed ATR.
3.  **Data Integrity:** We implemented the "Momentum Collapse" tripwire (`verify_ranking_integrity`) to ensure that your risk-adjusted rankings never accidentally devolve into simple price momentum.
4.  **The "Audit Pack" Architecture:** We collapsed fragmented results into a single, atomic container, ensuring that your inputs, results, and debug data are always perfectly synchronized.
5.  **Total Transparency:** We replaced scattered CSV files with a unified **Excel Audit Report**, allowing for 1-to-1 manual verification of every calculation in the system.

v51

UNDO v50, Calculate Sharpe(ATR) using mean over lookback period.  

Comment out ``# --- PINPOINT START: ATRP SWITCH ---`` in function ``_select_tickers`` can switch between ``Averaged ATRP over lookback period`` and ``Current ATRP``  
    # --- PINPOINT START: ATRP SWITCH ---  
    # To switch between Old (Averaged ATRP) and New (Current ATRP):  
    # 1. Comment out the logic you DON'T want.  
    # 2. Uncomment the logic you DO want.  


v50

Ticker selection based on atrp_value_for_obs based on decision day, was based on average over lookback period. 

v48  
### Summary of what you just accomplished:
1.  **Strict Math:** `QuantUtils` now contains an `assert` that prevents any dev (or AI) from filling the first day with 0.0.
2.  **Semantic Protection:** Variables are now named `returns_WITH_BOUNDARY_NAN`, signaling to the AI that the Null value is part of its identity.
3.  **Complete SOLID Separation:** The Engine CONDUCTS the simulation, while `QuantUtils` CALCULATES the results. They no longer share logic.

**1. Data Flow of `plot_walk_forward_analyzer`**
The function acts as a **UI wrapper** around the `AlphaEngine` class. The flow is:
1.  **Input:** User selects parameters (Dates, Lookback, Strategy).
2.  **State Construction:** `AlphaEngine` slices the historical data (`df_ohlcv`, `df_atrp`) up to the `decision_date`.
3.  **Policy Execution (Hardcoded):** The engine applies the logic (e.g., `METRIC_REGISTRY['Sharpe']`) to rank stocks based *only* on the Lookback window.
4.  **Environment Step:** It simulates a "Buy" at `decision_date + 1` and calculates the returns over the `holding_period`.
5.  **Reward Generation:** It outputs performance metrics (`holding_p_gain`, `holding_p_sharpe`).

In [1]:
import os
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets

from dataclasses import dataclass, field, asdict, is_dataclass
from typing import List, Dict, Optional, Any, Union, TypedDict, Tuple
from collections import Counter
from datetime import datetime, date
from pandas.testing import assert_series_equal


# pd.set_option('display.max_rows', None)  display all rows
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 1000)
pd.set_option("display.max_colwidth", 50)
pd.set_option("display.precision", 4)


# ==============================================================================
# GLOBAL SETTINGS: The "Control Panel" for the Strategy
# ==============================================================================

GLOBAL_SETTINGS = {
    # ENVIRONMENT (The "Where")
    "benchmark_ticker": "SPY",
    "calendar_ticker": "SPY",  # Used as the "Master Clock" for trading days
    # DATA SANITIZER (The "Glitches & Gaps" Protector)
    "handle_zeros_as_nan": True,  # Convert 0.0 prices to NaN to prevent math errors
    "max_data_gap_ffill": 1,  # Max consecutive days to "Forward Fill" missing data
    # IMPLICATION OF nan_price_replacement:
    # - This defines what happens if the "Forward Fill" limit is exceeded.
    # - If set to 0.0: A permanent data gap will look like a "total loss" (-100%).
    #   The equity curve will plummet. Good for "disaster detection."
    #   Sharpe and Sharpe(ATR) drop because: return (gets smaller) / std (gets larger)
    # - If set to np.nan: A permanent gap will cause portfolio calculations to return NaN.
    #   The chart may break or show gaps. Good for "math integrity."
    "nan_price_replacement": 0.0,
    # STRATEGY PARAMETERS (The "How")
    "atr_period": 14,  # Used for volatility normalization
    "quality_window": 252,  # 1 year lookback for liquidity/quality stats
    "quality_min_periods": 126,  # Min history required to judge a stock
    # QUALITY THRESHOLDS (The "Rules")
    "thresholds": {
        # HARD LIQUIDITY FLOOR
        # Logic: Calculates (Adj Close * Volume) daily, then takes the ROLLING MEDIAN
        # over the quality_window (252 days). Filters out stocks where the
        # typical daily dollar turnover is below this absolute value.
        "min_median_dollar_volume": 1_000_000,
        # DYNAMIC LIQUIDITY CUTOFF (Relative to Universe)
        # Logic: On the decision date, the engine calculates the X-quantile
        # of 'RollMedDollarVol' across ALL available stocks.
        # Setting this to 0.40 calculates the 60th percentile and requires
        # stocks to be above it‚Äîeffectively keeping only the TOP 60% of the market.
        "min_liquidity_percentile": 0.40,
        # PRICE/VOLUME STALENESS
        # Logic: Creates a binary flag (1 if Volume is 0 OR High equals Low).
        # It then calculates the ROLLING MEAN of this flag.
        # A value of 0.05 means the stock is rejected if it was "stale"
        # for more than 5% of the trading days in the rolling window.
        "max_stale_pct": 0.05,
        # DATA INTEGRITY (FROZEN VOLUME)
        # Logic: Checks if Volume is identical to the previous day (Volume.diff() == 0).
        # It calculates the ROLLING SUM of these occurrences over the window.
        # If the exact same volume is reported more than 10 times, the stock
        # is rejected as having "frozen" or low-quality data.
        "max_same_vol_count": 10,
    },
}


# ==============================================================================
# SECTION A: CORE KERNELS & QUANT UTILITIES (THE SAFE ROOM)
# ==============================================================================


class QuantUtils:
    """
    MATHEMATICAL KERNEL REGISTRY: THE SINGLE SOURCE OF TRUTH.
    Handles both pd.Series (Report) and pd.DataFrame (Ranking) robustly.
    """

    @staticmethod
    def compute_returns(
        data: Union[pd.Series, pd.DataFrame],
    ) -> Union[pd.Series, pd.DataFrame]:
        return data.pct_change().replace([np.inf, -np.inf], np.nan)

    @staticmethod
    def calculate_gain(data: Union[pd.Series, pd.DataFrame]) -> Union[float, pd.Series]:
        if data.empty:
            return 0.0
        res = (data.ffill().iloc[-1] / data.bfill().iloc[0]) - 1

        if isinstance(res, (pd.Series, pd.DataFrame)):
            return res.replace([np.inf, -np.inf], np.nan).fillna(0.0)
        return float(res) if np.isfinite(res) else 0.0

    @staticmethod
    def calculate_sharpe(
        data: Union[pd.Series, pd.DataFrame], periods: int = 252
    ) -> Union[float, pd.Series]:
        mu, std = data.mean(), data.std()
        # Use np.maximum for universal floor (works on scalars and Series)
        res = (mu / np.maximum(std, 1e-8)) * np.sqrt(periods)

        if isinstance(res, (pd.Series, pd.DataFrame)):
            return res.replace([np.inf, -np.inf], np.nan).fillna(0.0)
        return float(res) if np.isfinite(res) else 0.0

    @staticmethod
    def calculate_sharpe_vol(
        returns: Union[pd.Series, pd.DataFrame],
        vol_data: Union[pd.Series, pd.DataFrame],
    ) -> Union[float, pd.Series]:
        """
        Unified Math for Sharpe(ATR) and Sharpe(TRP).
        Logic: Reward / Risk.
        """
        avg_ret = returns.mean()

        # --- DEFENSIVE LOGIC ---
        # If returns is a DataFrame, we are in RANKING mode (cross-sectional).
        # In this mode, vol_data is expected to be a Series indexed by Ticker.
        # Calling .mean() on it would collapse it to a market-average scalar (The Bug).
        if isinstance(returns, pd.DataFrame) and isinstance(vol_data, pd.Series):
            avg_vol = vol_data
        else:
            # We are in REPORT mode (single portfolio) or raw arrays.
            avg_vol = vol_data.mean()

        res = avg_ret / np.maximum(avg_vol, 1e-8)

        if isinstance(res, (pd.Series, pd.DataFrame)):
            return res.replace([np.inf, -np.inf], np.nan).fillna(0.0)
        return float(res) if np.isfinite(res) else 0.0

    @staticmethod
    def compute_portfolio_stats(
        prices: pd.DataFrame,
        atrp_matrix: pd.DataFrame,
        trp_matrix: pd.DataFrame,
        weights: pd.Series,
    ) -> Tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
        """
        MATRIX KERNEL: Calculates equity curve and weighted volatility.
        """
        # 1. Equity Curve Logic (Price-Weighted Drift)
        norm_prices = prices.div(prices.bfill().iloc[0])
        weighted_components = norm_prices.mul(weights, axis=1)
        equity_curve = weighted_components.sum(axis=1)

        # MANDATORY: Use internal compute_returns to preserve boundary NaN
        returns_WITH_BOUNDARY_NAN = QuantUtils.compute_returns(equity_curve)

        # 2. Portfolio Volatility Logic (Weighted Average)
        # We calculate current_weights (rebalanced daily by price drift)
        current_weights = weighted_components.div(equity_curve, axis=0)

        # Weighted average of ATRP and TRP
        portfolio_atrp = (current_weights * atrp_matrix).sum(axis=1, min_count=1)
        portfolio_trp = (current_weights * trp_matrix).sum(axis=1, min_count=1)

        return equity_curve, returns_WITH_BOUNDARY_NAN, portfolio_atrp, portfolio_trp


# ==============================================================================
# SECTION B: STRATEGY HELPERS & FEATURES
# ==============================================================================


def generate_features(
    df_ohlcv: pd.DataFrame,
    df_indices: pd.DataFrame = None,
    benchmark_ticker: str = "SPY",
    atr_period: int = 14,
    rsi_period: int = 14,
    quality_window: int = 252,
    quality_min_periods: int = 126,
) -> pd.DataFrame:

    print(f"‚ö° Generating SOTA Quant Features (Benchmark: {benchmark_ticker})...")

    # 1. Sort and Group
    if not df_ohlcv.index.is_monotonic_increasing:
        df_ohlcv = df_ohlcv.sort_index()
    grouped = df_ohlcv.groupby(level="Ticker")

    # 2. VECTORIZED ATR (Wilder's)
    prev_close = grouped["Adj Close"].shift(1)
    high_low = df_ohlcv["Adj High"] - df_ohlcv["Adj Low"]
    high_prev = abs(df_ohlcv["Adj High"] - prev_close)
    low_prev = abs(df_ohlcv["Adj Low"] - prev_close)
    tr = pd.concat([high_low, high_prev, low_prev], axis=1).max(axis=1, skipna=False)

    atr = (
        tr.groupby(level="Ticker")
        .ewm(alpha=1 / atr_period, adjust=False)
        .mean()
        .reset_index(level=0, drop=True)
    )
    atrp = (atr / df_ohlcv["Adj Close"]).replace([np.inf, -np.inf], np.nan)
    trp = (tr / df_ohlcv["Adj Close"]).replace([np.inf, -np.inf], np.nan)

    # 3. VECTORIZED RSI
    delta = grouped["Adj Close"].diff()
    up = delta.clip(lower=0)
    down = -1 * delta.clip(upper=0)

    ma_up = (
        up.groupby(level="Ticker")
        .ewm(alpha=1 / rsi_period, adjust=False)
        .mean()
        .reset_index(level=0, drop=True)
    )
    ma_down = (
        down.groupby(level="Ticker")
        .ewm(alpha=1 / rsi_period, adjust=False)
        .mean()
        .reset_index(level=0, drop=True)
    )

    rsi = 100 - (100 / (1 + (ma_up / ma_down)))
    rsi = rsi.replace([np.inf, -np.inf], 50).fillna(50)

    # 4. OBV (Ticker Specific)
    direction = np.sign(delta).fillna(0)
    obv_raw = (direction * df_ohlcv["Volume"]).groupby(level="Ticker").cumsum()
    obv_roll_mean = (
        obv_raw.groupby(level="Ticker")
        .rolling(21)
        .mean()
        .reset_index(level=0, drop=True)
    )
    obv_roll_std = (
        obv_raw.groupby(level="Ticker")
        .rolling(21)
        .std()
        .reset_index(level=0, drop=True)
    )
    obv_score = (
        ((obv_raw - obv_roll_mean) / obv_roll_std)
        .fillna(0.0)
        .clip(lower=-5.0, upper=5.0)
    )

    # Dollar Volume
    dollar_vol_series = df_ohlcv["Adj Close"] * df_ohlcv["Volume"]

    # 5. BENCHMARK FEATURES (Price & Volume)
    bench_close_series = None
    bench_vol_series = None

    found_bench = False
    if (
        df_indices is not None
        and benchmark_ticker in df_indices.index.get_level_values(0)
    ):
        try:
            bench_close_series = df_indices.xs(benchmark_ticker, level=0)["Adj Close"]
            bench_vol_series = df_indices.xs(benchmark_ticker, level=0)["Volume"]
            found_bench = True
        except Exception:
            pass
    if not found_bench and benchmark_ticker in df_ohlcv.index.get_level_values(0):
        try:
            bench_close_series = df_ohlcv.xs(benchmark_ticker, level=0)["Adj Close"]
            bench_vol_series = df_ohlcv.xs(benchmark_ticker, level=0)["Volume"]
            found_bench = True
        except Exception:
            pass

    # Initialize Containers
    rel_strength_21 = pd.Series(0.0, index=df_ohlcv.index)
    spy_rvol = pd.Series(1.0, index=df_ohlcv.index)
    spy_obv_score = pd.Series(0.0, index=df_ohlcv.index)  # <--- NEW CONTAINER

    if found_bench:
        try:
            # A. Relative Strength
            bench_close_aligned = bench_close_series.reindex(
                df_ohlcv.index.get_level_values("Date")
            ).values
            rel_ratio = df_ohlcv["Adj Close"] / bench_close_aligned
            rel_strength_21 = (
                rel_ratio.groupby(level="Ticker")
                .pct_change(21, fill_method=None)
                .fillna(0.0)
            )

            # B. Spy RVol (Magnitude)
            bench_dvol = bench_close_series * bench_vol_series
            bench_dvol_avg = bench_dvol.rolling(21).mean()
            bench_rvol_raw = (bench_dvol / bench_dvol_avg).fillna(1.0)

            # C. SPY OBV Score (Direction) <--- NEW
            # Calculate OBV for SPY specifically
            spy_delta = bench_close_series.diff()
            spy_direction = np.sign(spy_delta).fillna(0)
            spy_obv_raw = (spy_direction * bench_vol_series).cumsum()

            # Normalize SPY OBV (Z-Score)
            spy_obv_mean = spy_obv_raw.rolling(21).mean()
            spy_obv_std = spy_obv_raw.rolling(21).std()
            spy_obv_z = (
                ((spy_obv_raw - spy_obv_mean) / spy_obv_std)
                .fillna(0.0)
                .clip(lower=-5.0, upper=5.0)
            )

            # D. BROADCAST TO ALL TICKERS
            # Reindex creates a Series aligned to the full DataFrame (Date-matched)
            spy_rvol_values = (
                bench_rvol_raw.reindex(df_ohlcv.index.get_level_values("Date"))
                .fillna(1.0)
                .values
            )
            spy_obv_values = (
                spy_obv_z.reindex(df_ohlcv.index.get_level_values("Date"))
                .fillna(0.0)
                .values
            )

            spy_rvol = pd.Series(spy_rvol_values, index=df_ohlcv.index).clip(upper=10.0)
            spy_obv_score = pd.Series(spy_obv_values, index=df_ohlcv.index)  # <--- NEW

        except Exception as e:
            print(f"‚ö†Ô∏è Benchmark Math Error: {e}")

    # 6. TICKER RELATIVE VOLUME (RVol)
    dvol_grouped = dollar_vol_series.groupby(level="Ticker")
    dvol_avg = dvol_grouped.rolling(21).mean().reset_index(level=0, drop=True)
    ticker_rvol = (
        (dollar_vol_series / dvol_avg)
        .replace([np.inf, -np.inf], 1.0)
        .fillna(1.0)
        .clip(upper=10.0)
    )

    # 7. MOMENTUM / RETURN FEATURES
    daily_returns = grouped["Adj Close"].pct_change(1, fill_method=None)
    roc_1 = daily_returns
    roc_3 = grouped["Adj Close"].pct_change(3, fill_method=None)
    roc_5 = grouped["Adj Close"].pct_change(5, fill_method=None)
    roc_10 = grouped["Adj Close"].pct_change(10, fill_method=None)
    roc_21 = grouped["Adj Close"].pct_change(21, fill_method=None)

    # 8. VOLATILITY REGIME
    returns_grouped = daily_returns.groupby(level="Ticker")
    std_5 = returns_grouped.rolling(5).std().reset_index(level=0, drop=True)
    std_21 = returns_grouped.rolling(21).std().reset_index(level=0, drop=True)

    if std_5.index.nlevels > df_ohlcv.index.nlevels:
        std_5 = std_5.reset_index(level=0, drop=True)
        std_21 = std_21.reset_index(level=0, drop=True)

    vol_regime = (std_5 / std_21).replace([np.inf, -np.inf], 1.0)

    # 9. MERGE
    indicator_df = pd.DataFrame(
        {
            "ATR": atr,
            "ATRP": atrp,
            "TRP": trp,  # <--- PINPOINT CHANGE: Add to output
            "RSI": rsi,
            "RelStrength": rel_strength_21,
            "VolRegime": vol_regime,
            "RVol": ticker_rvol,
            "Spy_RVol": spy_rvol,
            "OBV_Score": obv_score,
            "Spy_OBV_Score": spy_obv_score,  # <--- NEW
            "ROC_1": roc_1,
            "ROC_3": roc_3,
            "ROC_5": roc_5,
            "ROC_10": roc_10,
            "ROC_21": roc_21,
        }
    )

    # 10. Quality/Liquidity Features
    quality_temp_df = pd.DataFrame(
        {
            "IsStale": np.where(
                (df_ohlcv["Volume"] == 0)
                | (df_ohlcv["Adj High"] == df_ohlcv["Adj Low"]),
                1,
                0,
            ),
            "DollarVolume": dollar_vol_series,
            "HasSameVolume": (grouped["Volume"].diff() == 0).astype(int),
        },
        index=df_ohlcv.index,
    )

    rolling_result = (
        quality_temp_df.groupby(level="Ticker")
        .rolling(window=quality_window, min_periods=quality_min_periods)
        .agg({"IsStale": "mean", "DollarVolume": "median", "HasSameVolume": "sum"})
        .rename(
            columns={
                "IsStale": "RollingStalePct",
                "DollarVolume": "RollMedDollarVol",
                "HasSameVolume": "RollingSameVolCount",
            }
        )
        .reset_index(level=0, drop=True)
    )

    return pd.concat([indicator_df, rolling_result], axis=1)


def _prepare_initial_weights(tickers: List[str]) -> pd.Series:
    """
    METADATA: Converts a list of tickers into a weight map.
    Example: ['AAPL', 'AAPL', 'TSLA'] -> {'AAPL': 0.66, 'TSLA': 0.33}
    """
    ticker_counts = Counter(tickers)
    total = len(tickers)
    return pd.Series({t: c / total for t, c in ticker_counts.items()})


def calculate_buy_and_hold_performance(
    df_close_wide: pd.DataFrame,  # Use the WIDE version
    df_atrp_wide: pd.DataFrame,  # Use the WIDE version
    df_trp_wide: pd.DataFrame,  # <--- Added
    tickers: List[str],
    start_date: pd.Timestamp,
    end_date: pd.Timestamp,
):
    if not tickers:
        return pd.Series(), pd.Series(), pd.Series()

    initial_weights = _prepare_initial_weights(tickers)

    # SLICE (Fix Part B)
    ticker_list = initial_weights.index.tolist()
    p_slice = df_close_wide.reindex(columns=ticker_list).loc[start_date:end_date]
    a_slice = df_atrp_wide.reindex(columns=ticker_list).loc[start_date:end_date]
    t_slice = df_trp_wide.reindex(columns=ticker_list).loc[start_date:end_date]
    # KERNEL - Pure Math
    return QuantUtils.compute_portfolio_stats(
        p_slice, a_slice, t_slice, initial_weights
    )


# ==============================================================================
# SECTION C: METRIC REGISTRY
# ==============================================================================


class MarketObservation(TypedDict):
    """
    The 'STATE' (Observation) in Reinforcement Learning.
    This defines the context given to the agent to make a decision.
    """

    # --- The Movie (Time Series) ---
    lookback_returns: pd.DataFrame  # (Time x Tickers)
    lookback_close: pd.DataFrame  # (Time x Tickers)

    # --- The Snapshot (Scalar values at Decision Time) ---
    atrp: pd.Series  # Volatility (Mean over lookback)
    trp: pd.Series  # Volatility (Snapshot)

    # NEW SENSORS
    rsi: pd.Series  # Internal Momentum (0-100)
    rel_strength: pd.Series  # Performance vs SPY
    vol_regime: pd.Series  # Volatility Expansion/Compression
    rvol: pd.Series  # Ticker Conviction
    spy_rvol: pd.Series  # Market Participation
    obv_score: pd.Series  # Ticker Accumulation/Distribution
    spy_obv_score: pd.Series  # Market Tide

    # MOMENTUM VECTORS
    roc_1: pd.Series
    roc_3: pd.Series
    roc_5: pd.Series
    roc_10: pd.Series
    roc_21: pd.Series


METRIC_REGISTRY = {
    # --- CLASSIC METRICS ---
    "Price": lambda obs: QuantUtils.calculate_gain(obs["lookback_close"]),
    "Sharpe": lambda obs: QuantUtils.calculate_sharpe(obs["lookback_returns"]),
    "Sharpe (ATR)": lambda obs: QuantUtils.calculate_sharpe_vol(
        obs["lookback_returns"], obs["atrp"]
    ),
    "Sharpe (TRP)": lambda obs: QuantUtils.calculate_sharpe_vol(
        obs["lookback_returns"], obs["trp"]
    ),  # <--- New Strategy
    # --- MOMENTUM VECTORS ---
    "Momentum 1D": lambda obs: obs["roc_1"],
    "Momentum 3D": lambda obs: obs["roc_3"],
    "Momentum 5D": lambda obs: obs["roc_5"],
    "Momentum 10D": lambda obs: obs["roc_10"],
    "Momentum 1M": lambda obs: obs["roc_21"],
    # --- PULLBACK VECTORS ---
    "Pullback 1D": lambda obs: -obs["roc_1"],
    "Pullback 3D": lambda obs: -obs["roc_3"],
    "Pullback 5D": lambda obs: -obs["roc_5"],
    "Pullback 10D": lambda obs: -obs["roc_10"],
    "Pullback 1M": lambda obs: -obs["roc_21"],
    # --- NEW SOTA SENSORS ---
    "RSI (Reversal)": lambda obs: -obs["rsi"],  # Rank Low RSI (Oversold) higher
    "RSI (Trend)": lambda obs: obs["rsi"],  # Rank High RSI (Strong Trend) higher
    "Alpha (RelStrength)": lambda obs: obs["rel_strength"],  # Rank stocks beating SPY
    "OBV (Accumulation)": lambda obs: obs["obv_score"],  # Rank High OBV Score
    "Volume Conviction": lambda obs: obs["rvol"],  # Rank High Relative Volume
    "Volatility Regime (Breakout)": lambda obs: obs[
        "vol_regime"
    ],  # Rank High Volatility Expansion
}


# ==============================================================================
# SECTION D: DATA CONTRACTS
# ==============================================================================


@dataclass
class EngineInput:
    mode: str
    start_date: pd.Timestamp
    lookback_period: int
    holding_period: int
    metric: str
    benchmark_ticker: str
    rank_start: int = 1
    rank_end: int = 10
    # Default factory pulls from Global thresholds
    quality_thresholds: Dict[str, float] = field(
        default_factory=lambda: GLOBAL_SETTINGS["thresholds"].copy()
    )
    manual_tickers: List[str] = field(default_factory=list)
    debug: bool = False


@dataclass
class EngineOutput:
    # 1. CORE DATA (Required - No Defaults)
    portfolio_series: pd.Series
    benchmark_series: pd.Series
    normalized_plot_data: pd.DataFrame
    tickers: List[str]
    initial_weights: pd.Series
    perf_metrics: Dict[str, float]
    results_df: pd.DataFrame

    # 2. TIMELINE (Required - No Defaults)
    start_date: pd.Timestamp
    decision_date: pd.Timestamp
    buy_date: pd.Timestamp
    holding_end_date: pd.Timestamp

    # 3. OPTIONAL / AUDIT DATA (Must be at the bottom because they have defaults)
    portfolio_trp_series: Optional[pd.Series] = None
    benchmark_trp_series: Optional[pd.Series] = None
    error_msg: Optional[str] = None
    debug_data: Optional[Dict[str, Any]] = None


class AlphaEngine:
    def __init__(
        self,
        df_ohlcv: pd.DataFrame,
        features_df: pd.DataFrame = None,
        df_close_wide: pd.DataFrame = None,
        df_atrp_wide: pd.DataFrame = None,
        df_trp_wide: pd.DataFrame = None,
        master_ticker: str = GLOBAL_SETTINGS["calendar_ticker"],
    ):
        print("--- ‚öôÔ∏è Initializing AlphaEngine v2.2 (Sanitized) ---")

        # 1. SETUP PRICES (CLEAN-AT-ENTRY)
        if df_close_wide is not None:
            self.df_close = df_close_wide
        else:
            print("üê¢ Pivoting and Sanitizing Price Data...")
            self.df_close = df_ohlcv["Adj Close"].unstack(level=0)

        # APPLY DATA SANITIZER LOGIC
        if GLOBAL_SETTINGS["handle_zeros_as_nan"]:
            # Replace 0.0 with NaN so math functions (mean/std) ignore them
            self.df_close = self.df_close.replace(0, np.nan)

        # Smooth over 1-2 day glitches (The "FNV" Fix)
        self.df_close = self.df_close.ffill(limit=GLOBAL_SETTINGS["max_data_gap_ffill"])

        # Handle the remaining "unfillable" gaps
        self.df_close = self.df_close.fillna(GLOBAL_SETTINGS["nan_price_replacement"])

        # 2. SETUP FEATURES
        if features_df is not None:
            self.features_df = features_df
        else:
            # We pass the cleaned price data if needed, or calculate from raw
            self.features_df = generate_features(
                df_ohlcv,
                atr_period=GLOBAL_SETTINGS["atr_period"],
                quality_window=GLOBAL_SETTINGS["quality_window"],
                quality_min_periods=GLOBAL_SETTINGS["quality_min_periods"],
            )

        # 1. SETUP ATRP
        if df_atrp_wide is not None:
            # INSTANT: Use the matrix precomputed outside the UI
            self.df_atrp = df_atrp_wide
        else:
            # SLOW FALLBACK: Only runs if you forget to precompute
            print("üöÄ Pre-aligning Volatility (ATRP) Matrix (Slow Fallback)...")
            self.df_atrp = self.features_df["ATRP"].unstack(level=0)

        # 2. SETUP TRP
        if df_trp_wide is not None:
            self.df_trp = df_trp_wide
        else:
            print("üöÄ Pre-aligning Volatility (TRP) Matrix (Slow Fallback)...")
            self.df_trp = self.features_df["TRP"].unstack(level=0)

        # 3. FINAL ALIGNMENT (The "Safety Seal")
        # Ensures all matrices have the exact same Dimensions, Tickers, and Dates
        common_idx = self.df_close.index
        common_cols = self.df_close.columns

        self.df_atrp = self.df_atrp.reindex(index=common_idx, columns=common_cols)
        self.df_trp = self.df_trp.reindex(index=common_idx, columns=common_cols)

        # 3. Setup Calendar
        if master_ticker not in self.df_close.columns:
            master_ticker = self.df_close.columns[0]
        self.trading_calendar = (
            self.df_close[master_ticker].dropna().index.unique().sort_values()
        )

    def run(self, inputs: EngineInput) -> EngineOutput:
        dates, error = self._validate_timeline(inputs)
        if error:
            return self._error_result(error)
        (safe_start, safe_decision, safe_buy, safe_end) = dates

        tickers_to_trade, results_table, debug_dict, error = self._select_tickers(
            inputs, safe_start, safe_decision
        )
        if error:
            return self._error_result(error)

        # GENERATE TRACKS (Fix Part A)
        p_f_val, p_f_ret, p_f_atrp, p_f_trp = calculate_buy_and_hold_performance(
            self.df_close,
            self.df_atrp,
            self.df_trp,
            tickers_to_trade,
            safe_start,
            safe_end,
        )
        b_f_val, b_f_ret, b_f_atrp, b_f_trp = calculate_buy_and_hold_performance(
            self.df_close,
            self.df_atrp,
            self.df_trp,
            [inputs.benchmark_ticker],
            safe_start,
            safe_end,
        )

        p_h_val, p_h_ret, p_h_atrp, p_h_trp = calculate_buy_and_hold_performance(
            self.df_close,
            self.df_atrp,
            self.df_trp,
            tickers_to_trade,
            safe_buy,
            safe_end,
        )
        b_h_val, b_h_ret, b_h_atrp, b_h_trp = calculate_buy_and_hold_performance(
            self.df_close,
            self.df_atrp,
            self.df_trp,
            [inputs.benchmark_ticker],
            safe_buy,
            safe_end,
        )

        # CALCULATE METRICS
        p_metrics, p_slices = self._calculate_period_metrics(
            p_f_val,
            p_f_ret,
            p_f_atrp,
            p_f_trp,
            safe_decision,
            p_h_val,
            p_h_ret,
            p_h_atrp,
            p_h_trp,
            prefix="p",
        )
        b_metrics, b_slices = self._calculate_period_metrics(
            b_f_val,
            b_f_ret,
            b_f_atrp,
            b_f_trp,
            safe_decision,
            b_h_val,
            b_h_ret,
            b_h_atrp,
            b_h_trp,
            prefix="b",
        )

        # CONSOLIDATE DEBUG DATA
        debug_dict["verification"] = {"portfolio": p_slices, "benchmark": b_slices}

        # ADD RAW COMPONENT EXPORTS
        debug_dict["portfolio_raw_components"] = {
            "prices": self.df_close[tickers_to_trade].loc[safe_start:safe_end],
            "atrp": self.features_df.loc[
                (tickers_to_trade, slice(safe_start, safe_end)), "ATRP"
            ].unstack(level=0),
        }
        debug_dict["benchmark_raw_components"] = {
            "prices": self.df_close[[inputs.benchmark_ticker]].loc[safe_start:safe_end],
            "atrp": self.features_df.loc[
                ([inputs.benchmark_ticker], slice(safe_start, safe_end)), "ATRP"
            ].unstack(level=0),
        }

        # FINAL OUTPUT
        results_table["Holding Gain"] = (p_h_val.iloc[-1] / p_h_val.iloc[0]) - 1

        # 1. FINAL CALCULATION / PRE-PACKING
        # Merge existing rankings or audits into the dict before sealing the result
        debug_dict["selection_audit"] = debug_dict.get("full_universe_ranking")

        # 2. CREATE THE OUTPUT OBJECT (The "Seal")
        res_output = EngineOutput(
            portfolio_series=p_f_val,
            benchmark_series=b_f_val,
            portfolio_trp_series=p_f_trp,
            benchmark_trp_series=b_f_trp,
            normalized_plot_data=self._get_normalized_plot_data(
                tickers_to_trade, safe_start, safe_end
            ),
            tickers=tickers_to_trade,
            initial_weights=pd.Series(
                {t: 1 / len(tickers_to_trade) for t in tickers_to_trade}
            ),
            perf_metrics={**p_metrics, **b_metrics},
            results_df=results_table,
            start_date=safe_start,
            decision_date=safe_decision,
            buy_date=safe_buy,
            holding_end_date=safe_end,
            # --- PINPOINT FIX: YOU MUST PASS THE DICT HERE ---
            debug_data=debug_dict,
        )

        return res_output

    # ==============================================================================
    # INTERNAL LOGIC MODULES
    # ==============================================================================

    def _validate_timeline(self, inputs: EngineInput):
        cal = self.trading_calendar
        last_idx = len(cal) - 1

        if len(cal) <= inputs.lookback_period:
            return (
                None,
                f"‚ùå Dataset too small.\nNeed > {inputs.lookback_period} days of history.",
            )

        # 2. Check "Past" Constraints (Lookback)
        min_decision_date = cal[inputs.lookback_period]
        if inputs.start_date < min_decision_date:
            # Added \n here
            return None, (
                f"‚ùå Not enough history for a {inputs.lookback_period}-day lookback.\n"
                f"Earliest valid Decision Date: {min_decision_date.date()}"
            )

        # 3. Check "Future" Constraints (Entry T+1 and Holding Period)
        required_future_days = 1 + inputs.holding_period
        latest_valid_idx = last_idx - required_future_days

        if latest_valid_idx < 0:
            return (
                None,
                f"‚ùå Holding period too long.\n{inputs.holding_period} days exceeds available data.",
            )

        # If user picked a date beyond the available "future" runway
        if inputs.start_date > cal[latest_valid_idx]:
            latest_date = cal[latest_valid_idx].date()
            # Added \n here and shortened the text slightly to fit better
            return None, (
                f"‚ùå Decision Date too late for a {inputs.holding_period}-day hold.\n"
                f"Latest valid date: {latest_date}. Please move picker back."
            )

        # 4. Map the safe indices
        decision_idx = cal.searchsorted(inputs.start_date)
        if decision_idx > latest_valid_idx:
            decision_idx = latest_valid_idx

        start_idx = decision_idx - inputs.lookback_period
        entry_idx = decision_idx + 1
        end_idx = entry_idx + inputs.holding_period

        return (cal[start_idx], cal[decision_idx], cal[entry_idx], cal[end_idx]), None

    def _select_tickers(self, inputs: EngineInput, start_date, decision_date):
        debug_dict = {}

        # --- PATH A: MANUAL LIST ---
        if inputs.mode == "Manual List":
            validation_errors = []
            valid_tickers = []
            for t in inputs.manual_tickers:
                if t not in self.df_close.columns:
                    validation_errors.append(f"‚ùå {t}: Not found.")
                    continue
                if pd.isna(self.df_close.at[start_date, t]):
                    validation_errors.append(f"‚ö†Ô∏è {t}: No data on start date.")
                    continue
                valid_tickers.append(t)

            if validation_errors:
                return [], pd.DataFrame(), {}, "\n".join(validation_errors)
            if not valid_tickers:
                return [], pd.DataFrame(), {}, "No valid tickers found."
            return valid_tickers, pd.DataFrame(index=valid_tickers), {}, None

        # --- PATH B: RANKING ---
        else:
            audit_info = {}
            eligible_tickers = self._filter_universe(
                decision_date, inputs.quality_thresholds, audit_info
            )
            debug_dict["audit_liquidity"] = audit_info

            if not eligible_tickers:
                return (
                    [],
                    pd.DataFrame(),
                    debug_dict,
                    "No tickers passed quality filters.",
                )

            lookback_close = self.df_close.loc[
                start_date:decision_date, eligible_tickers
            ]

            # 1. Get the Snapshot of Features for the Decision Date
            feat_slice_current = self.features_df.xs(
                decision_date, level="Date"
            ).reindex(eligible_tickers)

            # Calculate mean ATRP over the lookback period
            idx_product = pd.MultiIndex.from_product(
                [eligible_tickers, lookback_close.index], names=["Ticker", "Date"]
            )
            feat_slice_period = self.features_df.reindex(idx_product)
            atrp_value_for_obs = (
                feat_slice_period["ATRP"].groupby(level="Ticker").mean()
            )

            # --- PINPOINT CHANGE: Calculate mean TRP over lookback ---
            trp_value_for_obs = feat_slice_period["TRP"].groupby(level="Ticker").mean()

            # Update the observation dictionary
            observation: MarketObservation = {
                # ...
                "atrp": atrp_value_for_obs,
                "trp": trp_value_for_obs,  # <--- PINPOINT CHANGE: Pass the lookback mean
                # ...
            }

            # 2. Package the Observation (The 'State')
            observation: MarketObservation = {
                # Time Series Data
                "lookback_close": lookback_close,
                "lookback_returns": lookback_close.ffill().pct_change(),
                # Snapshot Data (Scalar values for today)
                "atrp": atrp_value_for_obs,  # <--- USES THE TOGGLED VALUE HERE
                "trp": trp_value_for_obs,  # <--- PINPOINT CHANGE: Pass the lookback mean
                "rsi": feat_slice_current["RSI"],
                "rel_strength": feat_slice_current["RelStrength"],
                "vol_regime": feat_slice_current["VolRegime"],
                "rvol": feat_slice_current["RVol"],
                "spy_rvol": feat_slice_current["Spy_RVol"],
                "obv_score": feat_slice_current["OBV_Score"],
                "spy_obv_score": feat_slice_current["Spy_OBV_Score"],
                # Momentum Vectors
                "roc_1": feat_slice_current["ROC_1"],
                "roc_3": feat_slice_current["ROC_3"],
                "roc_5": feat_slice_current["ROC_5"],
                "roc_10": feat_slice_current["ROC_10"],
                "roc_21": feat_slice_current["ROC_21"],
            }

            # 3. Run the Strategy (The 'Agent')
            if inputs.metric not in METRIC_REGISTRY:
                return [], pd.DataFrame(), {}, f"Strategy '{inputs.metric}' not found."

            metric_vals = METRIC_REGISTRY[inputs.metric](observation)
            sorted_tickers = metric_vals.sort_values(ascending=False)
            start_r = max(0, inputs.rank_start - 1)
            end_r = inputs.rank_end
            selected_tickers = sorted_tickers.iloc[start_r:end_r].index.tolist()

            # Audit
            debug_dict["full_universe_ranking"] = pd.DataFrame(
                {
                    "Strategy_Score": metric_vals,
                    "Lookback_Return_Ann": observation["lookback_returns"].mean() * 252,
                    "Lookback_ATRP": observation["atrp"],
                }
            )

            if not selected_tickers:
                return (
                    [],
                    pd.DataFrame(),
                    debug_dict,
                    "No tickers generated from ranking.",
                )

            results_table = pd.DataFrame(
                {
                    "Rank": range(
                        inputs.rank_start, inputs.rank_start + len(selected_tickers)
                    ),
                    "Ticker": selected_tickers,
                    "Strategy Value": sorted_tickers.loc[selected_tickers].values,
                }
            ).set_index("Ticker")

            return selected_tickers, results_table, debug_dict, None

    def _filter_universe(self, date_ts, thresholds, audit_container=None):
        avail_dates = (
            self.features_df.index.get_level_values("Date").unique().sort_values()
        )
        valid_dates = avail_dates[avail_dates <= date_ts]
        if valid_dates.empty:
            return []
        target_date = valid_dates[-1]
        day_features = self.features_df.xs(target_date, level="Date")

        vol_cutoff = thresholds.get("min_median_dollar_volume", 0)
        percentile_used = "N/A"
        if "min_liquidity_percentile" in thresholds:
            percentile_used = thresholds["min_liquidity_percentile"]
            dynamic_val = day_features["RollMedDollarVol"].quantile(percentile_used)
            vol_cutoff = max(vol_cutoff, dynamic_val)

        mask = (
            (day_features["RollMedDollarVol"] >= vol_cutoff)
            & (day_features["RollingStalePct"] <= thresholds["max_stale_pct"])
            & (day_features["RollingSameVolCount"] <= thresholds["max_same_vol_count"])
        )

        if audit_container is not None:
            audit_container["date"] = target_date
            audit_container["total_tickers_available"] = len(day_features)
            audit_container["percentile_setting"] = percentile_used
            audit_container["final_cutoff_usd"] = vol_cutoff
            audit_container["tickers_passed"] = mask.sum()
            snapshot = day_features.copy()
            snapshot["Calculated_Cutoff"] = vol_cutoff
            snapshot["Passed_Vol_Check"] = snapshot["RollMedDollarVol"] >= vol_cutoff
            snapshot["Passed_Final"] = mask
            snapshot = snapshot.sort_values("RollMedDollarVol", ascending=False)
            audit_container["universe_snapshot"] = snapshot

        return day_features[mask].index.tolist()

    def _calculate_period_metrics(
        self,
        f_val,
        f_ret,
        f_atrp,
        f_trp,
        decision_date,
        h_val,
        h_ret,
        h_atrp,
        h_trp,
        prefix,  # <--- Added trp args
    ):
        metrics = {}
        slices = {}

        # 1. Temporal Slicing (Routing)
        lb_val, lb_ret, lb_atrp, lb_trp = (
            f_val.loc[:decision_date],
            f_ret.loc[:decision_date],
            f_atrp.loc[:decision_date],
            f_trp.loc[:decision_date],
        )

        # Use the new unified QuantUtils calls
        metrics[f"full_{prefix}_gain"] = QuantUtils.calculate_gain(f_val)
        metrics[f"full_{prefix}_sharpe"] = QuantUtils.calculate_sharpe(f_ret)
        metrics[f"full_{prefix}_sharpe_atr"] = QuantUtils.calculate_sharpe_vol(
            f_ret, f_atrp
        )
        metrics[f"full_{prefix}_sharpe_trp"] = QuantUtils.calculate_sharpe_vol(
            f_ret, f_trp
        )

        metrics[f"lookback_{prefix}_gain"] = QuantUtils.calculate_gain(lb_val)
        metrics[f"lookback_{prefix}_sharpe"] = QuantUtils.calculate_sharpe(lb_ret)
        metrics[f"lookback_{prefix}_sharpe_atr"] = QuantUtils.calculate_sharpe_vol(
            lb_ret, lb_atrp
        )
        metrics[f"lookback_{prefix}_sharpe_trp"] = QuantUtils.calculate_sharpe_vol(
            lb_ret, lb_trp
        )

        metrics[f"holding_{prefix}_gain"] = QuantUtils.calculate_gain(h_val)
        metrics[f"holding_{prefix}_sharpe"] = QuantUtils.calculate_sharpe(h_ret)
        metrics[f"holding_{prefix}_sharpe_atr"] = QuantUtils.calculate_sharpe_vol(
            h_ret, h_atrp
        )
        metrics[f"holding_{prefix}_sharpe_trp"] = QuantUtils.calculate_sharpe_vol(
            h_ret, h_trp
        )

        # 5. Metadata Collection
        (
            slices["full_val"],
            slices["full_ret"],
            slices["full_atrp"],
            slices["full_trp"],
        ) = (
            f_val,
            f_ret,
            f_atrp,
            f_trp,
        )
        (
            slices["lookback_val"],
            slices["lookback_ret"],
            slices["lookback_atrp"],
            slices["lookback_trp"],
        ) = (
            lb_val,
            lb_ret,
            lb_atrp,
            lb_trp,
        )
        (
            slices["holding_val"],
            slices["holding_ret"],
            slices["holding_atrp"],
            slices["holding_trp"],
        ) = (
            h_val,
            h_ret,
            h_atrp,
            h_trp,
        )

        return metrics, slices

    def _get_normalized_plot_data(self, tickers, start_date, end_date):
        if not tickers:
            return pd.DataFrame()
        data = self.df_close[list(set(tickers))].loc[start_date:end_date]
        if data.empty:
            return pd.DataFrame()
        return data / data.bfill().iloc[0]

    def _error_result(self, msg):
        return EngineOutput(
            portfolio_series=pd.Series(dtype=float),
            benchmark_series=pd.Series(dtype=float),
            normalized_plot_data=pd.DataFrame(),
            tickers=[],
            initial_weights=pd.Series(dtype=float),
            perf_metrics={},
            results_df=pd.DataFrame(),
            start_date=pd.Timestamp.min,
            decision_date=pd.Timestamp.min,
            buy_date=pd.Timestamp.min,
            holding_end_date=pd.Timestamp.min,
            error_msg=msg,
        )


# ==============================================================================
# SECTION E: THE UI (Visualization)
# ==============================================================================


def plot_walk_forward_analyzer(
    df_ohlcv,
    features_df=None,
    df_close_wide=None,
    df_atrp_wide=None,
    df_trp_wide=None,
    default_start_date="2025-01-17",
    default_lookback=10,
    default_holding=5,
    default_strategy="Sharpe (ATR)",
    default_rank_start=1,
    default_rank_end=10,
    default_benchmark_ticker=GLOBAL_SETTINGS["benchmark_ticker"],
    master_calendar_ticker=GLOBAL_SETTINGS["calendar_ticker"],
    quality_thresholds=GLOBAL_SETTINGS["thresholds"],
    debug=False,
):

    engine = AlphaEngine(
        df_ohlcv,
        features_df=features_df,
        df_close_wide=df_close_wide,
        df_atrp_wide=df_atrp_wide,
        df_trp_wide=df_trp_wide,  # <--- Update your class to accept this
        master_ticker=master_calendar_ticker,
    )

    # Initialize containers
    audit_pack = [None]  # Unified container

    # If no thresholds passed, use the global Source of Truth
    if quality_thresholds is None:
        quality_thresholds = GLOBAL_SETTINGS["thresholds"]

    # --- Widgets ---
    mode_selector = widgets.RadioButtons(
        options=["Ranking", "Manual List"],
        value="Ranking",
        description="Mode:",
        layout={"width": "max-content"},
        style={"description_width": "initial"},
    )
    lookback_input = widgets.IntText(
        value=default_lookback,
        description="Lookback (Days):",
        layout={"width": "200px"},
        style={"description_width": "initial"},
    )
    decision_date_picker = widgets.DatePicker(
        description="Decision Date:",
        value=pd.to_datetime(default_start_date),
        layout={"width": "auto"},
        style={"description_width": "initial"},
    )
    holding_input = widgets.IntText(
        value=default_holding,
        description="Holding (Days):",
        layout={"width": "200px"},
        style={"description_width": "initial"},
    )
    strategy_dropdown = widgets.Dropdown(
        options=list(METRIC_REGISTRY.keys()),
        value=default_strategy,
        description="Strategy:",
        layout={"width": "220px"},
        style={"description_width": "initial"},
    )
    benchmark_input = widgets.Text(
        value=default_benchmark_ticker,
        description="Benchmark:",
        placeholder="Enter Ticker",
        layout={"width": "180px"},
        style={"description_width": "initial"},
    )
    rank_start_input = widgets.IntText(
        value=default_rank_start,
        description="Rank Start:",
        layout={"width": "150px"},
        style={"description_width": "initial"},
    )
    rank_end_input = widgets.IntText(
        value=default_rank_end,
        description="Rank End:",
        layout={"width": "150px"},
        style={"description_width": "initial"},
    )
    manual_tickers_input = widgets.Textarea(
        value="",
        placeholder="Enter tickers...",
        description="Manual Tickers:",
        layout={"width": "400px", "height": "80px"},
        style={"description_width": "initial"},
    )
    update_button = widgets.Button(description="Run Simulation", button_style="primary")
    ticker_list_output = widgets.Output()

    # --- Layouts ---
    timeline_box = widgets.HBox(
        [lookback_input, decision_date_picker, holding_input],
        layout=widgets.Layout(
            justify_content="space-between",
            border="1px solid #ddd",
            padding="10px",
            margin="5px",
        ),
    )
    strategy_box = widgets.HBox([strategy_dropdown, benchmark_input])
    ranking_box = widgets.HBox([rank_start_input, rank_end_input])

    def on_mode_change(c):
        ranking_box.layout.display = "flex" if c["new"] == "Ranking" else "none"
        manual_tickers_input.layout.display = (
            "none" if c["new"] == "Ranking" else "flex"
        )
        strategy_dropdown.disabled = c["new"] == "Manual List"

    mode_selector.observe(on_mode_change, names="value")
    on_mode_change({"new": mode_selector.value})

    ui = widgets.VBox(
        [
            widgets.HTML(
                "<b>1. Timeline Configuration:</b> (Past <--- Decision ---> Future)"
            ),
            timeline_box,
            widgets.HTML("<b>2. Strategy Settings:</b>"),
            widgets.HBox([mode_selector, strategy_box]),
            ranking_box,
            manual_tickers_input,
            widgets.HTML("<hr>"),
            update_button,
            ticker_list_output,
        ],
        layout=widgets.Layout(margin="10px 0 20px 0"),
    )

    fig = go.FigureWidget()
    fig.update_layout(
        title="Event-Driven Walk-Forward Analysis",
        height=600,
        template="plotly_white",
        hovermode="x unified",
    )
    for i in range(50):
        fig.add_trace(go.Scatter(visible=False, line=dict(width=2)))
    fig.add_trace(
        go.Scatter(
            name="Benchmark",
            visible=True,
            line=dict(color="black", width=3, dash="dash"),
        )
    )
    fig.add_trace(
        go.Scatter(
            name="Group Portfolio", visible=True, line=dict(color="green", width=3)
        )
    )

    # --- Update Logic ---
    def update_plot(b):
        ticker_list_output.clear_output()
        manual_list = [
            t.strip().upper()
            for t in manual_tickers_input.value.split(",")
            if t.strip()
        ]
        decision_date_raw = pd.to_datetime(decision_date_picker.value)

        inputs = EngineInput(
            mode=mode_selector.value,
            start_date=decision_date_raw,
            lookback_period=lookback_input.value,
            holding_period=holding_input.value,
            metric=strategy_dropdown.value,
            benchmark_ticker=benchmark_input.value.strip().upper(),
            rank_start=rank_start_input.value,
            rank_end=rank_end_input.value,
            quality_thresholds=quality_thresholds,
            manual_tickers=manual_list,
            debug=debug,
        )

        # --- CAPTURE INPUTS FOR AUDIT ---
        with ticker_list_output:
            res = engine.run(inputs)
            audit_pack[0] = {"inputs": inputs, "results": res, "debug": res.debug_data}

            if res.error_msg:
                print(f"‚ö†Ô∏è Simulation Stopped: {res.error_msg}")
                return

            # Plotting
            with fig.batch_update():
                cols = res.normalized_plot_data.columns.tolist()
                for i in range(50):
                    if i < len(cols):
                        fig.data[i].update(
                            x=res.normalized_plot_data.index,
                            y=res.normalized_plot_data[cols[i]],
                            name=cols[i],
                            visible=True,
                        )
                    else:
                        fig.data[i].visible = False

                fig.data[50].update(
                    x=res.benchmark_series.index,
                    y=res.benchmark_series.values,
                    name=f"Benchmark ({inputs.benchmark_ticker})",
                    visible=not res.benchmark_series.empty,
                )
                fig.data[51].update(
                    x=res.portfolio_series.index,
                    y=res.portfolio_series.values,
                    visible=True,
                )

                # Visual Lines
                fig.layout.shapes = [
                    dict(
                        type="line",
                        x0=res.decision_date,
                        y0=0,
                        x1=res.decision_date,
                        y1=1,
                        xref="x",
                        yref="paper",
                        line=dict(color="red", width=2, dash="dash"),
                    ),
                    dict(
                        type="line",
                        x0=res.buy_date,
                        y0=0,
                        x1=res.buy_date,
                        y1=1,
                        xref="x",
                        yref="paper",
                        line=dict(color="blue", width=2, dash="dot"),
                    ),
                ]

                fig.layout.annotations = [
                    dict(
                        x=res.decision_date,
                        y=0.05,
                        xref="x",
                        yref="paper",
                        text="DECISION",
                        showarrow=False,
                        bgcolor="red",
                        font=dict(color="white"),
                    ),
                    dict(
                        x=res.buy_date,
                        y=1.0,
                        xref="x",
                        yref="paper",
                        text="ENTRY (T+1)",
                        showarrow=False,
                        bgcolor="blue",
                        font=dict(color="white"),
                    ),
                ]

            start_date = res.start_date.date()
            act_date = res.decision_date.date()
            entry_date = res.buy_date.date()

            # Liquidity Audit Print
            if (
                inputs.mode == "Ranking"
                and res.debug_data
                and "audit_liquidity" in res.debug_data
            ):
                audit = res.debug_data["audit_liquidity"]
                if audit:
                    raw_percentile = audit.get("percentile_setting", 0)
                    keep_pct = (
                        1 - raw_percentile
                    ) * 100  # Calculates the actual portion kept
                    cut_val = audit.get("final_cutoff_usd", 0)

                    print("-" * 60)
                    print(f"üîç LIQUIDITY CHECK (On Decision Date: {act_date})")
                    print(
                        f"   Universe Size: {audit.get('total_tickers_available')} tickers"
                    )
                    print(
                        f"   Liquidity Threshold: {raw_percentile*100:.0f}th Percentile"
                    )
                    print(f"   Action: Keeping the Top {keep_pct:.0f}% of Market")
                    print(f"   Calculated Cutoff: ${cut_val:,.0f} / day")
                    print(f"   Tickers Remaining: {audit.get('tickers_passed')}")
                    print("-" * 60)

            # --- UPDATED TIMELINE PRINT ---
            print(
                f"Timeline: Start [ {start_date} ] --> Decision [ {act_date} ] --> Cash (1d) --> Entry [ {entry_date} ] --> End [ {res.holding_end_date.date()} ]"
            )

            if inputs.mode == "Ranking":
                print(f"Ranked Tickers ({len(res.tickers)}):")
                for i in range(0, len(res.tickers), 10):
                    print(", ".join(res.tickers[i : i + 10]))
            else:
                print("Manual Portfolio Tickers:")
                for i in range(0, len(res.tickers), 10):
                    print(", ".join(res.tickers[i : i + 10]))

            m = res.perf_metrics

            # --- DRY UI GENERATION ---
            # 1. Define the metrics we want to display
            metrics_to_show = [
                ("Gain", "gain"),
                ("Sharpe", "sharpe"),
                ("Sharpe (ATR)", "sharpe_atr"),
                ("Sharpe (TRP)", "sharpe_trp"),  # <--- PINPOINT CHANGE: Add this line
            ]

            rows = []
            for label, key in metrics_to_show:
                p_row = {
                    "Metric": f"Group {label}",
                    "Full": m.get(f"full_p_{key}"),
                    "Lookback": m.get(f"lookback_p_{key}"),
                    "Holding": m.get(f"holding_p_{key}"),
                }
                b_row = {
                    "Metric": f"Benchmark {label}",
                    "Full": m.get(f"full_b_{key}"),
                    "Lookback": m.get(f"lookback_b_{key}"),
                    "Holding": m.get(f"holding_b_{key}"),
                }

                # Delta calculation
                d_row = {"Metric": f"== {label} Delta"}
                for col in ["Full", "Lookback", "Holding"]:
                    d_row[col] = (p_row[col] or 0) - (b_row[col] or 0)

                rows.extend([p_row, b_row, d_row])

            df_report = pd.DataFrame(rows).set_index("Metric")

            # --- 2. STYLING (The "Senior" Design) ---
            # --- 1. PREP DATA (Flattening the Index) ---
            # We convert the index to a column so "Metric" sits on the same row as other headers
            df_report = pd.DataFrame(rows)
            df_report = df_report.set_index("Metric")

            # --- 2. THE STYLING (Sleek & Proportional) ---
            def apply_sleek_style(styler):
                # Match notebook font size (usually 13px)
                styler.format("{:+.4f}", na_rep="N/A")

                # Dynamic Row Highlighting
                def row_logic(row):
                    if "Delta" in row.name:
                        return [
                            "background-color: #f9f9f9; font-weight: 600; border-top: 1px solid #ddd"
                        ] * len(row)
                    if "Group" in row.name:
                        return ["color: #2c5e8f; background-color: #fcfdfe"] * len(row)
                    return ["color: #555"] * len(
                        row
                    )  # Benchmark rows are slightly muted

                styler.apply(row_logic, axis=1)

                styler.set_table_styles(
                    [
                        # Base Table Font - Scaling down to match standard text
                        {
                            "selector": "",
                            "props": [
                                ("font-family", "inherit"),
                                ("font-size", "12px"),
                                ("border-collapse", "collapse"),
                                ("width", "auto"),
                                ("margin-left", "0"),
                            ],
                        },
                        # Header Row - Flattened and Muted
                        {
                            "selector": "th",
                            "props": [
                                ("background-color", "white"),
                                ("color", "#222"),
                                ("font-weight", "600"),
                                ("padding", "6px 12px"),
                                ("border-bottom", "2px solid #444"),
                                ("text-align", "center"),
                                (
                                    "vertical-align",
                                    "bottom",
                                ),  # Aligns 'Metric' with others
                            ],
                        },
                        # Index Column (The "Metric" labels)
                        {
                            "selector": "th.row_heading",
                            "props": [
                                ("text-align", "left"),
                                ("padding-right", "30px"),
                                ("border-bottom", "1px solid #eee"),
                            ],
                        },
                        # Cell Data - Tighter padding
                        {
                            "selector": "td",
                            "props": [
                                ("padding", "4px 12px"),
                                ("border-bottom", "1px solid #eee"),
                            ],
                        },
                        # Remove the extra "Index Name" row completely
                        {
                            "selector": "thead tr:nth-child(1) th",
                            "props": [("display", "table-cell")],
                        },
                    ]
                )

                # Hack to fix the 'Metric' alignment:
                # We remove the index name and set it as the horizontal label for the index
                styler.index.name = None

                return styler

            display(apply_sleek_style(df_report.style))

    update_button.on_click(update_plot)
    update_plot(None)
    display(ui, fig)
    return audit_pack  # <--- Return ONLY ONE


# ==============================================================================
# INTEGRITY PROTECTION: THE TRIPWIRE
# ==============================================================================


def verify_math_integrity():
    """
    üõ°Ô∏è TRIPWIRE: Ensures Sample Boundary Integrity.
    """
    print("\n--- üõ°Ô∏è Starting Final Integrity Audit ---")

    try:
        # Test 1: Series Input
        mock_series = pd.Series([100.0, 102.0, 101.0])
        rets_s = QuantUtils.compute_returns(mock_series)
        # Verify first value is actually NaN
        if not pd.isna(rets_s.iloc[0]):
            raise ValueError("Series Leading NaN missing")
        print("‚úÖ Series Boundary: OK")

        # Test 2: DataFrame Input
        mock_df = pd.DataFrame({"A": [100, 101], "B": [200, 202]})
        rets_df = QuantUtils.compute_returns(mock_df)
        if not rets_df.iloc[0].isna().all():
            raise ValueError("DataFrame Leading NaN missing")
        print("‚úÖ DataFrame Boundary: OK")

        print("‚úÖ AUDIT PASSED: Mathematical boundaries are strictly enforced.")
    except Exception as e:
        print(f"üî• SYSTEM BREACH: {str(e)}")
        raise e


def verify_feature_engineering_integrity():
    """
    üõ°Ô∏è TRIPWIRE: Validates Feature Engineering Logic.
    Enforces:
    1. Day 1 ATR must be NaN (No PrevClose).
    2. Wilder's Smoothing must use Alpha = 1/Period.
    3. Recursion must match manual calculation.
    """
    print("\n--- üõ°Ô∏è Starting Feature Engineering Audit ---")

    # 1. Create Synthetic Data (3 Days)
    # Day 1: High-Low = 10. No PrevClose.
    # Day 2: High-Low = 20. Gap up implies TR might be larger.
    # Day 3: High-Low = 10.
    dates = pd.to_datetime(["2020-01-01", "2020-01-02", "2020-01-03"])
    idx = pd.MultiIndex.from_product([["TEST"], dates], names=["Ticker", "Date"])

    df_mock = pd.DataFrame(
        {
            "Adj Open": [100, 110, 110],
            "Adj High": [110, 130, 120],
            "Adj Low": [100, 110, 110],
            "Adj Close": [105, 120, 115],  # PrevClose: NaN, 105, 120
            "Volume": [1000, 1000, 1000],
        },
        index=idx,
    )

    # 2. Run the Generator
    # We use Period=2 to make manual math easy (Alpha = 1/2 = 0.5)
    feats = generate_features(
        df_mock, atr_period=2, rsi_period=2, quality_min_periods=1
    )
    atr_series = feats["ATR"]

    # 3. MANUAL CALCULATION (The "Truth")
    # Day 1:
    #   TR = Max(H-L, |H-PC|, |L-PC|)
    #   TR = Max(10, NaN, NaN) -> NaN (Because skipna=False)
    #   Expected ATR: NaN

    # Day 2:
    #   PrevClose = 105
    #   H-L=20, |130-105|=25, |110-105|=5
    #   TR = 25
    #   Expected ATR: First valid observation = 25.0

    # Day 3:
    #   PrevClose = 120
    #   H-L=10, |120-120|=0, |110-120|=10
    #   TR = 10
    #   Wilder's Smoothing (Alpha=0.5):
    #   ATR_3 = (TR_3 * alpha) + (ATR_2 * (1-alpha))
    #   ATR_3 = (10 * 0.5) + (25 * 0.5) = 5 + 12.5 = 17.5

    print(f"Audit Values:\n{atr_series.values}")

    # 4. ASSERTIONS
    try:
        # Check Day 1
        if not np.isnan(atr_series.iloc[0]):
            raise AssertionError(
                f"Day 1 Regression: Expected NaN, got {atr_series.iloc[0]}. (Check skipna=False)"
            )

        # Check Day 2 (Initialization)
        if not np.isclose(atr_series.iloc[1], 25.0):
            raise AssertionError(
                f"Initialization Regression: Expected 25.0, got {atr_series.iloc[1]}."
            )

        # Check Day 3 (Recursion)
        if not np.isclose(atr_series.iloc[2], 17.5):
            raise AssertionError(
                f"Wilder's Logic Regression: Expected 17.5, got {atr_series.iloc[2]}. (Check Alpha=1/N)"
            )

        print("‚úÖ FEATURE INTEGRITY PASSED: Wilder's ATR logic is strictly enforced.")

    except AssertionError as e:
        print(f"üî• LOGIC FAILURE: {str(e)}")
        raise e


def verify_ranking_integrity():
    """
    üõ°Ô∏è TRIPWIRE: Prevents 'Momentum Collapse' in Volatility-Adjusted Ranking.
    Ensures that Sharpe(Vol) distinguishes between High-Vol and Low-Vol stocks.
    """
    print("--- üõ°Ô∏è Starting Ranking Kernel Audit ---")

    # 1. Setup Mock Universe (2 Tickers, 2 Days)
    # Ticker 'VOLATILE': 10% return, but 10% Volatility
    # Ticker 'STABLE': 2% return, but 1% Volatility (The 'Sharpe' Winner)
    data = {"VOLATILE": [1.0, 1.10], "STABLE": [1.0, 1.02]}  # +10%  # +2%
    df_returns = pd.DataFrame(data).pct_change().dropna()

    # Pre-calculated Mean Volatility per ticker (as provided by Engine Observation)
    vol_series = pd.Series({"VOLATILE": 0.10, "STABLE": 0.01})

    # 2. Run Kernel
    results = QuantUtils.calculate_sharpe_vol(df_returns, vol_series)

    # 3. CALCULATE EXPECTED (Pure Math)
    # Volatile Sharpe: 0.10 / 0.10 = 1.0
    # Stable Sharpe:   0.02 / 0.01 = 2.0

    try:
        # Check A: Diversity. If they are the same, normalization didn't happen.
        if np.isclose(results["VOLATILE"], results["STABLE"]):
            raise AssertionError(
                "RANKING COLLAPSE: Both tickers have the same normalized score."
            )

        # Check B: Direction. STABLE must rank higher than VOLATILE.
        if results["STABLE"] < results["VOLATILE"]:
            # This is exactly what happens when the bug turns it into Momentum
            raise AssertionError(
                f"MOMENTUM REGRESSION: 'STABLE' ({results['STABLE']:.2f}) "
                f"ranked below 'VOLATILE' ({results['VOLATILE']:.2f}). "
                "The denominator was likely collapsed to a market average."
            )

        # Check C: Absolute Precision
        if not np.isclose(results["STABLE"], 2.0):
            raise AssertionError(
                f"MATH ERROR: Expected 2.0 for STABLE, got {results['STABLE']}"
            )

        print(
            "‚úÖ RANKING INTEGRITY PASSED: Volatility normalization is strictly enforced."
        )

    except Exception as e:
        print(f"üî• KERNEL BREACH: {str(e)}")
        raise e


# Auto-run the checks
verify_math_integrity()

verify_feature_engineering_integrity()

verify_ranking_integrity()

#


--- üõ°Ô∏è Starting Final Integrity Audit ---
‚úÖ Series Boundary: OK
‚úÖ DataFrame Boundary: OK
‚úÖ AUDIT PASSED: Mathematical boundaries are strictly enforced.

--- üõ°Ô∏è Starting Feature Engineering Audit ---
‚ö° Generating SOTA Quant Features (Benchmark: SPY)...
Audit Values:
[ nan 25.  17.5]
‚úÖ FEATURE INTEGRITY PASSED: Wilder's ATR logic is strictly enforced.
--- üõ°Ô∏è Starting Ranking Kernel Audit ---
‚úÖ RANKING INTEGRITY PASSED: Volatility normalization is strictly enforced.


In [2]:
# ==============================================================================
# SECTION F: UTILITIES
# ==============================================================================


def export_debug_to_csv(audit_pack, source_label="Audit"):
    """
    High-Transparency Exporter (Hardened Version).
    Dumps the entire simulation state into a folder for manual Excel verification.
    """
    if not audit_pack or not audit_pack[0]:
        print("‚ùå Error: Audit Pack is empty. Run a simulation first.")
        return

    data = audit_pack[0]
    # Handle the fact that 'inputs' might be a key or a dataclass attribute
    inputs = data.get("inputs")

    # 1. Folder Setup
    date_str = inputs.start_date.strftime("%Y-%m-%d")
    strat = inputs.metric.replace(" ", "").replace("(", "").replace(")", "")
    folder_name = f"{source_label}_{strat}_{date_str}"

    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    print(f"üìÇ [AUDIT EXPORT] Folder: ./{folder_name}/")

    def process_item(item, path_prefix=""):
        # A. Handle Nested Dicts
        if isinstance(item, dict):
            for k, v in item.items():
                process_item(v, f"{path_prefix}{k}_" if path_prefix else f"{k}_")

        # B. Handle DataFrames (Matrices - High Precision)
        elif isinstance(item, pd.DataFrame):
            fn = f"Matrix_{path_prefix.strip('_')}.csv"
            item.to_csv(os.path.join(folder_name, fn), float_format="%.8f")
            print(f"   ‚úÖ Matrix: {fn}")

        # C. Handle Series (Vectors)
        elif isinstance(item, pd.Series):
            fn = f"Vector_{path_prefix.strip('_')}.csv"
            item.to_frame().to_csv(os.path.join(folder_name, fn), float_format="%.8f")
            print(f"   ‚úÖ Vector: {fn}")

        # D. Handle Dataclasses (Metadata & Results)
        elif is_dataclass(item):
            class_name = item.__class__.__name__
            fn = f"Summary_{class_name}_{path_prefix.strip('_')}".strip("_") + ".csv"

            # --- THE FIX: Create a Safe Dictionary for Pandas ---
            raw_dict = asdict(item)
            summary_ready_dict = {}

            for k, v in raw_dict.items():
                # If it's a big data object, just note its existence in the summary
                if isinstance(v, (pd.DataFrame, pd.Series)):
                    summary_ready_dict[k] = f"<{v.__class__.__name__} shape={v.shape}>"
                # If it's a list or dict (the crash cause), stringify it for Excel
                elif isinstance(v, (list, dict)):
                    summary_ready_dict[k] = str(v)
                else:
                    summary_ready_dict[k] = v

            # Save the clean key-value summary
            pd.DataFrame.from_dict(
                summary_ready_dict, orient="index", columns=["Value"]
            ).to_csv(os.path.join(folder_name, fn))
            print(f"   üìë Summary: {fn}")

            # E. RECURSION: Now find the actual DataFrames inside the dataclass
            # We iterate the object attributes directly to avoid the 'asdict' list confusion
            for k in item.__dataclass_fields__.keys():
                val = getattr(item, k)
                if isinstance(val, (pd.DataFrame, pd.Series, dict)):
                    process_item(val, f"{path_prefix}{k}_")

    # 3. Execute Extraction
    process_item(data)
    print(f"\n‚ú® Export Complete. Open ./{folder_name}/ to verify results.")


def export_audit_to_excel(audit_pack, filename="Audit_Verification_Report.xlsx"):
    """
    Consolidates the audit_pack into a multi-sheet Excel workbook.
    Organizes data by shared axes (Date vs Ticker) for manual formula checking.
    """
    if not audit_pack or not audit_pack[0]:
        print("‚ùå Error: Audit Pack is empty.")
        return

    data = audit_pack[0]
    res = data["results"]
    inputs = data["inputs"]
    debug = data.get("debug", {})

    print(f"üìÇ [EXCEL AUDIT] Creating Report: {filename}")

    with pd.ExcelWriter(filename, engine="openpyxl") as writer:

        # --- SHEET 1: OVERVIEW (The Settings & Final Totals) ---
        # Combines Input settings and Result scalars into one vertical table
        meta_dict = {
            **asdict(inputs),
            **{
                k: v
                for k, v in asdict(res).items()
                if not isinstance(v, (pd.DataFrame, pd.Series, dict))
            },
        }
        # Stringify lists/dicts to prevent Excel/Pandas export crashes
        clean_meta = {
            k: (str(v) if isinstance(v, (list, dict)) else v)
            for k, v in meta_dict.items()
        }

        df_overview = pd.DataFrame.from_dict(
            clean_meta, orient="index", columns=["Value"]
        )
        df_overview.to_excel(writer, sheet_name="OVERVIEW")

        # --- SHEET 2: DAILY_AUDIT (Axis = Date) ---
        # Concatenates everything that happens day-by-day
        daily_items = {
            "Port_Value": res.portfolio_series,
            "Port_Ret": QuantUtils.compute_returns(res.portfolio_series),
            "Port_TRP": res.portfolio_trp_series,
            "Port_ATRP": getattr(
                res, "portfolio_atrp_series", None
            ),  # Optional if added
            "Bench_Value": res.benchmark_series,
            "Bench_Ret": QuantUtils.compute_returns(res.benchmark_series),
            "Bench_TRP": res.benchmark_trp_series,
        }
        # Filter out None values and concatenate side-by-side
        df_daily = pd.concat(
            {k: v for k, v in daily_items.items() if v is not None}, axis=1
        )
        df_daily.to_excel(writer, sheet_name="DAILY_AUDIT", float_format="%.8f")

        # --- SHEET 3: SELECTION_SNAPSHOT (Axis = Ticker) ---
        # Focuses on the selected 10-20 tickers and their performance
        if "full_universe_ranking" in debug:
            df_rank = debug["full_universe_ranking"]
            # Filter the leaderboard for only the tickers we actually bought
            df_composition = df_rank.reindex(res.tickers)
            df_composition.to_excel(
                writer, sheet_name="PORTFOLIO_SNAPSHOT", float_format="%.8f"
            )

        # --- SHEET 4: FULL_UNIVERSE_RANKING ---
        if "full_universe_ranking" in debug:
            debug["full_universe_ranking"].to_excel(
                writer, sheet_name="FULL_RANKING", float_format="%.8f"
            )

        # --- SHEET 5: RAW_PRICES_MATRIX ---
        if "portfolio_raw_components" in debug:
            raw_p = debug["portfolio_raw_components"].get("prices")
            if raw_p is not None:
                raw_p.to_excel(writer, sheet_name="RAW_PRICES", float_format="%.8f")

        # --- SHEET 6: RAW_VOL_MATRIX (TRP) ---
        if "portfolio_raw_components" in debug:
            # Re-extracting TRP matrix for the specific tickers
            raw_v = debug["portfolio_raw_components"].get(
                "atrp"
            )  # Or trp if stored specifically
            if raw_v is not None:
                raw_v.to_excel(writer, sheet_name="RAW_VOL_DATA", float_format="%.8f")

    print(f"‚ú® Audit Report Complete. Manual verification ready in {filename}")


def print_nested(d, indent=0, width=4):
    """Pretty-print nested containers.
    Leaves are rendered as two lines:  key\\nvalue ."""
    spacing = " " * indent

    def _kind(node):
        if not isinstance(node, dict):
            return None
        return "sep" if all(isinstance(v, dict) for v in node.values()) else "nest"

    if isinstance(d, dict):
        for k, v in d.items():
            kind = _kind(v)
            tag = "" if kind is None else f"  [{'SEP' if kind == 'sep' else 'NEST'}]"
            print(f"{spacing}{k}{tag}")
            print_nested(v, indent + width, width)

    elif isinstance(d, (list, tuple)):
        for idx, item in enumerate(d):
            print(f"{spacing}[{idx}]")
            print_nested(item, indent + width, width)

    else:  # leaf ‚Äì primitive value
        print(f"{spacing}{d}")


def get_ticker_OHLCV(
    df_ohlcv: pd.DataFrame,
    tickers: Union[str, List[str]],
    date_start: str,
    date_end: str,
    return_format: str = "dataframe",
    verbose: bool = True,
) -> Union[pd.DataFrame, dict]:
    """
    Get OHLCV data for specified tickers within a date range.

    Parameters
    ----------
    df_ohlcv : pd.DataFrame
        DataFrame with MultiIndex of (ticker, date) and OHLCV columns
    tickers : str or list of str
        Ticker symbol(s) to retrieve
    date_start : str
        Start date in 'YYYY-MM-DD' format
    date_end : str
        End date in 'YYYY-MM-DD' format
    return_format : str, optional
        Format to return data in. Options:
        - 'dataframe': Single DataFrame with MultiIndex (default)
        - 'dict': Dictionary with tickers as keys and DataFrames as values
        - 'separate': List of separate DataFrames for each ticker
    verbose : bool, optional
        Whether to print summary information (default: True)

    Returns
    -------
    Union[pd.DataFrame, dict, list]
        Filtered OHLCV data in specified format

    Raises
    ------
    ValueError
        If input parameters are invalid
    KeyError
        If tickers not found in DataFrame

    Examples
    --------
    >>> # Get data for single ticker
    >>> vlo_data = get_ticker_OHLCV(df_ohlcv, 'VLO', '2025-08-13', '2025-09-04')

    >>> # Get data for multiple tickers
    >>> multi_data = get_ticker_OHLCV(df_ohlcv, ['VLO', 'JPST'], '2025-08-13', '2025-09-04')

    >>> # Get data as dictionary
    >>> data_dict = get_ticker_OHLCV(df_ohlcv, ['VLO', 'JPST'], '2025-08-13',
    ...                              '2025-09-04', return_format='dict')
    """

    # Input validation
    if not isinstance(df_ohlcv, pd.DataFrame):
        raise TypeError("df_ohlcv must be a pandas DataFrame")

    if not isinstance(df_ohlcv.index, pd.MultiIndex):
        raise ValueError("DataFrame must have MultiIndex of (ticker, date)")

    if len(df_ohlcv.index.levels) != 2:
        raise ValueError("MultiIndex must have exactly 2 levels: (ticker, date)")

    # Convert single ticker to list for consistent processing
    if isinstance(tickers, str):
        tickers = [tickers]
    elif not isinstance(tickers, list):
        raise TypeError("tickers must be a string or list of strings")

    # Convert dates to Timestamps
    try:
        start_date = pd.Timestamp(date_start)
        end_date = pd.Timestamp(date_end)
    except ValueError as e:
        raise ValueError(f"Invalid date format. Use 'YYYY-MM-DD': {e}")

    if start_date > end_date:
        raise ValueError("date_start must be before or equal to date_end")

    # Check if tickers exist in the DataFrame
    available_tickers = df_ohlcv.index.get_level_values(0).unique()
    missing_tickers = [t for t in tickers if t not in available_tickers]

    if missing_tickers:
        raise KeyError(f"Ticker(s) not found in DataFrame: {missing_tickers}")

    # Filter the data using MultiIndex slicing
    try:
        filtered_data = df_ohlcv.loc[(tickers, slice(date_start, date_end)), :]
    except Exception as e:
        raise ValueError(f"Error filtering data: {e}")

    # Handle empty results
    if filtered_data.empty:
        if verbose:
            print(
                f"No data found for tickers {tickers} in date range {date_start} to {date_end}"
            )
        return filtered_data

    # Print summary if verbose
    if verbose:
        print(
            f"Data retrieved for {len(tickers)} ticker(s) from {date_start} to {date_end}"
        )
        print(f"Total rows: {len(filtered_data)}")
        print(
            f"Date range in data: {filtered_data.index.get_level_values(1).min()} to "
            f"{filtered_data.index.get_level_values(1).max()}"
        )

        # Print ticker-specific counts
        ticker_counts = filtered_data.index.get_level_values(0).value_counts()
        for ticker in tickers:
            count = ticker_counts.get(ticker, 0)
            if count > 0:
                print(f"  {ticker}: {count} rows")
            else:
                print(f"  {ticker}: No data in range")

    # Return in requested format
    if return_format == "dict":
        result = {}
        for ticker in tickers:
            try:
                result[ticker] = filtered_data.xs(ticker, level=0).loc[
                    date_start:date_end
                ]
            except KeyError:
                result[ticker] = pd.DataFrame()
        return result

    elif return_format == "separate":
        result = []
        for ticker in tickers:
            try:
                result.append(
                    filtered_data.xs(ticker, level=0).loc[date_start:date_end]
                )
            except KeyError:
                result.append(pd.DataFrame())
        return result

    elif return_format == "dataframe":
        return filtered_data

    else:
        raise ValueError(
            f"Invalid return_format: {return_format}. "
            f"Must be 'dataframe', 'dict', or 'separate'"
        )


def get_ticker_features(
    features_df: pd.DataFrame,
    tickers: Union[str, List[str]],
    date_start: str,
    date_end: str,
    return_format: str = "dataframe",
    verbose: bool = True,
) -> Union[pd.DataFrame, dict]:
    """
    Get features data for specified tickers within a date range.

    Parameters
    ----------
    features_df : pd.DataFrame
        DataFrame with MultiIndex of (ticker, date) and feature columns
    tickers : str or list of str
        Ticker symbol(s) to retrieve
    date_start : str
        Start date in 'YYYY-MM-DD' format
    date_end : str
        End date in 'YYYY-MM-DD' format
    return_format : str, optional
        Format to return data in. Options:
        - 'dataframe': Single DataFrame with MultiIndex (default)
        - 'dict': Dictionary with tickers as keys and DataFrames as values
        - 'separate': List of separate DataFrames for each ticker
    verbose : bool, optional
        Whether to print summary information (default: True)

    Returns
    -------
    Union[pd.DataFrame, dict, list]
        Filtered features data in specified format
    """
    # Convert single ticker to list for consistent processing
    if isinstance(tickers, str):
        tickers = [tickers]

    # Filter the data using MultiIndex slicing
    try:
        filtered_data = features_df.loc[(tickers, slice(date_start, date_end)), :]
    except Exception as e:
        if verbose:
            print(f"Error filtering data: {e}")
        return pd.DataFrame() if return_format == "dataframe" else {}

    # Handle empty results
    if filtered_data.empty:
        if verbose:
            print(
                f"No data found for tickers {tickers} in date range {date_start} to {date_end}"
            )
        return filtered_data

    # Print summary if verbose
    if verbose:
        print(
            f"Features data retrieved for {len(tickers)} ticker(s) from {date_start} to {date_end}"
        )
        print(f"Total rows: {len(filtered_data)}")
        print(
            f"Date range in data: {filtered_data.index.get_level_values(1).min()} to "
            f"{filtered_data.index.get_level_values(1).max()}"
        )
        print(f"Available features: {', '.join(filtered_data.columns.tolist())}")

        # Print ticker-specific counts
        ticker_counts = filtered_data.index.get_level_values(0).value_counts()
        for ticker in tickers:
            count = ticker_counts.get(ticker, 0)
            if count > 0:
                print(f"  {ticker}: {count} rows")
            else:
                print(f"  {ticker}: No data in range")

    # Return in requested format
    if return_format == "dict":
        result = {}
        for ticker in tickers:
            try:
                result[ticker] = filtered_data.xs(ticker, level=0).loc[
                    date_start:date_end
                ]
            except KeyError:
                result[ticker] = pd.DataFrame()
        return result

    elif return_format == "separate":
        result = []
        for ticker in tickers:
            try:
                result.append(
                    filtered_data.xs(ticker, level=0).loc[date_start:date_end]
                )
            except KeyError:
                result.append(pd.DataFrame())
        return result

    elif return_format == "dataframe":
        return filtered_data

    else:
        raise ValueError(
            f"Invalid return_format: {return_format}. "
            f"Must be 'dataframe', 'dict', or 'separate'"
        )


def create_combined_dict(
    df_ohlcv: pd.DataFrame,
    features_df: pd.DataFrame,
    tickers: Union[str, List[str]],
    date_start: str,
    date_end: str,
    verbose: bool = True,
) -> dict:
    """
    Create a combined dictionary with both OHLCV and features data for each ticker.

    Parameters:
    -----------
    df_ohlcv : pd.DataFrame
        DataFrame with OHLCV data (MultiIndex: ticker, date)
    features_df : pd.DataFrame
        DataFrame with features data (MultiIndex: ticker, date)
    tickers : str or list of str
        Ticker symbol(s) to retrieve
    date_start : str
        Start date in 'YYYY-MM-DD' format
    date_end : str
        End date in 'YYYY-MM-DD' format
    verbose : bool, optional
        Whether to print progress information (default: True)

    Returns:
    --------
    dict
        Dictionary with tickers as keys and combined DataFrames (OHLCV + features) as values
    """
    # Convert single ticker to list
    if isinstance(tickers, str):
        tickers = [tickers]

    if verbose:
        print(f"Creating combined dictionary for {len(tickers)} ticker(s)")
        print(f"Date range: {date_start} to {date_end}")
        print("=" * 60)

    # Get OHLCV data as dictionary
    ohlcv_dict = get_ticker_OHLCV(
        df_ohlcv, tickers, date_start, date_end, return_format="dict", verbose=verbose
    )

    # Get features data as dictionary
    features_dict = get_ticker_features(
        features_df,
        tickers,
        date_start,
        date_end,
        return_format="dict",
        verbose=verbose,
    )

    # Create combined_dict
    combined_dict = {}

    for ticker in tickers:
        if verbose:
            print(f"\nProcessing {ticker}...")

        # Check if ticker exists in both dictionaries
        if ticker in ohlcv_dict and ticker in features_dict:
            ohlcv_data = ohlcv_dict[ticker]
            features_data = features_dict[ticker]

            # Check if both dataframes have data
            if not ohlcv_data.empty and not features_data.empty:
                # Combine OHLCV and features data
                # Note: Both dataframes have the same index (dates), so we can concatenate
                combined_df = pd.concat([ohlcv_data, features_data], axis=1)

                # Ensure proper index naming
                combined_df.index.name = "Date"

                # Store in combined_dict
                combined_dict[ticker] = combined_df

                if verbose:
                    print(f"  ‚úì Successfully combined data")
                    print(f"  OHLCV shape: {ohlcv_data.shape}")
                    print(f"  Features shape: {features_data.shape}")
                    print(f"  Combined shape: {combined_df.shape}")
                    print(
                        f"  Date range: {combined_df.index.min()} to {combined_df.index.max()}"
                    )
            else:
                if verbose:
                    print(f"  ‚úó Cannot combine: One or both dataframes are empty")
                    print(f"    OHLCV empty: {ohlcv_data.empty}")
                    print(f"    Features empty: {features_data.empty}")
                combined_dict[ticker] = pd.DataFrame()
        else:
            if verbose:
                print(f"  ‚úó Ticker not found in both dictionaries")
                if ticker not in ohlcv_dict:
                    print(f"    Not in OHLCV data")
                if ticker not in features_dict:
                    print(f"    Not in features data")
            combined_dict[ticker] = pd.DataFrame()

    # Print summary
    if verbose:
        print("\n" + "=" * 60)
        print("SUMMARY")
        print("=" * 60)
        print(f"Total tickers processed: {len(tickers)}")

        tickers_with_data = [
            ticker for ticker, df in combined_dict.items() if not df.empty
        ]
        print(f"Tickers with combined data: {len(tickers_with_data)}")

        if tickers_with_data:
            print("\nTicker details:")
            for ticker in tickers_with_data:
                df = combined_dict[ticker]
                print(f"  {ticker}: {df.shape} - {df.index.min()} to {df.index.max()}")
                print(f"    Columns: {len(df.columns)}")

        empty_tickers = [ticker for ticker, df in combined_dict.items() if df.empty]
        if empty_tickers:
            print(f"\nTickers with no data: {', '.join(empty_tickers)}")

    return combined_dict


#

In [3]:
data_path = r"c:\Users\ping\Files_win10\python\py311\stocks\data\df_indices.parquet"

df_indices = pd.read_parquet(data_path, engine="pyarrow")

In [4]:
df_indices.info()

<class 'pandas.core.frame.DataFrame'>
MultiIndex: 144059 entries, ('^AXJO', Timestamp('1992-11-22 00:00:00')) to ('^VIX3M', Timestamp('2025-12-31 00:00:00'))
Data columns (total 5 columns):
 #   Column     Non-Null Count   Dtype  
---  ------     --------------   -----  
 0   Adj Open   144059 non-null  float64
 1   Adj High   144059 non-null  float64
 2   Adj Low    144059 non-null  float64
 3   Adj Close  144059 non-null  float64
 4   Volume     144059 non-null  int64  
dtypes: float64(4), int64(1)
memory usage: 6.5+ MB


In [5]:
data_path = (
    r"c:\Users\ping\Files_win10\python\py311\stocks\data\df_OHLCV_stocks_etfs.parquet"
)

df_ohlcv = pd.read_parquet(data_path, engine="pyarrow")

In [None]:
df_ohlcv.info()

In [None]:
# ==============================================================================
# DATA PRE-COMPUTATION (The "Fast-Track" Setup)
# ==============================================================================

print("Calculating features... this might take about 3 minutes...")
features_df = generate_features(
    df_ohlcv=df_ohlcv,
    atr_period=GLOBAL_SETTINGS["atr_period"],
    quality_window=GLOBAL_SETTINGS["quality_window"],
    quality_min_periods=GLOBAL_SETTINGS["quality_min_periods"],
)

print("üöÄ Generating Wide Matrices for Instant Backtesting...")

# 1. Price Matrix
df_close_wide = df_ohlcv["Adj Close"].unstack(level=0)

# 2. Volatility Matrices (Unstack and Align)
# Using reindex_like ensures Dates and Tickers match df_close_wide exactly
print("   - Unstacking ATRP...")
df_atrp_wide = features_df["ATRP"].unstack(level=0).reindex_like(df_close_wide)

print("   - Unstacking TRP...")
df_trp_wide = features_df["TRP"].unstack(level=0).reindex_like(df_close_wide)

# 3. Handle Data Gaps (Sanitize the Wide Matrices)
# This prevents NaN propagation during matrix multiplication
if GLOBAL_SETTINGS["handle_zeros_as_nan"]:
    df_close_wide = df_close_wide.replace(0, np.nan)

# Forward fill up to the limit, then fill remaining with the "Disaster Detection" value
df_close_wide = df_close_wide.ffill(limit=GLOBAL_SETTINGS["max_data_gap_ffill"])
df_close_wide = df_close_wide.fillna(GLOBAL_SETTINGS["nan_price_replacement"])

print(
    "‚úÖ Pre-computation Complete. df_close_wide, df_atrp_wide, and df_trp_wide are ready."
)

In [None]:
_ticker = "GOOG"
pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 500)
pd.set_option("display.width", 1000)  # keep lines from wrapping

In [None]:
df_indices.info()

In [None]:
_idx = "SPY"
print(df_ohlcv.loc[_idx].head(50).to_csv(index=True))

# save the slice as a real CSV file
start, end = "2004-08-19", "2025-12-30"
# df_ohlcv.loc[_idx, start:end].to_csv(f"{_idx}.csv", index=True)
# df_ohlcv.loc[_idx].to_csv(f"{_idx}.csv", index=True)
(df_ohlcv.loc[_idx].loc["2004-08-19":"2025-12-30"].to_csv(f"_{_idx}.csv", index=True))

In [None]:
print(df_ohlcv.loc[_ticker].head(50).to_csv(index=True))
# save the slice as a real CSV file
df_ohlcv.loc[_ticker].to_csv(f"_{_ticker}.csv", index=True)

In [None]:
features_df.loc[_ticker].to_csv(f"_{_ticker}_features.csv", index=True)

In [None]:
print(features_df.loc[_ticker].head(70), "\n")
# print(features_df.loc[_ticker].tail(70))

In [None]:
print(f" features_df:\n{features_df}\n")
features_df.info()

In [None]:
print(f"df_atrp_wide:\n{df_atrp_wide}\n")
df_atrp_wide.info()

In [None]:
audit_pack = plot_walk_forward_analyzer(
    df_ohlcv=df_ohlcv,
    features_df=features_df,
    df_close_wide=df_close_wide,
    df_atrp_wide=df_atrp_wide,
    df_trp_wide=df_trp_wide,  # <--- Update your class to accept this
    default_start_date="2025-01-17",
    default_lookback=10,
    default_holding=5,
    default_strategy="Sharpe (ATR)",
    default_rank_start=1,
    default_rank_end=10,
    default_benchmark_ticker=GLOBAL_SETTINGS["benchmark_ticker"],
    master_calendar_ticker=GLOBAL_SETTINGS["calendar_ticker"],
    quality_thresholds=GLOBAL_SETTINGS["thresholds"],
    debug=True,
)

#####  Output Debug Data to csv Files

In [None]:
export_debug_to_csv(audit_pack, source_label="_Audit_bot_v52")

In [None]:
export_audit_to_excel(audit_pack)

#####  Get Subset Data, Copy Cell Output and Paste into Excel with 'Import Wizard'

In [None]:
_ticker = "SPY"
_start_date = "2025-01-02"
_end_date = "2025-01-28"

_df = df_ohlcv.loc[_ticker][_start_date:_end_date]
print(_df.to_csv())

_df = features_df.loc[_ticker][_start_date:_end_date]
print(_df.to_csv())