<a href="https://colab.research.google.com/github/racoope70/exploratory_daytrading/blob/main/multi_stock_sac_inference_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Fix Potential Library Conflicts
!apt-get remove --purge -y cuda* libcuda* nvidia* || echo "No conflicting CUDA packages"
!apt-get autoremove -y
!apt-get clean

In [None]:
#Protocol Buffer Fix (for TensorFlow)
!pip uninstall -y protobuf
!pip install protobuf==3.20.3

In [None]:
#Update Colab Environment and System Libraries
!apt-get update -y && apt-get upgrade -y


In [None]:
#Install Correct Version of CUDA for Colab GPU
!apt-get update -qq && apt-get install -y \
    libcusolver11 libcusparse11 libcurand10 libcufft10 libnppig10 libnppc10 libnppial10 \
    cuda-toolkit-12-4

In [None]:
#Set Correct CUDA Paths
import os
os.environ['CUDA_HOME'] = '/usr/local/cuda-12.4'
os.environ['PATH'] += ':/usr/local/cuda-12.4/bin'
os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.4/lib64'


In [None]:
#Install RAPIDS and NVIDIA Dependencies
!pip install --extra-index-url=https://pypi.nvidia.com \
    cuml-cu12==25.2.0 cudf-cu12==25.2.0 cupy-cuda12x dask-cuda==25.2.0 dask-cudf-cu12==25.2.0


In [None]:
#Install TensorFlow (latest GPU-compatible version)
!pip install tensorflow==2.18.0

#Install Stable Baselines3 and Trading Libraries
!pip install stable-baselines3[extra] gymnasium gym-anytrading yfinance xgboost joblib

#Install Miscellaneous Libraries
!pip install matplotlib scikit-learn pandas numba==0.61.0

#Install PyTorch with GPU Support
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124


In [None]:
# In a notebook cell:
!pip uninstall -y jax jaxlib

# Then restart the runtime/kernel, and import TF normally:
import tensorflow as tf


In [None]:
#!rm -rf /content/drive

In [None]:
# ---------- Imports & Logging ----------
import os, glob, gc, time, json, logging, warnings, re, sys
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
import yfinance as yf
import pywt
import warnings
warnings.filterwarnings("ignore", message=".*Gym has been unmaintained.*", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="jupyter_client.session")


def setup_logger(name="download_fe", level=logging.INFO):
    logger = logging.getLogger(name)
    logger.setLevel(level)
    # avoid duplicate handlers on re-run
    logger.handlers = []
    fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s")

    sh = logging.StreamHandler(sys.stdout)
    sh.setFormatter(fmt); sh.setLevel(level)
    logger.addHandler(sh)

    os.makedirs("./sac_parity/logs", exist_ok=True)
    fh = logging.FileHandler(os.path.join("./sac_parity/logs", f"{name}.log"))
    fh.setFormatter(fmt); fh.setLevel(level)
    logger.addHandler(fh)
    return logger

log = setup_logger("download_fe", level=logging.INFO)
warnings.filterwarnings("ignore", category=FutureWarning)
log.info("yfinance=%s | pandas=%s", getattr(yf, "__version__", "unknown"), pd.__version__)

# ---------- Paths ----------
BASE_DIR   = "./sac_parity"
DATA_DIR   = os.path.join(BASE_DIR, "data")
LOG_DIR    = os.path.join(BASE_DIR, "logs")
for d in (BASE_DIR, DATA_DIR, LOG_DIR):
    os.makedirs(d, exist_ok=True)

# Try Google Drive (Colab); else local CWD
try:
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive', force_remount=False)
    DRIVE_BASE = "/content/drive/MyDrive"
    log.info("Mounted Google Drive at /content/drive")
except Exception:
    DRIVE_BASE = os.getcwd()
    log.info("Google Drive not available; using local working directory")

RESULTS_DIR = os.path.join(DRIVE_BASE, "Results_May_2025", "results_sac_walkforward")
TRADING_DIR = os.path.join(DRIVE_BASE, "trading_data")
for d in (RESULTS_DIR, TRADING_DIR):
    os.makedirs(d, exist_ok=True)

print("RESULTS_DIR =", RESULTS_DIR)
print("Exists?", os.path.exists(RESULTS_DIR))
print("Listdir sample:", os.listdir(RESULTS_DIR)[:10] if os.path.exists(RESULTS_DIR) else "N/A")

csvs = sorted(glob.glob(os.path.join(RESULTS_DIR, "*_sac_ws_sweep_summary.csv")))
print("Matched:", [os.path.basename(p) for p in csvs])

for p in csvs:
    df = pd.read_csv(p)
    print(os.path.basename(p), "rows:", len(df))

# Primary output filenames (three mirrors)
FEATURE_CSV_RESULTS = os.path.join(RESULTS_DIR, "multi_stock_feature_engineered_dataset.csv")
FEATURE_CSV_TRADING = os.path.join(TRADING_DIR, "multi_stock_feature_engineered_dataset.csv")
FEATURE_CSV_LOCAL   = "multi_stock_feature_engineered_dataset.csv"

# Optional Parquet mirrors (local)
PARQ_FULL  = "features_full.parquet"
PARQ_TRAIN = "train.parquet"
PARQ_VAL   = "val.parquet"

# ---------- Config / Toggles ----------
USE_SENTIMENT = False   # placeholder; off by default
USE_REGIME    = True
USE_WAVELET   = True
USE_EVAL_CALLBACK = False
FORCE_RETRAIN = True

# Data params
INTERVAL     = os.getenv("INTERVAL", "1h")
PERIOD_DAYS  = int(os.getenv("PERIOD_DAYS", "720"))  # ~2 years by default

# If you want to prep for a 1y train + 1y test split later, ensure ~744 days
TRAIN_DAYS_TARGET = 365
TEST_DAYS_TARGET  = 365
BUFFER_DAYS       = 14
MIN_PERIOD_FOR_TRAINER = TRAIN_DAYS_TARGET + TEST_DAYS_TARGET + BUFFER_DAYS  # ~744

# --- Yahoo intraday maximum window helper ---
def _max_days_for_interval(interval: str) -> int:
    """Yahoo intraday windows are restricted; daily/weekly can go further."""
    intraday_caps = {
        "1m": 30, "2m": 60, "5m": 60, "15m": 60, "30m": 60,
        "60m": 730, "90m": 60, "1h": 730
    }
    return intraday_caps.get(interval.lower(), 3650)


TEST_MODE = os.getenv("TEST_MODE", "1").strip().lower() not in ("0", "false", "no", "")
PREP_FOR_SAC = os.getenv("PREP_FOR_SAC", "1").lower() not in ("0", "false", "no", "")
if PREP_FOR_SAC and PERIOD_DAYS < MIN_PERIOD_FOR_TRAINER:
    log.info("Bumping PERIOD_DAYS %d → %d to support ~1y train / ~1y test", PERIOD_DAYS, MIN_PERIOD_FOR_TRAINER)
    PERIOD_DAYS = MIN_PERIOD_FOR_TRAINER

# --- Cap for Yahoo intraday data limits (avoid >730 days error) ---
_cap = _max_days_for_interval(INTERVAL)
if PERIOD_DAYS > _cap:
    log.warning("PERIOD_DAYS=%d exceeds Yahoo limit (%d) for interval='%s'; capping.",
                PERIOD_DAYS, _cap, INTERVAL)
    PERIOD_DAYS = _cap

# ---------- Universe ----------
TICKERS_ALL = [
    'AAPL','TSLA','MSFT','GOOGL','AMZN','NVDA','META','BRK-B','JPM','JNJ',
    'XOM','V','PG','UNH','MA','HD','LLY','MRK','PEP','KO',
    'BAC','ABBV','AVGO','PFE','COST','CSCO','TMO','ABT','ACN','WMT',
    'MCD','ADBE','DHR','CRM','NKE','INTC','QCOM','NEE','AMD','TXN',
    'AMGN','UPS','LIN','PM','UNP','BMY','LOW','RTX','CVX','IBM',
    'GE','SBUX','ORCL'
]

DEFAULT_TEST_TICKERS = ['AAPL','NVDA','MSFT']
RUN_SYMBOLS_ENV  = [s.strip().upper() for s in os.getenv("RUN_SYMBOLS", "").split(",") if s.strip()]
TEST_TICKERS_ENV = [s.strip().upper() for s in os.getenv("TEST_TICKERS", "").split(",") if s.strip()]

if RUN_SYMBOLS_ENV:
    SYMBOLS = RUN_SYMBOLS_ENV
    log.info("RUN_SYMBOLS override → %s", SYMBOLS)
elif TEST_MODE:
    SYMBOLS = TEST_TICKERS_ENV if TEST_TICKERS_ENV else DEFAULT_TEST_TICKERS
    log.info("TEST_MODE=on → symbols=%s | period_days=%d", SYMBOLS, PERIOD_DAYS)
else:
    SYMBOLS = TICKERS_ALL
    log.info("TEST_MODE=off → %d symbols | period_days=%d", len(SYMBOLS), PERIOD_DAYS)

# ---------- Helpers: atomic save + verification ----------
def _ensure_dir(path: str):
    d = os.path.dirname(path)
    if d and not os.path.exists(d):
        os.makedirs(d, exist_ok=True)

def save_csv_atomically(df: pd.DataFrame, dest_path: str, max_wait_s: float = 3.0):
    """Write via dest.tmp then os.replace(); verify non-zero size (Colab/Drive-safe)."""
    _ensure_dir(dest_path)
    tmp_path = dest_path + ".tmp"
    df.to_csv(tmp_path, index=False)
    os.replace(tmp_path, dest_path)
    t0 = time.time()
    while (not os.path.exists(dest_path) or os.path.getsize(dest_path) == 0) and (time.time() - t0 < max_wait_s):
        time.sleep(0.2)
    assert os.path.exists(dest_path) and os.path.getsize(dest_path) > 0, f"Save failed for {dest_path}"
    return dest_path

def _verify_csv_ok(p: str) -> bool:
    try:
        if not (os.path.exists(p) and os.path.getsize(p) > 0):
            return False
        _ = pd.read_csv(p, nrows=2)
        log.info("Verified dataset CSV: %s (%s bytes)", p, f"{os.path.getsize(p):,}")
        return True
    except Exception as e:
        log.warning("Dataset candidate failed verification (%s): %s", p, e)
        return False

def _log_artifact(p: str):
    try:
        ok = os.path.exists(p); sz = os.path.getsize(p) if ok else 0
        tm = time.ctime(os.path.getmtime(p)) if ok else "-"
        head = pd.read_csv(p, nrows=2)
        log.info("Artifact OK: %s | size=%s | mtime=%s | cols=%s",
                 p, f"{sz:,}", tm, list(head.columns)[:8])
        log.info("Head preview:\n%s", head.to_string(index=False))
    except Exception as e:
        log.warning("Artifact check failed for %s: %s", p, e)

# ---------- Helpers: schema / normalize ----------
def _force_datetime_column(df: pd.DataFrame) -> pd.DataFrame:
    """Ensure tz-naive Datetime column exists; dedupe/sort."""
    if isinstance(df.index, pd.DatetimeIndex):
        try:
            if df.index.tz is not None:
                df.index = df.index.tz_convert(None)
        except Exception:
            try:
                df.index = df.index.tz_localize(None)
            except Exception:
                pass
        df.index.name = 'Datetime'
        df = df.reset_index()
    else:
        df = df.reset_index()
        first = df.columns[0]
        if np.issubdtype(df[first].dtype, np.datetime64):
            df = df.rename(columns={first: 'Datetime'})
        elif 'Date' in df.columns:
            df['Datetime'] = pd.to_datetime(df['Date'])
        elif 'Datetime' not in df.columns:
            df['Datetime'] = pd.to_datetime(df[first], errors='coerce')

    if 'Datetime' not in df.columns:
        raise KeyError("Failed to construct 'Datetime' from data.")

    df['Datetime'] = pd.to_datetime(df['Datetime'])
    return df.drop_duplicates(subset=['Datetime']).sort_values('Datetime').reset_index(drop=True)

def _normalize_ohlcv(df_in: pd.DataFrame, ticker: str) -> pd.DataFrame:
    """Flatten MultiIndex; strip ticker tokens; map to canonical OHLCV names."""
    df = df_in.copy()
    if isinstance(df.columns, pd.MultiIndex):
        df.columns = [" ".join([str(p) for p in col if p]) for col in df.columns]
    df.columns = [re.sub(r"\s+", " ", str(c)).strip() for c in df.columns]

    tkr = ticker.upper().replace("-", "[- ]?")
    cleaned = {}
    for c in df.columns:
        cu = c.upper()
        cu = re.sub(rf"^(?:{tkr})[\s/_-]+", "", cu)
        cu = re.sub(rf"[\s/_-]+(?:{tkr})$", "", cu)
        cleaned[c] = cu.title()
    if any(cleaned[c] != c for c in df.columns):
        df = df.rename(columns=cleaned)

    cols_ci = {c.lower(): c for c in df.columns}
    wants = {
        "Open":      ["open"],
        "High":      ["high"],
        "Low":       ["low"],
        "Close":     ["close", "last", "close*"],
        "Adj Close": ["adj close","adj_close","adjclose","adjusted close"],
        "Volume":    ["volume","vol"]
    }
    rename_map = {}
    for desired, alts in wants.items():
        if desired.lower() in cols_ci:
            rename_map[cols_ci[desired.lower()]] = desired
        else:
            for a in alts:
                if a in cols_ci:
                    rename_map[cols_ci[a]] = desired
                    break
    if rename_map:
        df = df.rename(columns=rename_map)
    return df

# ---------- Downloader (with history() fallback & retries) ----------
def download_stock_data(ticker, interval="1h", period_days=720, max_retries=5, sleep_base=3):
    """
    Robust yfinance intraday downloader.
    Guarantees: Open, High, Low, Close, Volume (+Adj Close), Datetime, Symbol.
    """
    period_days = int(period_days)
    period_str = f"{period_days}d"

    def _post(df: pd.DataFrame) -> pd.DataFrame:
        df = _normalize_ohlcv(df, ticker)
        df = _force_datetime_column(df)
        needed = {'Open', 'High', 'Low', 'Close', 'Volume'}
        missing = needed - set(df.columns)
        if missing:
            raise ValueError(f"Missing OHLCV columns after normalize: {missing}")
        if 'Adj Close' not in df.columns:
            df['Adj Close'] = df['Close']
        return df

    for attempt in range(1, max_retries + 1):
        try:
            log.info(f"[{ticker}] Attempt {attempt}: download(period={period_str}, interval={interval})")
            df = yf.download(
                tickers=ticker,
                period=period_str,
                interval=interval,
                progress=False,
                auto_adjust=False,
                group_by='column',
                threads=False,
                prepost=False,
                repair=True
            )
            if df is None or df.empty:
                raise ValueError("Empty data from download()")

            df = _post(df)
            df['Symbol'] = ticker
            log.info(f"[{ticker}] rows={len(df)} {df['Datetime'].min()} → {df['Datetime'].max()}")
            return df

        except Exception as e1:
            # --- Yahoo 730-day clamp logic (self-heal) ---
            msg = str(e1)
            too_long = ("must be within the last 730 days" in msg.lower()) or ("no price data found" in msg.lower())
            if too_long:
                clamp_days = _max_days_for_interval(interval)
                if period_days > clamp_days:
                    log.warning("[%s] Yahoo limit hit (%s). Retrying with period_days=%d.", ticker, e1, clamp_days)
                    period_days = clamp_days
                    period_str = f"{period_days}d"
                    try:
                        df = yf.download(
                            tickers=ticker, period=period_str, interval=interval,
                            progress=False, auto_adjust=False, group_by='column',
                            threads=False, prepost=False, repair=True
                        )
                        if df is not None and not df.empty:
                            df = _post(df)
                            df['Symbol'] = ticker
                            log.info(f"[{ticker}] rows={len(df)} {df['Datetime'].min()} → {df['Datetime'].max()} (clamped)")
                            return df
                    except Exception as e1b:
                        log.warning("[%s] clamp retry failed (%s); will try history() fallback.", ticker, e1b)

            # --- history() fallback with backoff ---
            log.warning(f"[{ticker}] download() error: {e1} | trying Ticker().history()")
            try:
                hist = yf.Ticker(ticker).history(
                    period=period_str,
                    interval=interval,
                    auto_adjust=False,
                    actions=False
                )
                if hist is None or hist.empty:
                    raise ValueError("Empty data from history()")

                df = _post(hist)
                df['Symbol'] = ticker
                log.info(f"[{ticker}] (fallback) rows={len(df)} {df['Datetime'].min()} → {df['Datetime'].max()}")
                return df

            except Exception as e2:
                wait = sleep_base * attempt
                log.warning(f"[{ticker}] history() error: {e2} | retrying in {wait}s")
                time.sleep(wait)

    log.error(f"[{ticker}] Failed after {max_retries} attempts.")
    return None


# ---------- Feature Engineering ----------
def denoise_wavelet(series, wavelet='db1', level=2):
    s = pd.Series(series).astype(float).ffill().bfill().to_numpy()
    try:
        coeffs = pywt.wavedec(s, wavelet, mode='symmetric', level=level)
        # Hard smoothing: zero high-frequency detail
        for i in range(1, len(coeffs)):
            coeffs[i] = np.zeros_like(coeffs[i])
        rec = pywt.waverec(coeffs, wavelet, mode='symmetric')
        return pd.Series(rec[:len(s)], index=series.index)
    except Exception as e:
        log.warning(f"Wavelet denoise failed ({e}); returning raw series.")
        return pd.Series(s, index=series.index)

def add_regime(df: pd.DataFrame) -> pd.DataFrame:
    df['Vol20'] = df['Close'].pct_change().rolling(20).std()
    df['Ret20'] = df['Close'].pct_change(20)
    vol_hi   = (df['Vol20'] > df['Vol20'].median()).astype(int)
    trend_hi = (df['Ret20'].abs() > df['Ret20'].abs().median()).astype(int)
    df['Regime4'] = vol_hi * 2 + trend_hi  # 0..3
    return df

def compute_enhanced_features(df: pd.DataFrame) -> tuple[pd.DataFrame, list]:
    """
    Returns (feature_df, FEATURES). Input must contain:
      Datetime, Symbol, Open, High, Low, Close, Volume.
    """
    df = df.copy()
    req = {"Open", "High", "Low", "Close", "Volume"}
    assert req.issubset(df.columns), f"OHLCV columns missing: {req - set(df.columns)}"

    close = df["Close"].astype("float64")
    high  = df["High"].astype("float64")
    low   = df["Low"].astype("float64")
    open_ = df["Open"].astype("float64")
    vol   = (df["Volume"].astype("float64") + 1.0)

    # Returns & lags
    df["ret_1"]     = close.pct_change(1)
    df["ret_3"]     = close.pct_change(3)
    df["ret_5"]     = close.pct_change(5)
    df["ret_10"]    = close.pct_change(10)
    df["logret_1"]  = np.log(close).diff(1)

    # MAs & volatility
    df["ma_5"]      = close.rolling(5).mean()
    df["ma_10"]     = close.rolling(10).mean()
    df["ma_20"]     = close.rolling(20).mean()
    df["ema_10"]    = close.ewm(span=10, adjust=False).mean()
    df["ema_20"]    = close.ewm(span=20, adjust=False).mean()
    df["std_10"]    = close.rolling(10).std()
    df["std_20"]    = close.rolling(20).std()
    df["ema10_ratio"] = df["ema_10"] / close - 1.0
    df["ema20_ratio"] = df["ema_20"] / close - 1.0
    df["z_close_20"] = (close - df["ma_20"]) / df["std_20"].replace(0, np.nan)

    # RSI(14)
    delta      = close.diff()
    up         = delta.clip(lower=0.0)
    down       = -delta.clip(upper=0.0)
    roll_up    = up.ewm(alpha=1/14, adjust=False).mean()
    roll_down  = down.ewm(alpha=1/14, adjust=False).mean()
    rs         = roll_up / roll_down.replace(0, np.nan)
    df["rsi_14"] = 100 - (100 / (1 + rs))

    # Stochastic(14,3)
    ll14       = low.rolling(14).min()
    hh14       = high.rolling(14).max()
    den_14     = (hh14 - ll14).replace(0, np.nan)
    df["stoch_k"] = 100 * (close - ll14) / den_14
    df["stoch_d"] = df["stoch_k"].rolling(3).mean()

    # MACD (12,26,9)
    ema12         = close.ewm(span=12, adjust=False).mean()
    ema26         = close.ewm(span=26, adjust=False).mean()
    macd          = ema12 - ema26
    macd_sig      = macd.ewm(span=9, adjust=False).mean()
    df["macd"]        = macd
    df["macd_signal"] = macd_sig
    df["macd_hist"]   = macd - macd_sig

    # Bollinger (20,2)
    bb_mid      = df["ma_20"]
    bb_std      = df["std_20"]
    bb_up       = bb_mid + 2 * bb_std
    bb_lo       = bb_mid - 2 * bb_std
    band_width  = (bb_up - bb_lo)
    df["bb_perc_b"]    = (close - bb_lo) / band_width.replace(0, np.nan)
    df["bb_bandwidth"] = band_width / bb_mid.replace(0, np.nan)

    # ATR(14)
    prev_close  = close.shift(1)
    tr = np.maximum(high - low,
                    np.maximum((high - prev_close).abs(), (low - prev_close).abs()))
    df["atr_14"] = tr.rolling(14).mean()

    # Volume features
    df["vol_ma_20"]    = vol.rolling(20).mean()
    df["vol_std_20"]   = vol.rolling(20).std()
    df["vol_z_20"]     = (vol - df["vol_ma_20"]) / df["vol_std_20"].replace(0, np.nan)
    df["vol_change_1"] = vol.pct_change(1)

    # Spreads & crosses
    df["hl_spread"] = (high - low) / close.replace(0, np.nan)
    df["oc_spread"] = (close - open_) / open_.replace(0, np.nan)
    df["ema10_gt_ema20"] = (df["ema_10"] > df["ema_20"]).astype("float32")

    # Optional extras
    if USE_WAVELET:
        df["denoised_close"] = denoise_wavelet(pd.Series(close, index=df.index))
    if USE_REGIME:
        df = add_regime(df)  # adds 'Regime4'

    # Clean up
    df.replace([np.inf, -np.inf], np.nan, inplace=True)
    df.fillna(0.0, inplace=True)

    FEATURES = [
        "ret_1","ret_3","ret_5","ret_10","logret_1",
        "ma_5","ma_10","ma_20","ema_10","ema_20","ema10_ratio","ema20_ratio",
        "std_10","std_20","z_close_20",
        "rsi_14","stoch_k","stoch_d",
        "macd","macd_signal","macd_hist",
        "bb_perc_b","bb_bandwidth",
        "atr_14",
        "vol_ma_20","vol_std_20","vol_z_20","vol_change_1",
        "hl_spread","oc_spread","ema10_gt_ema20",
    ]
    df[FEATURES] = df[FEATURES].astype("float32")
    return df, FEATURES

# ---------- Orchestrator ----------
VAL_FRACTION = float(os.getenv("VAL_FRACTION", "0.20"))  # only for optional local train/val files
CANDIDATE_CSVS = [FEATURE_CSV_LOCAL, FEATURE_CSV_RESULTS, FEATURE_CSV_TRADING]

def build_features() -> pd.DataFrame:
    # ---- PPO-style: re-use any healthy dataset across locations ----
    for cand in CANDIDATE_CSVS:
        if _verify_csv_ok(cand):
            d = pd.read_csv(cand)
            d['Datetime'] = pd.to_datetime(d['Datetime'], utc=True)  # keep UTC in memory
            log.info("Using existing features CSV: %s | rows=%d cols=%d", cand, len(d), d.shape[1])
            return d

    all_dfs = []
    for i, ticker in enumerate(SYMBOLS, 1):
        log.info(f"[{i}/{len(SYMBOLS)}] {ticker} — downloading")
        raw = download_stock_data(ticker, interval=INTERVAL, period_days=PERIOD_DAYS)
        if raw is None or raw.empty:
            log.warning(f"[{ticker}] no data; skipping.")
            continue
        try:
            feats, _ = compute_enhanced_features(raw)
            if feats is not None and not feats.empty:
                outp = os.path.join(DATA_DIR, f"{ticker}.parquet")
                feats.to_parquet(outp, index=False)
                log.info(f"[{ticker}] features={len(feats)} rows → {outp}")
                all_dfs.append(feats)
            else:
                log.warning(f"[{ticker}] empty features; skipped.")
        except Exception as e:
            log.error(f"[{ticker}] FE failed: {e}")
        finally:
            del raw
            try: del feats
            except: pass
            gc.collect()
            time.sleep(0.2)

    if not all_dfs:
        raise RuntimeError("No usable data found for any ticker.")

    full = pd.concat(all_dfs, ignore_index=True)

    # ---- Keep UTC in storage; only convert a view to NY for RTH mask ----
    full['Datetime'] = pd.to_datetime(full['Datetime'], utc=True)  # tz-aware UTC for storage
    dt_ny = full['Datetime'].dt.tz_convert('America/New_York')
    rth_mask = (
        (dt_ny.dt.weekday < 5) &
        (dt_ny.dt.time >= pd.to_datetime("09:30").time()) &
        (dt_ny.dt.time <  pd.to_datetime("16:00").time())
    )
    full = full[rth_mask].reset_index(drop=True)

    # ---- Save atomically: local + both Drive locations ----
    save_csv_atomically(full, FEATURE_CSV_LOCAL)
    save_csv_atomically(full, FEATURE_CSV_RESULTS)
    save_csv_atomically(full, FEATURE_CSV_TRADING)

    # Parquet mirror (local)
    full.to_parquet(PARQ_FULL, index=False)

    # Diagnostics: existence + size + small preview
    for p in (FEATURE_CSV_LOCAL, FEATURE_CSV_RESULTS, FEATURE_CSV_TRADING):
        _log_artifact(p)

    log.info("Saved combined CSV (rows=%d) →\n- %s\n- %s\n- %s",
             len(full), FEATURE_CSV_LOCAL, FEATURE_CSV_RESULTS, FEATURE_CSV_TRADING)
    return full

# ---------- Optional: utility used by training scripts ----------
def find_or_build_dataset():
    """Try local/Drive mirrors; if none, build and save to all mirrors."""
    for cand in CANDIDATE_CSVS:
        if _verify_csv_ok(cand):
            df = pd.read_csv(cand)
            df["Datetime"] = pd.to_datetime(df["Datetime"], utc=True)
            return cand, df
    df = build_features()
    save_csv_atomically(df, FEATURE_CSV_LOCAL)
    save_csv_atomically(df, FEATURE_CSV_RESULTS)
    save_csv_atomically(df, FEATURE_CSV_TRADING)
    return FEATURE_CSV_LOCAL, df

# ---------- Run end-to-end ----------
if __name__ == "__main__":
    df_full = build_features()

    # (Optional) Create quick time-based train/val splits locally for inspection
    try:
        df_full = df_full.sort_values('Datetime').reset_index(drop=True)
        cutoff_idx = int((1.0 - VAL_FRACTION) * len(df_full))
        cutoff_idx = min(max(1, cutoff_idx), len(df_full) - 1)  # guardrails
        cutoff_time = df_full.loc[cutoff_idx, 'Datetime']
        train_df = df_full[df_full['Datetime'] <  cutoff_time].reset_index(drop=True)
        val_df   = df_full[df_full['Datetime'] >= cutoff_time].reset_index(drop=True)

        # Atomic saves everywhere (local + Drive mirror for convenience)
        save_csv_atomically(df_full,  FEATURE_CSV_LOCAL)   # ensure latest combined locally
        save_csv_atomically(train_df, "train.csv")
        save_csv_atomically(val_df,   "val.csv")

        # Drive mirror of splits
        save_csv_atomically(train_df, os.path.join(TRADING_DIR, "train.csv"))
        save_csv_atomically(val_df,   os.path.join(TRADING_DIR, "val.csv"))

        # Parquet mirrors
        train_df.to_parquet(PARQ_TRAIN, index=False)
        val_df.to_parquet(PARQ_VAL, index=False)

        log.info("Time split cutoff @ %s", cutoff_time)
        log.info("Train: %s, Val: %s", train_df.shape, val_df.shape)

        for p in ("train.csv", "val.csv", PARQ_TRAIN, PARQ_VAL):
            exists = os.path.exists(p)
            sz = os.path.getsize(p) if exists else 0
            log.info("Saved split artifact: exists=%s | size=%s | path=%s", exists, sz, p)

        # Light summary
        try:
            sym_counts = df_full['Symbol'].value_counts()
            log.info("Symbols saved (top 10):\n%s", sym_counts.head(10).to_string())
            log.info("Datetime range: %s → %s", df_full['Datetime'].min(), df_full['Datetime'].max())
        except Exception:
            pass

        # Log artifacts
        for p in (FEATURE_CSV_LOCAL,
                  os.path.join(TRADING_DIR, "train.csv"),
                  os.path.join(TRADING_DIR, "val.csv")):
            _log_artifact(p)

        # Cleanup refs
        del train_df, val_df
    except Exception as e:
        log.warning("Train/Val split skipped: %s", e)

    gc.collect()
    log.info("Download + Feature Build complete.")

In [None]:
import os, time, pandas as pd
for p in [
    "multi_stock_feature_engineered_dataset.csv",
    "/content/drive/MyDrive/Results_May_2025/results_sac_walkforward/multi_stock_feature_engineered_dataset.csv",
    "/content/drive/MyDrive/trading_data/multi_stock_feature_engineered_dataset.csv",
]:
    print(p, "->", os.path.exists(p))
    if os.path.exists(p):
        print("  size:", os.path.getsize(p), "modified:", time.ctime(os.path.getmtime(p)))
        print("  cols:", list(pd.read_csv(p, nrows=1).columns)[:8])


In [None]:
import os, sys, random, logging

# Modes
TEST_MODE = os.getenv("TEST_MODE", "1").lower() not in ("0", "false", "no")
FAST_TEST = os.getenv("FAST_TEST", "1").lower() in ("1", "true", "yes")
RELAX_CPU_THREADS = TEST_MODE and os.getenv("RELAX_CPU_THREADS", "0").lower() in ("1", "true", "yes")

# Determinism knobs (must be set before importing numpy/torch)
os.environ["PYTHONHASHSEED"] = "42"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"  # cuBLAS determinism (Ampere+)

# CPU threading policy (assign so we truly override)
_thr = str(min(4, (os.cpu_count() or 4)))
for k in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS"):
    os.environ[k] = _thr if RELAX_CPU_THREADS else "1"

# Logging
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
    level=getattr(logging, LOG_LEVEL, logging.INFO),
    format=":%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler("sac_run.log", mode="a")],
)
log = logging.getLogger("SAC-TrainExec")

import numpy as np
import torch

# Torch thread caps (respect env above)
try:
    torch.set_num_interop_threads(1)
    torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", "1")))
except Exception:
    pass

# Seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cuda.matmul.allow_tf32 = False  # avoid non-deterministic TF32
    torch.backends.cudnn.allow_tf32 = False

# Deterministic backends
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
try:
    torch.use_deterministic_algorithms(True)
except Exception as e:
    # Some ops/hardware combos may not support full determinism; continue safely.
    log.warning("Could not force full deterministic algorithms: %s", e)

# ---- libs (no yfinance / no feature-building here) ----
import gc, time, json, re, warnings, heapq, glob
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta, timezone
from shutil import copyfile
from math import isfinite
from collections import defaultdict

# gymnasium + SB3
import gymnasium as gym
from gymnasium.spaces import Box
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import (
    EvalCallback, StopTrainingOnNoModelImprovement, CheckpointCallback, CallbackList, BaseCallback
)
from stable_baselines3.common.utils import set_random_seed

warnings.filterwarnings("ignore", message=".*Gym has been unmaintained.*")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="jupyter_client.session")
set_random_seed(SEED)

# ========= Paths =========
DRIVE_BASE  = os.getenv("DRIVE_BASE", "/content/drive/MyDrive")  # no auto-mount here
BASE_DIR    = "./sac_parity"
RESULTS_DIR = os.path.join(DRIVE_BASE, "Results_May_2025", "results_sac_walkforward")
SAVE_DIR    = os.path.join(RESULTS_DIR, "models_sac")
for p in [BASE_DIR, RESULTS_DIR, SAVE_DIR]:
    os.makedirs(p, exist_ok=True)

print("RESULTS_DIR =", RESULTS_DIR)
print("Exists?", os.path.exists(RESULTS_DIR))
print("Listdir sample:", os.listdir(RESULTS_DIR)[:10] if os.path.exists(RESULTS_DIR) else "N/A")

csvs = sorted(glob.glob(os.path.join(RESULTS_DIR, "*_sac_ws_sweep_summary.csv")))
print("Matched:", [os.path.basename(p) for p in csvs])

for p in csvs:
    df = pd.read_csv(p)
    print(os.path.basename(p), "rows:", len(df))

for sub in ["scalers", "features", "signals", "plots", "vecnorms", "tmp", "model_info", "best"]:
    os.makedirs(os.path.join(RESULTS_DIR, sub), exist_ok=True)

# Canonical CSV locations (must already exist, built by data-prep script)
FEATURE_CSV_RESULTS = os.path.join(RESULTS_DIR, "multi_stock_feature_engineered_dataset.csv")
FEATURE_CSV_DRIVE   = os.path.join(DRIVE_BASE, "trading_data", "multi_stock_feature_engineered_dataset.csv")
FEATURE_CSV_LOCAL   = "multi_stock_feature_engineered_dataset.csv"

# ========= Data: strict read-only load =========
CANDIDATES_DATA = [FEATURE_CSV_LOCAL, FEATURE_CSV_RESULTS, FEATURE_CSV_DRIVE]
DATA_PATH = next((p for p in CANDIDATES_DATA if os.path.exists(p)), None)

if DATA_PATH is None:
    log.warning("Dataset CSV not found in %s; attempting to build it now...", CANDIDATES_DATA)
    build_features = None

    # Try to import from your download/feature-build module if it's a separate file.
    try:
        from download_fe import build_features as _bf, FEATURE_CSV_LOCAL as _F_LOCAL
        build_features = _bf
        FEATURE_CSV_LOCAL = _F_LOCAL  # keep paths consistent with the builder, if different
        log.info("Imported build_features() from download_fe.py")
    except Exception:
        # Fall back to a global if you executed the combined, single-script version in the same runtime.
        build_features = globals().get("build_features", None)
        if build_features:
            log.info("Using build_features() found in current runtime (combined script).")

    if build_features is None:
        raise FileNotFoundError(
            "Prebuilt dataset CSV not found and build_features() is unavailable.\n"
            "→ Run the Download + Feature Build script first OR ensure `from download_fe import build_features` works."
        )

    built_df = build_features()
    built_df.to_csv(FEATURE_CSV_LOCAL, index=False)
    DATA_PATH = FEATURE_CSV_LOCAL
    log.info("Rebuilt dataset → %s", DATA_PATH)

log.info("Using dataset: %s", DATA_PATH)
df = pd.read_csv(DATA_PATH)

if "Datetime" not in df.columns or "Symbol" not in df.columns or "Close" not in df.columns:
    raise ValueError("Dataset must contain at least ['Datetime','Symbol','Close'] plus numeric features.")

df["Datetime"] = pd.to_datetime(df["Datetime"], utc=True)

# Identify numeric feature columns (observations exclude Symbol/Datetime/Close)
feature_cols = [
    c for c in df.columns
    if c not in ["Symbol", "Datetime", "Close"] and pd.api.types.is_numeric_dtype(df[c])
]
if not feature_cols:
    raise ValueError("No numeric feature columns found for observations.")
log.info("Feature columns detected: %d", len(feature_cols))

# ========= SAC / Env toggles =========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_kwargs = dict(net_arch=[256, 256])


LONG_ONLY      = True  # set False to allow shorts again
ENABLE_PLOTS   = False
ENABLE_SLO     = True     # SL/TP + cooldown in env
LIVE_MODE      = False    # (no live helpers in this script)
SIM_LATENCY_MS = 0
BROKER         = "log"
USE_REGIME_TRAIN = False  # (no feature-building here—left for compatibility flags)
USE_SENTIMENT_TRAIN = False
ENABLE_WAVELET_TRAIN = False

# Risk controls & hygiene
DEAD_BAND        = 0.01
MIN_TRADE_DELTA  = 0.05
STOP_LOSS_PCT    = 0.04
TAKE_PROFIT_PCT  = 0.08
COOLDOWN_STEPS   = 6

# Reward shaping
ENABLE_WHIPSAW_PENALTY   = True
WHIPSAW_PENALTY          = 1e-4
ENABLE_COOLDOWN_PENALTY  = True
COOLDOWN_STEP_PENALTY    = 0.0
CARRY_WEIGHT             = 0.00  # 0.02–0.05
IDLE_PENALTY             = 0.0

MAX_EXPOSURE    = 1.0
COMMISSION_BPS  = 0.5
SLIPPAGE_BPS    = 1.0
INITIAL_CAPITAL = 100_000.0

# Window sweep / schedule
CANDIDATE_WS   = [12, 16, 24]
TOP_N_WINDOWS  = 1 if TEST_MODE else 3
FORCE_RETRAIN  = os.getenv("FORCE_RETRAIN", "0").lower() in ("1","true","yes")


LOG_TRADES_TRAIN  = False
LOG_TRADES_EVALCB = False
LOG_TRADES_FINAL  = False

# Checkpointing
CKPT_FREQ_STEPS = int(os.getenv("CKPT_FREQ_STEPS", "50000"))  # save every N env steps
CKPT_DIR_ROOT   = os.path.join(RESULTS_DIR, "ckpts")
os.makedirs(CKPT_DIR_ROOT, exist_ok=True)


# Base train steps (adjusted below for TEST_MODE/FAST_TEST)
TRAIN_TOTAL_STEPS = 50_000
MIN_TRAIN_STEPS = 1000

FAST = dict(
    learning_rate=3e-4, batch_size=256, train_freq=1, gradient_steps=2,
    gamma=0.995, tau=0.005, target_update_interval=1, ent_coef="auto",
    buffer_size=int(1e6), learning_starts=2_000,
)
SLOW = dict(
    learning_rate=1e-4, batch_size=512, train_freq=4, gradient_steps=2,
    gamma=0.997, tau=0.005, target_update_interval=2, ent_coef="auto",
    buffer_size=int(2e6), learning_starts=4_000,
)

# === Modes / flags ===
# Base from env
USE_EVAL_CALLBACK = os.getenv("USE_EVAL_CALLBACK", "0").lower() in ("1", "true", "yes")

# Precedence: FAST_TEST forces ON, TEST_MODE forces OFF, otherwise env decides
if FAST_TEST:
    USE_EVAL_CALLBACK = True
elif TEST_MODE:
    USE_EVAL_CALLBACK = False

# Optional: quick sanity print
log.info(f"Flags → TEST_MODE={TEST_MODE} | FAST_TEST={FAST_TEST} | USE_EVAL_CALLBACK={USE_EVAL_CALLBACK}")

if FAST_TEST:
    CANDIDATE_WS  = [10]
    TOP_N_WINDOWS = 1
    MIN_TRAIN_STEPS = 1000
    FAST["buffer_size"] = 50_000; FAST["batch_size"] = 128
    SLOW["buffer_size"] = 50_000; SLOW["batch_size"] = 128
    LOG_TRADES_FINAL = False

# ========= Universe bookkeeping =========
symbols_all = sorted({str(s).upper() for s in df["Symbol"].unique()})
fast_names = {
    "TSLA","NVDA","AMD","AVGO","AAPL","MSFT","AMZN","GOOGL","META","ADBE","CRM",
    "INTC","QCOM","TXN","ORCL","NEE","GE","XOM","CVX","LLY","NKE","SBUX"
}
slow_names = {
    "BRK-B","JPM","BAC","JNJ","UNH","MRK","PFE","ABBV","ABT","AMGN","PG","PEP","KO",
    "V","MA","WMT","MCD","TMO","DHR","ACN","IBM","LIN","PM","RTX","UPS","UNP","COST","HD","LOW"
}
def get_sac_cfg(symbol: str) -> dict:
    if symbol in fast_names:
        bucket = "FAST"; base = FAST
    elif symbol in slow_names:
        bucket = "SLOW"; base = SLOW
    else:
        bucket = "DEFAULT"; base = FAST
    log.info("Config bucket for %s -> %s", symbol, bucket)
    return base | {"_bucket": bucket}

# Detect already-trained models to resume/skip
os.makedirs(SAVE_DIR, exist_ok=True)
ws_by_symbol = defaultdict(set)
for f in os.listdir(SAVE_DIR):
    m = re.match(r"^sac_([A-Za-z0-9\-]+)_ws(\d+)\.zip$", f, re.IGNORECASE)
    if m:
        ws_by_symbol[m.group(1).upper()].add(int(m.group(2)))

completed_symbols = sorted([s for s, wsset in ws_by_symbol.items()
                            if set(CANDIDATE_WS).issubset(wsset)])
remaining_symbols = [s for s in symbols_all if s not in completed_symbols]
log.info("Completed: %d | Remaining: %d", len(completed_symbols), len(remaining_symbols))

# ======= Run-list selection (manual/env/TEST_MODE) =======
# 1) Manual short list for quick tests (leave [] to disable manual override)
RUNLIST_MANUAL: list[str] = ['AAPL', 'NVDA', 'MSFT']

# 2) From environment (comma-separated), e.g.:
#    export RUN_SYMBOLS="TSLA,NVDA"
_env = os.getenv("RUN_SYMBOLS", "").strip()
RUNLIST_ENV = [s.strip().upper() for s in _env.split(",") if s.strip()] if _env else []

# 3) TEST_MODE: keep first N remaining symbols if nothing else specified
TEST_PICK_N = int(os.getenv("TEST_PICK_N", "3"))

# Only (re)train symbols that aren't already completed
# (computed earlier from SAVE_DIR)
universe = remaining_symbols  # <- keep this line as in your script

def _restrict_to_universe(cands: list[str]) -> list[str]:
    uni = {s.upper(): s for s in universe}
    return [uni[s.upper()] for s in cands if s.upper() in uni]

if RUNLIST_MANUAL:
    run_symbols = _restrict_to_universe(RUNLIST_MANUAL); why = "manual override"
elif RUNLIST_ENV:
    run_symbols = _restrict_to_universe(RUNLIST_ENV);    why = "RUN_SYMBOLS env"
elif TEST_MODE:
    run_symbols = universe[:TEST_PICK_N];                 why = f"TEST_MODE first {TEST_PICK_N}"
else:
    run_symbols = universe;                               why = "all remaining"

if not run_symbols and universe:
    run_symbols = universe[:1]; why += " (fallback to first remaining)"

log.info("SAC run plan → %s | count=%d | %s", run_symbols, len(run_symbols), why)

# ========= Inference config artifact (optional) =========
inference_cfg_path = os.path.join(RESULTS_DIR, "inference_config.json")
if not os.path.exists(inference_cfg_path):
    with open(inference_cfg_path, "w") as f:
        json.dump(
            {
                "algo": "SAC",
                "action_mode": "deterministic",
                "conf_source": "none",
                "conf_thresh": None,
                "n_eval_runs": 1,
                "created_at": datetime.now(timezone.utc).isoformat(),
            },
            f, indent=2
        )
    log.info("Wrote inference_config.json -> %s", inference_cfg_path)

# ========= Env (continuous exposure with costs + risk controls) =========
class ContinuousTradingEnv(gym.Env):
    """
    Continuous exposure env with costs + risk controls.
      - Action a in [-1,1] -> exposure in [-MAX_EXPOSURE, MAX_EXPOSURE]
      - Deadband around zero, min trade delta
      - Commission + slippage on position change notional
      - Stop-loss/Take-profit + sign-flip guard + cooldown
    """
    metadata = {"render_modes": []}

    def __init__(
        self, df: pd.DataFrame, frame_bound: tuple, window_size: int,
        max_exposure: float = MAX_EXPOSURE,
        commission_bps: float = COMMISSION_BPS,
        slippage_bps: float = SLIPPAGE_BPS,
        initial_capital: float = INITIAL_CAPITAL,
        log_trades: bool = True,
        dead_band: float = 0.0,
        min_trade_delta: float = 0.0,
        stop_loss_pct: float | None = None,
        take_profit_pct: float | None = None,
        cooldown_steps: int = 0,                 # <- int, not float
        whipsaw_penalty: float = 0.0,
        cooldown_step_penalty: float = 0.0,
        long_only: bool = False                  # <- single declaration
    ):
        super().__init__()
        assert "Close" in df.columns and "Datetime" in df.columns, "DataFrame must have Close & Datetime"
        self.raw_df = df.sort_values("Datetime").reset_index(drop=True).copy()
        self.window_size = int(window_size)
        self.start_tick, self.end_tick = int(frame_bound[0]), int(frame_bound[1])
        assert self.end_tick <= len(self.raw_df), "frame_bound[1] exceeds data length"
        assert self.start_tick >= self.window_size, "start tick must be >= window_size"

        self.prices = self.raw_df["Close"].values.astype(np.float64)
        feats = self.raw_df.drop(columns=["Symbol", "Datetime", "Close"], errors="ignore")
        feats = feats.select_dtypes(include=[np.number]).astype(np.float32).values
        self.features = feats
        n_feat = self.features.shape[1]

        self.action_space = Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self.window_size, n_feat), dtype=np.float32)

        self.max_exposure = float(max_exposure)
        self.bps_total = (float(commission_bps) + float(slippage_bps)) / 10_000.0
        self.initial_capital = float(initial_capital)

        self.dead_band = float(dead_band)
        self.min_trade_delta = float(min_trade_delta)
        self.stop_loss_pct = float(stop_loss_pct) if stop_loss_pct is not None else None
        self.take_profit_pct = float(take_profit_pct) if take_profit_pct is not None else None
        self.cooldown_steps = int(cooldown_steps)

        self.whipsaw_penalty = float(whipsaw_penalty)
        self.cooldown_step_penalty = float(cooldown_step_penalty)
        self.log_trades = bool(log_trades)
        self.log_every = 500 if TEST_MODE else 2000
        self.long_only = bool(long_only)         # <- keep only here

        self.current_tick = None
        self.position = None
        self.portfolio_value = None
        self.entry_price = None
        self.cooldown_until = -1
        self.done_tick = self.end_tick - 1

    def _obs(self):
        start = self.current_tick - self.window_size
        obs = self.features[start:self.current_tick]
        return obs.astype(np.float32)

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.current_tick = max(self.start_tick, self.window_size)
        self.position = 0.0
        self.portfolio_value = self.initial_capital
        self.entry_price = None
        self.cooldown_until = -1
        obs = self._obs()
        info = {
            "portfolio_value": float(self.portfolio_value),
            "position": float(self.position),
            "price": float(self.prices[self.current_tick]),
            "trade_cost": 0.0,
            "risk_event": None,
            "cooldown_left": 0,
            "entry_price": None,
        }
        return obs, info

    def _pnl_from_entry(self, price_now: float) -> float:
        if self.entry_price is None or self.entry_price <= 0 or self.position == 0.0:
            return 0.0
        side = 1.0 if self.position > 0 else -1.0
        return side * (price_now / self.entry_price - 1.0)

    def _apply_risk_controls(self, price_t: float, proposed_pos: float):
        event = None
        cooldown_left = max(0, self.cooldown_until - self.current_tick)

        if self.position == 0.0 and self.current_tick < self.cooldown_until:
            return 0.0, "cooldown", cooldown_left

        if (
            self.cooldown_steps > 0 and
            self.position != 0.0 and
            np.sign(proposed_pos) != np.sign(self.position) and
            abs(proposed_pos) >= self.min_trade_delta
        ):
            self.cooldown_until = self.current_tick + self.cooldown_steps
            return 0.0, "flip_block", self.cooldown_steps

        if self.position != 0.0:
            pnl_entry = self._pnl_from_entry(price_t)
            if (self.stop_loss_pct is not None) and (pnl_entry <= -self.stop_loss_pct):
                self.cooldown_until = self.current_tick + self.cooldown_steps
                return 0.0, "stop_loss", self.cooldown_steps
            if (self.take_profit_pct is not None) and (pnl_entry >= self.take_profit_pct):
                self.cooldown_until = self.current_tick + self.cooldown_steps
                return 0.0, "take_profit", self.cooldown_steps

        return proposed_pos, event, cooldown_left

    def step(self, action):
        if SIM_LATENCY_MS and SIM_LATENCY_MS > 0:
            try:
                time.sleep(float(SIM_LATENCY_MS) / 1000.0)
            except Exception:
                pass

        a = float(np.array(action).reshape(-1)[0])
        if abs(a) < self.dead_band:
            a = 0.0

        if self.long_only:
            a = max(0.0, a)  # disallow shorts
            target_pos = float(np.clip(a, 0.0, 1.0)) * self.max_exposure
        else:
            target_pos = float(np.clip(a, -1.0, 1.0)) * self.max_exposure
        target_pos = float(np.clip(target_pos, -self.max_exposure, self.max_exposure))

        if abs(target_pos - (0.0 if self.position is None else self.position)) < self.min_trade_delta:
            target_pos = self.position

        price_t   = float(self.prices[self.current_tick])
        old_pos   = float(0.0 if self.position is None else self.position)

        target_pos, risk_event, cooldown_left = self._apply_risk_controls(price_t, target_pos)

        base_cash = self.portfolio_value if self.portfolio_value else self.initial_capital
        trade_notional = abs(target_pos - old_pos) * base_cash
        trade_cost = self.bps_total * trade_notional

        self.position = float(target_pos)
        self.portfolio_value = float(base_cash - trade_cost)
        if not np.isfinite(self.portfolio_value) or self.portfolio_value <= 0.0:
            self.portfolio_value = max(1e-8, float(self.portfolio_value))

        next_tick = self.current_tick + 1
        terminated = next_tick >= len(self.prices) or self.current_tick >= self.done_tick
        price_tp1 = float(self.prices[min(next_tick, len(self.prices) - 1)])
        ret = 0.0 if price_t <= 0 else (price_tp1 / price_t - 1.0)
        v_prev = self.portfolio_value
        self.portfolio_value = v_prev * (1.0 + self.position * ret)

        # --- carry term: reward scales with ret and long size (no ret>0 gate)
        reward = (self.portfolio_value - v_prev) / v_prev if v_prev > 0 else 0.0
        long_frac = max(self.position, 0.0) / max(self.max_exposure, 1e-9)
        baseline  = CARRY_WEIGHT * ret
        reward   += baseline * long_frac

        penalty = 0.0
        penalty += 1e-4 * abs(target_pos - old_pos)
        whipsaw_applied = False
        cooldown_applied = False
        if self.current_tick >= self.cooldown_until:
            cash_frac = 1.0 - min(abs(self.position) / max(self.max_exposure, 1e-9), 1.0)
            penalty += IDLE_PENALTY * cash_frac
        long_frac = max(self.position, 0.0) / max(self.max_exposure, 1e-9)
        floor_target, floor_w = 0.20, 0.0  # try target 0.2–0.3, weight 1e-4 → 3e-4
        penalty += floor_w * max(0.0, floor_target - long_frac)
        if self.whipsaw_penalty > 0.0:
            if (old_pos != 0.0 and self.position != 0.0 and np.sign(old_pos) != np.sign(self.position)):
                penalty += self.whipsaw_penalty
                whipsaw_applied = True
            if risk_event == "flip_block":
                penalty += self.whipsaw_penalty
                whipsaw_applied = True
        if self.cooldown_step_penalty > 0.0 and self.current_tick < self.cooldown_until:
            penalty += self.cooldown_step_penalty
            cooldown_applied = True

        reward = float(np.clip(reward - penalty, -1.0, 1.0))

        if old_pos == 0.0 and self.position != 0.0:
            self.entry_price = price_t
        if self.position == 0.0:
            self.entry_price = None

        if self.log_trades and (abs(target_pos - old_pos) > 1e-9 or risk_event):
            log.info(
                "t=%d | px=%.4f | pos %.3f→%.3f | notional=%.2f | cost=%.2f bps=%.2f | risk=%s | cd=%d | V=%.2f",
                self.current_tick, price_t, old_pos, self.position,
                trade_notional, trade_cost, self.bps_total * 10_000.0,
                (risk_event or "none"), int(cooldown_left), self.portfolio_value
            )

        self.current_tick = next_tick
        obs = self._obs()
        info = {
            "portfolio_value": float(self.portfolio_value),
            "position": float(self.position),
            "price": float(price_tp1),
            "trade_cost": float(trade_cost),
            "risk_event": risk_event,
            "cooldown_left": int(max(0, self.cooldown_until - self.current_tick)) if self.cooldown_until >= 0 else 0,
            "entry_price": None if self.entry_price is None else float(self.entry_price),
            "whipsaw_penalty_applied": bool(whipsaw_applied),
            "cooldown_penalty_applied": bool(cooldown_applied),
            "broker": str(BROKER),
        }
        truncated = False
        return obs.astype(np.float32), reward, bool(terminated), bool(truncated), info

def _risk_kwargs():
    return dict(
        dead_band=DEAD_BAND,
        min_trade_delta=MIN_TRADE_DELTA,
        stop_loss_pct=(STOP_LOSS_PCT if ENABLE_SLO else None),
        take_profit_pct=(TAKE_PROFIT_PCT if ENABLE_SLO else None),
        cooldown_steps=(COOLDOWN_STEPS if ENABLE_SLO else 0),
    )

def _reward_kwargs():
    return dict(
        whipsaw_penalty=(WHIPSAW_PENALTY if ENABLE_WHIPSAW_PENALTY else 0.0),
        cooldown_step_penalty=(COOLDOWN_STEP_PENALTY if ENABLE_COOLDOWN_PENALTY else 0.0),
    )

# ====== Helpers ======
def _nan_if_none(x):
    if x is None:
        return float("nan")
    try:
        xf = float(x)
    except (TypeError, ValueError):
        return float("nan")
    return xf if np.isfinite(xf) else float("nan")

def _safe_round(x, ndigits):
    x = _nan_if_none(x)
    return round(x, ndigits) if np.isfinite(x) else float("nan")

# ---- Callback to periodically save VecNormalize stats during training ----
class SaveVecNormCallback(BaseCallback):
    def __init__(self, vec_env, save_path: str, save_freq: int = 50_000, verbose: int = 0):
        super().__init__(verbose)
        self.vec_env = vec_env
        self.save_path = save_path
        self.save_freq = int(save_freq)

    def _on_step(self) -> bool:
        if self.num_timesteps % self.save_freq == 0:
            try:
                self.vec_env.save(self.save_path)
                if self.verbose:
                    print(f"[SaveVecNorm] saved -> {self.save_path}")
            except Exception as e:
                if self.verbose:
                    print(f"[SaveVecNorm] failed: {e}")
        return True

class HeartbeatCallback(BaseCallback):
    def __init__(self, every: int = 10_000, target_steps: int | None = None, py_logger=None):
        super().__init__()
        self.every = int(every)
        self.target_steps = target_steps
        self._pylog = py_logger or logging.getLogger(__name__)
        self._t0 = None
        self._last_t = None
        self._last_step = 0

    def _on_training_start(self) -> None:
        self._t0 = time.time()
        self._last_t = self._t0
        self._last_step = self.num_timesteps

    def _on_step(self) -> bool:
        if (self.num_timesteps - self._last_step) >= self.every:
            now = time.time()
            dt = max(now - self._last_t, 1e-9)
            total_dt = max(now - self._t0, 1e-9)
            steps = self.num_timesteps
            inst_sps = (steps - self._last_step) / dt
            avg_sps = steps / total_dt
            eta_str = ""
            if self.target_steps is not None and inst_sps > 0:
                remaining = max(self.target_steps - steps, 0)
                eta_sec = remaining / inst_sps
                eta_str = f" | eta={eta_sec/60:.1f}m"
            self._pylog.warning(
                f"[HB] steps={steps:,} | inst={inst_sps:,.0f} sps | avg={avg_sps:,.0f} sps"
                f" | elapsed={total_dt/60:.1f}m{eta_str}"
            )
            self._last_t = now
            self._last_step = steps
        return True


# ========= Train/Eval per symbol =========
skipped_all = []
global_rows = []

for idx, symbol in enumerate(run_symbols, 1):
    log.warning("▶ [%d/%d] Processing %s | candidate windows=%s",
                idx, len(run_symbols), symbol, CANDIDATE_WS)
    log.warning("---------------------------------------------------------------")

    sac_cfg = get_sac_cfg(symbol)
    sdf = df[df["Symbol"] == symbol].sort_values("Datetime")

    # Walk-forward split (1y train, 1y test within available ~2y window)
    max_ts = sdf["Datetime"].max()
    test_end_dt    = max_ts
    train_start_dt = test_end_dt - timedelta(days=729)
    train_end_dt   = train_start_dt + timedelta(days=365)
    test_start_dt  = train_end_dt

    train_start_ts = pd.to_datetime(train_start_dt, utc=True)
    train_end_ts   = pd.to_datetime(train_end_dt,   utc=True)
    test_start_ts  = pd.to_datetime(test_start_dt,  utc=True)
    test_end_ts    = pd.to_datetime(test_end_dt,    utc=True)

    train_df = sdf[(sdf["Datetime"] >= train_start_ts) & (sdf["Datetime"] <  train_end_ts)].reset_index(drop=True)
    test_df  = sdf[(sdf["Datetime"] >= test_start_ts)  & (sdf["Datetime"] <= test_end_ts)].reset_index(drop=True)
    if len(train_df) < max(CANDIDATE_WS) + 10 or len(test_df) < max(CANDIDATE_WS) + 10:
        log.warning("Symbol %s has too few rows for this split; skipping.", symbol)
        continue
    steps_est = max(len(train_df), 1)

    existing_ws = set()
    try:
        for f in os.listdir(SAVE_DIR):
            m = re.match(rf"^sac_{re.escape(symbol)}_ws(\d+)\.zip$", f, re.IGNORECASE)
            if m:
                existing_ws.add(int(m.group(1)))
    except Exception as e:
        log.warning("Could not scan SAVE_DIR for existing windows (%s). Proceeding as if none exist.", e)

    pending_ws = CANDIDATE_WS if FORCE_RETRAIN else [ws for ws in CANDIDATE_WS if ws not in existing_ws]
    if not pending_ws:
        log.warning("⏭ Ticker %s fully skipped (all %d windows already complete).", symbol, len(CANDIDATE_WS))
        skipped_all.append(symbol)
        continue
    else:
        have_str = ", ".join(map(str, sorted(existing_ws))) if existing_ws else "none"
        todo_str = ", ".join(map(str, pending_ws))
        log.info("▶️  %s pending windows: [%s] (already have: %s)", symbol, todo_str, have_str)

    # Write per-symbol features.json (auditing)
    try:
        sym_dir = os.path.join(RESULTS_DIR, symbol); os.makedirs(sym_dir, exist_ok=True)
        _feature_cols = sorted(feature_cols)
        with open(os.path.join(sym_dir, "features.json"), "w") as f:
            json.dump(
                {
                    "algorithm": "SAC",
                    "created_at": datetime.now(timezone.utc).isoformat(),
                    "features": _feature_cols,
                    "note": "Active numeric observation columns for this symbol/run."
                },
                f, indent=2
            )
        log.info("Wrote %s/features.json with %d features", symbol, len(_feature_cols))
    except Exception as e:
        log.warning("Per-symbol features.json write skipped for %s: %s", symbol, e)

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # ==== WS sweep (train+eval per window size, rank by Sharpe) ====
    def train_eval_for_ws(ws: int):
        log.info("  • WS=%d | building envs ...", ws)
        model_path = f"{SAVE_DIR}/sac_{symbol}_ws{ws}"
        vec_path   = f"{RESULTS_DIR}/vecnorms/{symbol}_ws{ws}_vecnorm.pkl"
        ckpt_dir   = os.path.join(CKPT_DIR_ROOT, f"{symbol}_ws{ws}")
        os.makedirs(ckpt_dir, exist_ok=True)


        # --- factories for this ws ---
        def make_train_env():
            return ContinuousTradingEnv(
                train_df, frame_bound=(ws, len(train_df)), window_size=ws,
                max_exposure=MAX_EXPOSURE, commission_bps=COMMISSION_BPS,
                slippage_bps=SLIPPAGE_BPS, initial_capital=INITIAL_CAPITAL,
                log_trades=LOG_TRADES_TRAIN, long_only=LONG_ONLY,
                **_risk_kwargs(), **_reward_kwargs()
            )

        def make_test_env_for_callback():
            return ContinuousTradingEnv(
                test_df, frame_bound=(ws, len(test_df)), window_size=ws,
                max_exposure=MAX_EXPOSURE, commission_bps=COMMISSION_BPS,
                slippage_bps=SLIPPAGE_BPS, initial_capital=INITIAL_CAPITAL,
                log_trades=LOG_TRADES_EVALCB, long_only=LONG_ONLY,
                **_risk_kwargs(), **_reward_kwargs()
            )

        def make_test_env_for_final():
            return ContinuousTradingEnv(
                test_df, frame_bound=(ws, len(test_df)), window_size=ws,
                max_exposure=MAX_EXPOSURE, commission_bps=COMMISSION_BPS,  # <- fixed typo
                slippage_bps=SLIPPAGE_BPS, initial_capital=INITIAL_CAPITAL,
                log_trades=LOG_TRADES_FINAL, long_only=LONG_ONLY,
                **_risk_kwargs(), **_reward_kwargs()
            )

        # --- vec wrapper(s) ---
        train_venv = DummyVecEnv([make_train_env])

        # (re)load or create VecNormalize for training
        if os.path.exists(vec_path):
            try:
                train_env = VecNormalize.load(vec_path, train_venv)
                train_env.training = True
                log.info("  • WS=%d | restored VecNormalize from %s", ws, vec_path)
            except Exception as e:
                log.warning("  • WS=%d | VecNormalize load failed (%s). Starting fresh.", ws, e)
                train_env = VecNormalize(train_venv, training=True, norm_obs=True, norm_reward=False, clip_obs=10.0)
        else:
            train_env = VecNormalize(train_venv, training=True, norm_obs=True, norm_reward=False, clip_obs=10.0)

        # Seed + reset after stats are in place
        train_env.seed(SEED); _ = train_env.reset()

        # --- SAC config (bucketed) ---
        loc_cfg = sac_cfg.copy()
        # Cap replay buffer for this run
        loc_cfg["buffer_size"] = min(int(loc_cfg["buffer_size"]), steps_est * 20)

        # --- choose total steps (define BEFORE eval callback) ---
        if FAST_TEST:
            base_steps  = max(200, min(MIN_TRAIN_STEPS, len(train_df)))
            total_steps = max(base_steps, loc_cfg["learning_starts"] + 5_000)
        else:
            total_steps = (
                TRAIN_TOTAL_STEPS if TRAIN_TOTAL_STEPS is not None
                else (min(4_000, len(train_df) * 2) if TEST_MODE else min(25_000, len(train_df) * 10))
            )
        eval_every = max(1_000, total_steps // 50)

        # Optional eval env (only if we're using the eval callback)
        eval_env = None
        if USE_EVAL_CALLBACK:
            eval_venv = DummyVecEnv([lambda: Monitor(make_test_env_for_callback())])
            eval_env  = VecNormalize(eval_venv, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0)
            eval_env.obs_rms = train_env.obs_rms
            eval_env.seed(SEED); _ = eval_env.reset()

        # --- callbacks (create ONCE) ---
        ckpt_cb = CheckpointCallback(
            save_freq=CKPT_FREQ_STEPS,
            save_path=ckpt_dir,
            name_prefix="sac",
            save_replay_buffer=True
        )
        vec_cb = SaveVecNormCallback(train_env, vec_path, save_freq=CKPT_FREQ_STEPS, verbose=1)

        eval_cb = None
        if USE_EVAL_CALLBACK and eval_env is not None:
            n_eval_no_improve, min_evals = (1, 2) if FAST_TEST else (2, 4)
            eval_cb = EvalCallback(
                eval_env,
                best_model_save_path=os.path.join(RESULTS_DIR, "tmp", f"best_{symbol}_ws{ws}"),
                eval_freq=eval_every,
                n_eval_episodes=1,
                callback_after_eval=StopTrainingOnNoModelImprovement(n_eval_no_improve, min_evals, verbose=1),
                verbose=1,
            )
        hb_cb = HeartbeatCallback(
          every=max(1_000, total_steps // 100),
          target_steps=total_steps,
          py_logger=log,
        )

        callback = CallbackList([cb for cb in (eval_cb, ckpt_cb, vec_cb, hb_cb) if cb is not None])

        log.info(
            "EvalCallback: %s | Checkpoint: ENABLED | VecNormSaver: ENABLED | Heartbeat: ENABLED",
            "ENABLED" if eval_cb else "DISABLED")

        # --- create model OR resume from latest checkpoint
        resuming = False
        latest_ckpt = None
        try:
            ckpts = sorted(glob.glob(os.path.join(ckpt_dir, "sac_*_steps.zip")), key=os.path.getmtime)
            if ckpts:
                latest_ckpt = ckpts[-1]
        except Exception:
            latest_ckpt = None

        if latest_ckpt:
            try:
                model = SAC.load(latest_ckpt, env=train_env, device=device)
                resuming = True
                log.info("  • WS=%d | resumed model from %s", ws, latest_ckpt)
            except Exception as e:
                log.warning("  • WS=%d | failed to load checkpoint (%s). Starting fresh.", ws, e)
                model = SAC(
                    "MlpPolicy", train_env, device=device, policy_kwargs=policy_kwargs, verbose=1,
                    learning_rate=loc_cfg["learning_rate"], batch_size=loc_cfg["batch_size"],
                    train_freq=loc_cfg["train_freq"], gradient_steps=loc_cfg["gradient_steps"],
                    gamma=loc_cfg["gamma"], tau=loc_cfg["tau"], ent_coef=loc_cfg["ent_coef"],
                    target_update_interval=loc_cfg["target_update_interval"],
                    buffer_size=loc_cfg["buffer_size"], learning_starts=loc_cfg["learning_starts"], seed=SEED,
                )
        else:
            model = SAC(
                "MlpPolicy", train_env, device=device, policy_kwargs=policy_kwargs, verbose=1,
                learning_rate=loc_cfg["learning_rate"], batch_size=loc_cfg["batch_size"],
                train_freq=loc_cfg["train_freq"], gradient_steps=loc_cfg["gradient_steps"],
                gamma=loc_cfg["gamma"], tau=loc_cfg["tau"], ent_coef=loc_cfg["ent_coef"],
                target_update_interval=loc_cfg["target_update_interval"],
                buffer_size=loc_cfg["buffer_size"], learning_starts=loc_cfg["learning_starts"], seed=SEED,
            )
        if resuming:
            buffer_path = latest_ckpt.replace(".zip", "_replay_buffer.pkl")
            if os.path.exists(buffer_path):
                try:
                    model.load_replay_buffer(buffer_path)
                    log.info("  • WS=%d | loaded replay buffer %s", ws, buffer_path)
                except Exception as e:
                    log.warning("  • WS=%d | failed to load replay buffer: %s", ws, e)

        log.info("  • WS=%d | learn steps=%s | eval_every=%s | resume=%s",
                ws, f"{total_steps:,}", f"{eval_every:,}", resuming)

        t0 = time.time()
        model.learn(
            total_timesteps=total_steps,
            callback=callback,
            reset_num_timesteps=not resuming,
            progress_bar=False,
        )
        dt = time.time() - t0
        log.info("  • WS=%d | learn done in %.1fs", ws, dt)

        # (these lines already exist in your code – keep them as-is)
        model.save(model_path)
        train_env.save(vec_path)
        log.info("  • WS=%d | saved model=%s.zip vecnorm=%s", ws, model_path, vec_path)
        try:
            final_rb_path = model_path + "_replay_buffer.pkl"
            model.save_replay_buffer(final_rb_path)
            log.info("  • WS=%d | saved final replay buffer -> %s", ws, final_rb_path)
        except Exception as e:
            log.warning("  • WS=%d | could not save final replay buffer: %s", ws, e)

        # ---- model_info.json (concise) ----
        try:
            model_info = {
                "algorithm": "SAC",
                "created_at": datetime.now(timezone.utc).isoformat(),
                "symbol": symbol,
                "ws": int(ws),
                "bucket": sac_cfg.get("_bucket", "UNKNOWN"),
                "seed": int(SEED),
                "features_used": int(len(feature_cols)),
                "reward_shaping": {
                    "whipsaw_penalty_enabled": bool(ENABLE_WHIPSAW_PENALTY),
                    "whipsaw_penalty": float(WHIPSAW_PENALTY),
                    "cooldown_step_penalty_enabled": bool(ENABLE_COOLDOWN_PENALTY),
                    "cooldown_step_penalty": float(COOLDOWN_STEP_PENALTY),
                    "confidence_source": "env_penalties"
                },
                "policy": {
                    "type": "MlpPolicy",
                    "net_arch": policy_kwargs.get("net_arch", None),
                    "device": str(device)
                },
                "hyperparams": {
                    "learning_rate": float(loc_cfg["learning_rate"]),
                    "batch_size": int(loc_cfg["batch_size"]),
                    "train_freq": int(loc_cfg["train_freq"]),
                    "gradient_steps": int(loc_cfg["gradient_steps"]),
                    "gamma": float(loc_cfg["gamma"]),
                    "tau": float(loc_cfg["tau"]),
                    "ent_coef": str(loc_cfg["ent_coef"]),
                    "buffer_size": int(loc_cfg["buffer_size"]),
                    "learning_starts": int(loc_cfg["learning_starts"]),
                    "target_update_interval": int(loc_cfg["target_update_interval"]),
                },
                "env_params": {
                    "window_size": int(ws),
                    "max_exposure": float(MAX_EXPOSURE),
                    "commission_bps": float(COMMISSION_BPS),
                    "slippage_bps": float(SLIPPAGE_BPS),
                    "dead_band": float(DEAD_BAND),
                    "min_trade_delta": float(MIN_TRADE_DELTA),
                    "stop_loss_pct": float(STOP_LOSS_PCT) if ENABLE_SLO else None,
                    "take_profit_pct": float(TAKE_PROFIT_PCT) if ENABLE_SLO else None,
                    "cooldown_steps": int(COOLDOWN_STEPS if ENABLE_SLO else 0),
                    "sim_latency_ms": int(SIM_LATENCY_MS),
                    "broker": str(BROKER),
                    "long_only": bool(LONG_ONLY),
                },
                "training": {
                    "total_steps": int(total_steps),
                    "eval_every": int(eval_every),
                    "use_eval_callback": bool(eval_cb is not None),
                    "test_mode": bool(TEST_MODE),
                    "fast_test": bool(FAST_TEST),
                    "deterministic_eval": True
                },
                "artifacts": {
                    "model_zip": model_path + ".zip",
                    "vecnorm_pkl": vec_path
                },
            }
            mi_dir = os.path.join(RESULTS_DIR, "model_info")
            os.makedirs(mi_dir, exist_ok=True)
            mi_path = os.path.join(mi_dir, f"sac_{symbol}_ws{ws}_model_info.json")
            with open(mi_path, "w") as f:
                json.dump(model_info, f, indent=2)
            log.info("  • WS=%d | wrote model_info.json -> %s", ws, mi_path)
        except Exception as e:
            log.warning("  • WS=%d | model_info.json write skipped: %s", ws, e)
            mi_path = None

        final_eval_raw = DummyVecEnv([make_test_env_for_final])
        final_eval_raw.seed(SEED)

        # Defensive load of VecNormalize stats
        if os.path.exists(vec_path):
            final_eval = VecNormalize.load(vec_path, final_eval_raw)
        else:
            final_eval = VecNormalize(
                final_eval_raw, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0
            )

        final_eval.training = False
        final_eval.norm_reward = False
        final_eval.obs_rms = train_env.obs_rms
        final_eval.seed(SEED)
        obs = final_eval.reset()
        portfolio, positions = [], []


        # Confidence bucket accumulators (per-step stats)
        bucket_stats = {
            "clean":    {"step_pnls": [], "count": 0},
            "whipsaw":  {"step_pnls": [], "count": 0},
            "cooldown": {"step_pnls": [], "count": 0},
        }
        for i in range(len(test_df)):
            action, _ = model.predict(obs, deterministic=True)
            obs, _, dones, infos = final_eval.step(action)
            info0 = infos[0]

            portfolio.append(float(info0.get("portfolio_value", np.nan)))
            positions.append(float(info0.get("position", 0.0)))

            if len(portfolio) >= 2:
                step_pnl = (portfolio[-1] / max(portfolio[-2], 1e-12)) - 1.0
                if info0.get("whipsaw_penalty_applied", False):
                    b = "whipsaw"
                elif info0.get("cooldown_penalty_applied", False):
                    b = "cooldown"
                else:
                    b = "clean"
                bucket_stats[b]["step_pnls"].append(float(step_pnl))
                bucket_stats[b]["count"] += 1

            if bool(dones[0]):
                break

        # save signals for this ws
        aligned_len = min(len(test_df), len(portfolio))
        signals_path = f"{RESULTS_DIR}/signals/{symbol}_ws{ws}_sac_signals.csv"
        test_eval_df = test_df.iloc[:aligned_len].copy()
        if aligned_len > 0:
            buy_hold = float(INITIAL_CAPITAL) * (
                test_eval_df["Close"].iloc[-1] / test_eval_df["Close"].iloc[0]
            )
        else:
            buy_hold = float(INITIAL_CAPITAL)

        test_eval_df["Position"] = positions[:aligned_len]
        test_eval_df["Portfolio"] = portfolio[:aligned_len]
        test_eval_df.to_csv(signals_path, index=False)
        log.info("  • WS=%d | wrote signals CSV → %s (rows=%d)", ws, signals_path, aligned_len)

        # try to attach signals path to model_info
        try:
            if mi_path and os.path.exists(mi_path):
                with open(mi_path, "r") as f:
                    _mi = json.load(f)
                _mi.setdefault("artifacts", {})["signals_csv"] = signals_path
                with open(mi_path, "w") as f:
                    json.dump(_mi, f, indent=2)
        except Exception as e:
            log.warning("  • WS=%d | Could not update model_info with signals_csv: %s", ws, e)
        # ---- metrics (Sharpe + extended) ----
        final_value = portfolio[-1] if portfolio else INITIAL_CAPITAL
        ret_series  = pd.Series(portfolio, dtype="float64").pct_change().dropna()
        if len(ret_series) > 1 and ret_series.std() > 0:
            # infer periods/year from median step size of timestamps
            step_delta = pd.to_datetime(test_eval_df["Datetime"]).diff().median()
            step_sec = float(step_delta.total_seconds()) if pd.notna(step_delta) else 0.0
            periods_per_year = int(np.clip(((365*24*3600)/max(step_sec, 1)) if step_sec > 0 else 252, 252, 100_000))
            sharpe = (ret_series.mean() / ret_series.std()) * np.sqrt(max(periods_per_year, 1))
        else:
            sharpe = 0.0

        # always compute drawdowns & trade stats
        equity = pd.Series(portfolio, dtype="float64")
        roll_max = equity.cummax()
        dd = (equity / roll_max) - 1.0
        max_dd = float(dd.min()) if len(dd) else 0.0
        dur_days = (
            (test_eval_df["Datetime"].iloc[aligned_len-1] - test_eval_df["Datetime"].iloc[0]).total_seconds() / 86400.0
            if aligned_len > 1 else 0.0
        )
        ann_ret = ((final_value / float(INITIAL_CAPITAL)) ** (365.0 / max(dur_days, 1e-9)) - 1.0) if dur_days > 0 else 0.0
        calmar = (ann_ret / abs(max_dd)) if (max_dd < 0.0 and abs(max_dd) > 1e-12) else float("nan")

        trades = []
        entry_val = None
        prev_pos = 0.0
        for i in range(aligned_len):
            p = float(positions[i])
            if prev_pos == 0.0 and abs(p) > 0.0:
                entry_val = float(portfolio[i])
            if prev_pos != 0.0 and abs(p) > 0.0 and np.sign(prev_pos) != np.sign(p):
                if entry_val and entry_val > 0.0:
                    trades.append((float(portfolio[i]) - entry_val) / entry_val)
                entry_val = float(portfolio[i])
            if prev_pos != 0.0 and abs(p) == 0.0 and entry_val and entry_val > 0.0:
                trades.append((float(portfolio[i]) - entry_val) / entry_val)
                entry_val = None
            prev_pos = p

        if prev_pos != 0.0 and entry_val and entry_val > 0.0:
            trades.append((float(portfolio[aligned_len - 1]) - entry_val) / entry_val)
        trade_count = int(len(trades))
        hit_ratio = float(np.mean(np.array(trades) > 0.0)) if trade_count > 0 else float("nan")
        avg_trade_pnl = float(np.mean(trades)) if trade_count > 0 else float("nan")


        # Persist trade ledger
        try:
            trade_rows = []
            entry_val = None; prev_pos = 0.0; entry_idx = None
            for i in range(aligned_len):
                p = float(positions[i]); nav = float(portfolio[i])
                if prev_pos == 0.0 and abs(p) > 0.0:
                    entry_val = nav; entry_idx = i
                if prev_pos != 0.0 and abs(p) > 0.0 and np.sign(prev_pos) != np.sign(p):
                    if entry_val and entry_val > 0.0 and entry_idx is not None:
                        trade_rows.append({
                            "entry_idx": int(entry_idx),
                            "exit_idx": int(i),
                            "entry_nav": float(entry_val),
                            "exit_nav": float(nav),
                            "trade_pnl": float((nav - entry_val) / entry_val),
                            "flip_close": True
                        })
                    entry_val = nav; entry_idx = i
                if prev_pos != 0.0 and abs(p) == 0.0 and entry_val and entry_val > 0.0 and entry_idx is not None:
                    trade_rows.append({
                        "entry_idx": int(entry_idx),
                        "exit_idx": int(i),
                        "entry_nav": float(entry_val),
                        "exit_nav": float(nav),
                        "trade_pnl": float((nav - entry_val) / entry_val),
                        "flip_close": False
                    })
                    entry_val = None; entry_idx = None
                prev_pos = p
            if prev_pos != 0.0 and entry_val and entry_idx is not None:
                trade_rows.append({
                    "entry_idx": int(entry_idx),
                    "exit_idx": int(aligned_len - 1),
                    "entry_nav": float(entry_val),
                    "exit_nav": float(portfolio[aligned_len - 1]),
                    "trade_pnl": float((portfolio[aligned_len - 1] - entry_val) / entry_val),
                    "flip_close": False
                })
            trades_csv = f"{RESULTS_DIR}/signals/{symbol}_ws{ws}_sac_trades.csv"
            pd.DataFrame(trade_rows).to_csv(trades_csv, index=False)
            log.info("  • WS=%d | wrote trades CSV → %s (trades=%d)", ws, trades_csv, len(trade_rows))
        except Exception as e:
            trades_csv = None
            log.warning("  • WS=%d | Could not write trades CSV: %s", ws, e)

        pos_arr = np.array(positions[:aligned_len], dtype="float64")
        turnover = float(np.sum(np.abs(np.diff(pos_arr)))) if aligned_len > 1 else 0.0
        time_in_mkt = float(np.mean(np.abs(pos_arr) > 0.0)) if aligned_len > 0 else 0.0

        # Persist WS metrics JSON
        try:
            def _bucket_summary(bs):
                cnt = int(bs.get("count", 0))
                arr = np.array(bs.get("step_pnls", []), dtype="float64")
                mean_pnl = float(np.mean(arr)) if cnt > 0 else float("nan")
                std_pnl  = float(np.std(arr, ddof=1)) if cnt > 1 else float("nan")
                hit_ratio_step = float(np.mean(arr > 0)) if cnt > 0 else float("nan")
                return {
                    "count": cnt,
                    "mean_step_pnl": mean_pnl,
                    "std_step_pnl": std_pnl,
                    "hit_ratio_step": hit_ratio_step,
                }

            for k in ("clean", "whipsaw", "cooldown"):
                bucket_stats.setdefault(k, {"step_pnls": [], "count": 0})

            metrics = {
                "symbol": symbol,
                "ws": int(ws),
                "final_value": float(final_value),
                "buy_hold": float(buy_hold),
                "sharpe": float(sharpe),
                "max_drawdown": float(max_dd),
                "calmar": float(calmar),
                "trade_count": trade_count,
                "hit_ratio": hit_ratio,
                "avg_trade_pnl": avg_trade_pnl,
                "turnover": turnover,
                "time_in_market": time_in_mkt,
                "confidence_buckets": {
                    "clean":    _bucket_summary(bucket_stats["clean"]),
                    "whipsaw":  _bucket_summary(bucket_stats["whipsaw"]),
                    "cooldown": _bucket_summary(bucket_stats["cooldown"]),
                }
            }
            mi_dir = os.path.join(RESULTS_DIR, "model_info"); os.makedirs(mi_dir, exist_ok=True)
            metrics_path = os.path.join(mi_dir, f"sac_{symbol}_ws{ws}_metrics.json")
            with open(metrics_path, "w") as f:
                json.dump(metrics, f, indent=2)
            log.info("  • WS=%d | wrote metrics JSON → %s", ws, metrics_path)
        except Exception as e:
            log.warning("  • WS=%d | WS metrics JSON write skipped: %s", ws, e)

        if 'eval_env' in locals() and eval_env is not None:
            try:
                eval_env.close()
            except Exception:
                pass

        for _env in ( 'train_env', 'final_eval', 'final_eval_raw' ):
            try:
                if _env in locals() and locals()[_env] is not None:
                    locals()[_env].close()
            except Exception:
                pass

        # Null out first (safe even if some names never existed)
        model = train_env = eval_env = final_eval = final_eval_raw = None

        # Now delete names if they exist (optional, helps free locals sooner)
        try:
            del final_eval, final_eval_raw, eval_env, train_env, model
        except NameError:
            pass

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


        return dict(
            ws=ws,
            sharpe=float(sharpe),
            final_value=float(final_value),
            buy_hold=float(buy_hold),
            max_dd=float(max_dd),
            calmar=(float(calmar) if np.isfinite(calmar) else None),
            trade_count=int(trade_count),
            hit_ratio=(float(hit_ratio) if trade_count > 0 else None),
            avg_trade_pnl=(float(avg_trade_pnl) if trade_count > 0 else None),
            turnover=float(turnover),
            time_in_mkt=float(time_in_mkt),
            model_zip=model_path + ".zip",
            vecnorm_pkl=vec_path,
            signals_csv=signals_path
        )

    # === run sweep ===
    results_ws = []
    top_heap = []

    for j, WS in enumerate(pending_ws, 1):
        log.warning("   • %s: WS %d/%d = %s", symbol, j, len(pending_ws), WS)
        try:
            res = train_eval_for_ws(WS)

            results_ws.append({
                "Symbol": symbol,
                "WS": WS,
                "Sharpe": _safe_round(res["sharpe"], 3),
                "SAC_Portfolio": _safe_round(res["final_value"], 2),
                "BuyHold": _safe_round(res["buy_hold"], 2),
                "MaxDD": _safe_round(res.get("max_dd"), 4),
                "Calmar": _safe_round(res.get("calmar"), 3),
                "TradeCount": int(res.get("trade_count", 0)),
                "HitRatio": _safe_round(res.get("hit_ratio"), 3),
                "AvgTradePnL": _safe_round(res.get("avg_trade_pnl"), 4),
                "Turnover": _safe_round(res.get("turnover", 0.0), 3),
                "TimeInMkt": _safe_round(res.get("time_in_mkt", 0.0), 3),
                "ModelZip": res["model_zip"],
                "VecNorm": res["vecnorm_pkl"],
                "SignalsCSV": res["signals_csv"],
                "Edge_vs_BH_%": _safe_round((res["final_value"] / max(res["buy_hold"], 1e-12) - 1.0) * 100.0, 6),
                "SAC_vs_BH_Ratio": _safe_round(res["final_value"] / max(res["buy_hold"], 1e-12), 6),

            })

            s = float(res["sharpe"])
            score = s if (np.isfinite(s) and s == s) else float("-inf")
            heapq.heappush(top_heap, (score, WS, res))
            if len(top_heap) > TOP_N_WINDOWS:
                heapq.heappop(top_heap)

        except Exception as e:
            log.warning("WS=%s failed for %s: %s", WS, symbol, e)

    # save per-symbol WS sweep summary
    ws_summary_path = os.path.join(RESULTS_DIR, f"{symbol}_sac_ws_sweep_summary.csv")

    if results_ws:
        df_ws = pd.DataFrame(results_ws)
        if "Sharpe" not in df_ws.columns:
            df_ws["Sharpe"] = np.nan
        df_ws.sort_values("Sharpe", ascending=False).to_csv(ws_summary_path, index=False)
    else:
        cols = ["Symbol","WS","Sharpe","SAC_Portfolio","BuyHold","MaxDD","Calmar",
                "TradeCount","HitRatio","AvgTradePnL","Turnover","TimeInMkt",
                "ModelZip","VecNorm","SignalsCSV","Edge_vs_BH_%","SAC_vs_BH_Ratio"]
        pd.DataFrame(columns=cols).to_csv(ws_summary_path, index=False)

    try:
        ws_json_path = ws_summary_path.replace(".csv", ".json")
        with open(ws_json_path, "w") as f:
            if results_ws:
                safe_sorted = sorted(
                    results_ws,
                    key=lambda r: (r.get("Sharpe") if (isinstance(r.get("Sharpe"), (int, float)) and np.isfinite(r.get("Sharpe"))) else -1e9),
                    reverse=True
                )
                json.dump(safe_sorted, f, indent=2)
            else:
                json.dump([], f, indent=2)
        log.info("Wrote WS sweep JSON -> %s", ws_json_path)
    except Exception as e:
        ws_json_path = None
        log.warning("WS sweep JSON write skipped: %s", e)

    global_rows.extend(results_ws)

    # --- per-symbol console summary (INSIDE LOOP) ---
    def _sharpe_key(row):
        try:
            s = float(row.get("Sharpe"))
            return s if isfinite(s) else float("-inf")
        except Exception:
            return float("-inf")

    def _fmt_num(v, fmt):
        try:
            x = float(v)
            return fmt % x if isfinite(x) else "nan"
        except Exception:
            return "nan"

    if results_ws:
        best_row = max(results_ws, key=_sharpe_key)
        log.warning(
            "%s summary → best WS=%s | Sharpe=%s | Final=%s | B&H=%s | MaxDD=%s | Trades=%d | Hit=%s",
            symbol,
            best_row.get("WS"),
            _fmt_num(best_row.get("Sharpe"),        "%.3f"),
            _fmt_num(best_row.get("SAC_Portfolio"), "%.2f"),
            _fmt_num(best_row.get("BuyHold"),       "%.2f"),
            _fmt_num(best_row.get("MaxDD"),         "%.4f"),
            int(best_row.get("TradeCount", 0)),
            _fmt_num(best_row.get("HitRatio"),      "%.3f"),
        )
        if ws_json_path:
            log.warning("Artifacts: %s | %s", ws_summary_path, ws_json_path)
        else:
            log.warning("Artifacts: %s", ws_summary_path)
    else:
        log.warning("%s summary → no successful window results.", symbol)

    # report winners
    top_sorted = sorted(top_heap, key=lambda t: t[0], reverse=True)
    log.info(
        f"Top {min(TOP_N_WINDOWS, len(top_sorted))} window sizes for {symbol}: " +
        ", ".join([f"WS={w} (Sharpe={s:.3f})" for s, w, _ in top_sorted])
    )

    # === Persist Top-N winners' artifacts ===
    best_root = os.path.join(RESULTS_DIR, "best", symbol)
    os.makedirs(best_root, exist_ok=True)

    top_k = min(TOP_N_WINDOWS, len(top_sorted))
    manifest = []
    for rank, (s_val, ws_best, res_obj) in enumerate(top_sorted[:top_k], start=1):
        ws_dir = os.path.join(best_root, f"ws{int(ws_best)}")
        os.makedirs(ws_dir, exist_ok=True)

        src_model  = res_obj["model_zip"]
        src_vec    = res_obj["vecnorm_pkl"]
        src_sig    = res_obj["signals_csv"]
        src_metrics_json = os.path.join(RESULTS_DIR, "model_info", f"sac_{symbol}_ws{ws_best}_metrics.json")
        src_model_info   = os.path.join(RESULTS_DIR, "model_info", f"sac_{symbol}_ws{ws_best}_model_info.json")
        src_conf_json    = os.path.join(RESULTS_DIR, "model_info", f"sac_{symbol}_ws{ws_best}_confidence_metrics.json")  # may not exist

        def _maybe_copy(src, dst_name):
            try:
                if src and os.path.exists(src):
                    dst = os.path.join(ws_dir, dst_name)
                    copyfile(src, dst)
                    return dst
            except Exception as e:
                log.warning("Copy failed %s -> %s: %s", src, dst_name, e)
            return None

        dsts = {
            "model_zip": _maybe_copy(src_model,  os.path.basename(src_model)),
            "vecnorm_pkl": _maybe_copy(src_vec,  os.path.basename(src_vec)),
            "signals_csv": _maybe_copy(src_sig,  os.path.basename(src_sig)),
            "metrics_json": _maybe_copy(src_metrics_json, os.path.basename(src_metrics_json)),
            "model_info_json": _maybe_copy(src_model_info, os.path.basename(src_model_info)),
            "confidence_json": _maybe_copy(src_conf_json, os.path.basename(src_conf_json)),
        }

        manifest.append({
            "rank": rank,
            "ws": int(ws_best),
            "sharpe": float(s_val),
            "artifacts": dsts,
        })

    try:
        top_manifest_path = os.path.join(best_root, "top_windows.json")
        with open(top_manifest_path, "w") as f:
            json.dump(manifest, f, indent=2)
        log.info("Persisted Top-%d winners for %s -> %s", top_k, symbol, top_manifest_path)
    except Exception as e:
        log.warning("Top-N manifest write failed for %s: %s", symbol, e)

    del results_ws, top_heap, top_sorted
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# --- Global skip summary (optional) ---
try:
    if skipped_all:
        skip_path = os.path.join(RESULTS_DIR, "skipped_windows_global.csv")
        pd.DataFrame({"Symbol": sorted(set(skipped_all))}).to_csv(skip_path, index=False)
        log.info("Fully skipped tickers (all windows done): %s", ", ".join(sorted(set(skipped_all))))
        log.info("Global skip log: %s", skip_path)
    else:
        log.info("No fully skipped tickers.")
except Exception as e:
    log.warning("Failed to write global skip log: %s", e)

log.info("✅ SAC run complete.")
# === Global summarizer (drop-in) ===
try:
    # Prefer in-memory rows collected this run; otherwise, sweep the folder.
    if "global_rows" in globals() and global_rows:
        all_rows_df = pd.DataFrame(global_rows)
    else:
        csvs = sorted(glob.glob(os.path.join(RESULTS_DIR, "*_sac_ws_sweep_summary.csv")))
        all_rows_df = pd.concat((pd.read_csv(p) for p in csvs), ignore_index=True) if csvs else pd.DataFrame()

    if not all_rows_df.empty:
        # Ensure derived columns exist (in case older CSVs lacked them)
        if "Edge_vs_BH_%" not in all_rows_df.columns:
            all_rows_df["Edge_vs_BH_%"] = (all_rows_df["SAC_Portfolio"] / all_rows_df["BuyHold"] - 1.0) * 100.0
        if "SAC_vs_BH_Ratio" not in all_rows_df.columns:
            all_rows_df["SAC_vs_BH_Ratio"] = (all_rows_df["SAC_Portfolio"] / all_rows_df["BuyHold"])

        # Save all rows
        out_all = os.path.join(RESULTS_DIR, "global_ws_sweep_all_rows.csv")
        all_rows_df.to_csv(out_all, index=False)

        # Backfill selector/filter columns if missing in older CSVs
        for col, default in [
            ("TimeInMkt", np.nan),
            ("TradeCount", 0),
            ("Turnover", np.nan),
            ("Calmar", np.nan),
            ("HitRatio", np.nan),
        ]:
            if col not in all_rows_df.columns:
                all_rows_df[col] = default


        # Best WS per symbol by highest Sharpe
        best_by_symbol = (
            all_rows_df.sort_values(["Symbol", "Sharpe"], ascending=[True, False])
                       .groupby("Symbol", as_index=False).head(1)
        )
        out_best = os.path.join(RESULTS_DIR, "global_best_by_symbol.csv")
        best_by_symbol.to_csv(out_best, index=False)

        # === Save model selector CSV for later pipeline ===
        try:
            # Guardrails
            flt = (
                (best_by_symbol["Sharpe"] > 0.0)
                & (best_by_symbol["SAC_Portfolio"] > best_by_symbol["BuyHold"])
                & (best_by_symbol["TimeInMkt"].between(0.20, 0.85, inclusive="both"))
                & (best_by_symbol["TradeCount"] >= 50)
                & (best_by_symbol["Turnover"] <= 600)
            )
            selector_src = best_by_symbol.loc[flt].copy()
            if selector_src.empty:
                selector_src = best_by_symbol.copy()  # fallback

            # Score = Sharpe + small bonuses/penalties
            selector_src["Score"] = (
                selector_src["Sharpe"]
                + 0.10 * selector_src["Calmar"].fillna(0)
                - 0.0005 * selector_src["Turnover"].fillna(0)
            )
            selector_src = selector_src.sort_values(["Score", "Sharpe"], ascending=False)

            # Final selector schema
            selector_df = selector_src[["Symbol", "Sharpe", "HitRatio", "MaxDD", "SAC_Portfolio"]].rename(columns={
                "Symbol": "Ticker",
                "HitRatio": "Accuracy",
                "MaxDD": "Drawdown",
                "SAC_Portfolio": "Final_Portfolio"
            })
            selector_df["Model"] = "SAC"

            selector_out = os.path.join(RESULTS_DIR, "sac_model_selector.csv")
            selector_df.to_csv(selector_out, index=False)
            print(f"✅ Saved model selector file → {selector_out} (rows={len(selector_df)})")
        except Exception as e:
            print(f"[selector_writer] Failed to save model selector CSV: {e}")

        # Console summary
        syms = best_by_symbol["Symbol"].nunique()
        mean_sharpe = float(best_by_symbol["Sharpe"].mean())
        median_sharpe = float(best_by_symbol["Sharpe"].median())
        n_pos = int((best_by_symbol["Sharpe"] > 0).sum())
        n_ge1 = int((best_by_symbol["Sharpe"] >= 1.0).sum())
        beat_bh = int((best_by_symbol["SAC_Portfolio"] > best_by_symbol["BuyHold"]).sum())

        print(f"[summarizer] Using RESULTS_DIR = {RESULTS_DIR}\n")
        print("=== Summary ===")
        print(f"Symbols covered:         {syms}")
        print(f"Mean Sharpe (best WS):   {mean_sharpe:.3f}")
        print(f"Median Sharpe (best WS): {median_sharpe:.3f}")
        print(f"Sharpe > 0 (best WS):    {n_pos} / {syms}")
        print(f"Sharpe ≥ 1.0 (best WS):  {n_ge1} / {syms}")
        print(f"Beat Buy&Hold (best WS): {beat_bh} / {syms}\n")

        # Top 10 windows overall
        cols = ["Symbol", "WS", "Sharpe", "Edge_vs_BH_%", "SAC_vs_BH_Ratio", "MaxDD", "Calmar"]
        top10 = all_rows_df.sort_values("Sharpe", ascending=False).head(10)
        top10 = top10[[c for c in cols if c in top10.columns]]
        print("Top 10 windows overall by Sharpe:")
        print(top10.to_string(index=False))

        print("\nSaved:")
        print(" -", out_all)
        print(" -", out_best)
    else:
        print("[summarizer] No rows found to summarize.")
except Exception as e:
    print("[summarizer] Failed:", e)

In [None]:
# === SAC Model Selector: Full Aggregator + Enhancer ===
import os, glob, json
import numpy as np
import pandas as pd

# --- Paths (align with your SAC trainer) ---
DRIVE_BASE  = os.getenv("DRIVE_BASE", "/content/drive/MyDrive")
RESULTS_DIR = os.path.join(DRIVE_BASE, "Results_May_2025", "results_sac_walkforward")

# Where SAC models/vecnorms are written by the trainer:
FINAL_MODEL_DIR = os.path.join(RESULTS_DIR, "models_sac")   # contains sac_{SYM}_ws{WS}.zip
VECNORM_DIR     = os.path.join(RESULTS_DIR, "vecnorms")     # contains {SYM}_ws{WS}_vecnorm.pkl

# Selector outputs
SELECTOR_FULL_PATH = os.path.join(RESULTS_DIR, "sac_model_selector_FULL.csv")
SELECTOR_JSON_PATH = os.path.join(RESULTS_DIR, "sac_model_selector_final.json")
MODEL_NAME = "SAC"

# --- 1) Collect all per-symbol sweep summaries ---
summary_files = glob.glob(os.path.join(RESULTS_DIR, "*_sac_ws_sweep_summary.csv"))
if not summary_files:
    raise SystemExit("❌ No SAC sweep summaries found (expected *_sac_ws_sweep_summary.csv).")

frames = []
for p in summary_files:
    try:
        dfp = pd.read_csv(p)
        dfp["SourceFile"] = os.path.basename(p)
        frames.append(dfp)
    except Exception as e:
        print(f"⚠️ Skipping {p} due to error: {e}")

if not frames:
    raise SystemExit("❌ No readable SAC summaries.")

raw = pd.concat(frames, ignore_index=True)

# Ensure expected columns exist (older runs may miss some)
for col in ["Symbol","WS","Sharpe","SAC_Portfolio","BuyHold","MaxDD","Calmar","HitRatio","Turnover","TimeInMkt"]:
    if col not in raw.columns:
        raw[col] = np.nan

# Drop dups by (Symbol, WS) keeping the last written row
raw = raw.drop_duplicates(subset=["Symbol","WS"], keep="last")

# --- 2) Pick best WS per symbol by Sharpe ---
# Treat NaN/inf Sharpe as very low
def _safe_sharpe(x):
    try:
        v = float(x)
        return v if np.isfinite(v) else -1e9
    except Exception:
        return -1e9

raw["_SharpeKey"] = raw["Sharpe"].apply(_safe_sharpe)
best = (raw.sort_values(["Symbol","_SharpeKey"], ascending=[True, False])
            .groupby("Symbol", as_index=False).head(1).drop(columns=["_SharpeKey"]))

# --- 3) Normalize to selector schema ---
selector = best.rename(columns={
    "Symbol": "Ticker",
    "WS": "Window",
    "SAC_Portfolio": "Final_Portfolio",
    "MaxDD": "Drawdown",
    "HitRatio": "Accuracy",
})
selector["Model"] = MODEL_NAME

# Backfills for optional columns
for col, default in [
    ("Calmar", np.nan),
    ("Turnover", np.nan),
    ("TimeInMkt", np.nan),
    ("BuyHold", np.nan),
]:
    if col not in selector.columns:
        selector[col] = default

# --- 4) Add artifact paths (match your trainer’s save patterns) ---
# sac_{TICKER}_ws{Window}.zip and {TICKER}_ws{Window}_vecnorm.pkl
selector["Window"] = selector["Window"].astype(int)
selector["artifact_path"] = selector.apply(
    lambda r: os.path.join(FINAL_MODEL_DIR, f"sac_{r['Ticker']}_ws{int(r['Window'])}.zip"), axis=1
)
selector["vecnorm_path"] = selector.apply(
    lambda r: os.path.join(VECNORM_DIR, f"{r['Ticker']}_ws{int(r['Window'])}_vecnorm.pkl"), axis=1
)
selector["artifact_exists"] = selector["artifact_path"].apply(os.path.exists)
selector["vecnorm_exists"]  = selector["vecnorm_path"].apply(os.path.exists)

# --- 5) Safety gates (light; adjust as you like)
# Drawdown is negative (e.g., -0.35). Keep those above -0.6 (i.e., not worse than -60%).
gates = (
    (selector["Sharpe"] > 0.0) &
    (selector["Drawdown"] > -0.60) &
    (selector["artifact_exists"]) &
    (selector["vecnorm_exists"])
)
filtered = selector.loc[gates].copy()
if filtered.empty:
    # Fallback: if filters are too strict, export all best picks so downstream can decide.
    print("⚠️ All candidates filtered out—exporting unfiltered best picks.")
    filtered = selector.copy()

# --- 6) Save aggregated CSV (FULL) ---
filtered.to_csv(SELECTOR_FULL_PATH, index=False)
print(f"✅ Aggregated SAC selector saved to → {SELECTOR_FULL_PATH} (rows={len(filtered)})")

# --- 7) Build final JSON with metadata & tie logic (optional ensemble on near-ties) ---
# Recompute per-ticker near-tie on Sharpe if multiple rows per ticker survived (rare here since best-per-ticker).
EPS = 0.03  # 3% relative Sharpe tie

selected_models = {}
for ticker, group in filtered.groupby("Ticker"):
    # In practice 'filtered' has 1 row per ticker; but support extras if you keep more than best later.
    group_sorted = group.sort_values("Sharpe", ascending=False)
    top = group_sorted.iloc[0]
    second = group_sorted.iloc[1] if len(group_sorted) > 1 else None

    if second is not None and abs(top["Sharpe"] - second["Sharpe"]) <= abs(top["Sharpe"]) * EPS:
        mode = "ensemble"
        secondary = {
            "model": MODEL_NAME,
            "window": int(second["Window"]),
            "artifact": {
                "path": second["artifact_path"],
                "vecnorm": second["vecnorm_path"],
                "exists": bool(second["artifact_exists"]),
                "vecnorm_exists": bool(second["vecnorm_exists"]),
            }
        }
    else:
        mode = "single"
        secondary = None

    selected_models[ticker] = {
        "model": MODEL_NAME,
        "window": int(top["Window"]),
        "score": round(float(top["Sharpe"]), 4),
        "return": round(float(top["Final_Portfolio"]), 2) if pd.notna(top["Final_Portfolio"]) else None,
        "sharpe": round(float(top["Sharpe"]), 3) if pd.notna(top["Sharpe"]) else None,
        "drawdown": round(float(top["Drawdown"]), 4) if pd.notna(top["Drawdown"]) else None,
        "sortino": None,
        "turnover": float(top["Turnover"]) if pd.notna(top["Turnover"]) else None,
        "trade_count": None,
        "stability": {},
        "regime": "unknown",
        "rl_profile": "sac",
        "artifact": {
            "path": top["artifact_path"],
            "vecnorm": top["vecnorm_path"],
            "features": None,
            "load_ms": 180,
            "mem_mb": 512,
            "exists": bool(top["artifact_exists"]),
            "vecnorm_exists": bool(top["vecnorm_exists"]),
        },
        "selection": {
            "mode": mode,
            "primary": MODEL_NAME,
            "secondary": secondary,
        }
    }

with open(SELECTOR_JSON_PATH, "w") as f:
    json.dump(selected_models, f, indent=2)

print(f"✅ Final enhanced SAC selector JSON saved to → {SELECTOR_JSON_PATH}")


In [None]:
import os, pandas as pd, matplotlib.pyplot as plt
from glob import glob

RESULTS_DIR = "/content/drive/MyDrive/Results_May_2025/results_sac_walkforward"
winners = [("NVDA",24), ("AAPL",24), ("MSFT",16)]  # requested windows

def pick_signals(sym, ws):
    p = os.path.join(RESULTS_DIR, "signals", f"{sym}_ws{ws}_sac_signals.csv")
    if os.path.exists(p): return p
    p2 = os.path.join(RESULTS_DIR, "best", sym, f"ws{ws}", f"{sym}_ws{ws}_sac_signals.csv")
    if os.path.exists(p2): return p2
    alts = sorted(glob(os.path.join(RESULTS_DIR, "signals", f"{sym}_ws*_sac_signals.csv")))
    if alts:
        print(f"[{sym}] ws{ws} not found; using {os.path.basename(alts[-1])} instead.")
        return alts[-1]
    raise FileNotFoundError(f"No signals CSV found for {sym} under {RESULTS_DIR}.")

for sym, ws in winners:
    path = pick_signals(sym, ws)
    s = pd.read_csv(path, parse_dates=["Datetime"])
    if "Close" not in s.columns:
        raise RuntimeError(f"'Close' not in {path}; re-run training that writes Close into signals.")
    base = float(s["Portfolio"].iloc[0])
    bh = s["Close"] / s["Close"].iloc[0] * base
    plt.figure()
    plt.plot(s["Datetime"], s["Portfolio"], label="SAC")
    plt.plot(s["Datetime"], bh, label="Buy & Hold")
    plt.title(f"{sym} ({os.path.basename(path)})")
    plt.legend(); plt.xlabel("Date"); plt.ylabel("NAV")
    plt.show()
