<a href="https://colab.research.google.com/github/racoope70/exploratory_daytrading/blob/main/ppo_alpaca_paper_trading_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# Clean any partials
!pip uninstall -y stable-baselines3 shimmy gymnasium gym autorom AutoROM.accept-rom-license ale-py

# Install the compatible trio (no [extra] to avoid Atari deps)
!pip install "gymnasium==0.29.1" "shimmy==1.3.0" "stable-baselines3==2.3.0"

# Your other libs (safe to keep separate)
!pip install alpaca-trade-api ta python-dotenv gym-anytrading


In [None]:
import torch, gymnasium, shimmy, stable_baselines3 as sb3
import alpaca_trade_api, websockets, pywt

print("torch:", torch.__version__)
print("gymnasium:", gymnasium.__version__)
print("shimmy:", shimmy.__version__)
print("stable-baselines3:", sb3.__version__)
print("alpaca-trade-api:", alpaca_trade_api.__version__)
print("websockets:", websockets.__version__)
print("pywavelets:", pywt.__version__)


In [None]:
from functools import lru_cache
import os, re, json, csv, shutil, logging, pickle, warnings, time, math, gc
from pathlib import Path
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple, Union, Mapping
from dataclasses import dataclass

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")   # save-to-file only; no inline rendering
import matplotlib.pyplot as plt

from dotenv import load_dotenv
import alpaca_trade_api as tradeapi
from alpaca_trade_api.rest import TimeFrame, APIError
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN


def round_to_cents(x: float) -> float:
    return float(Decimal(str(x)).quantize(Decimal("0.01"), rounding=ROUND_DOWN))

# Detect Colab and (optionally) mount Drive
IN_COLAB = False
try:
    import google.colab  # type: ignore
    from google.colab import drive, files  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    try:
        drive.mount("/content/drive", force_remount=False)
    except Exception:
        pass

# Project root (Drive in Colab; cwd locally)
if IN_COLAB:
    PROJECT_ROOT = Path("/content/drive/MyDrive/AlpacaPaper")
else:
    PROJECT_ROOT = Path.cwd() / "AlpacaPaper"

# Ensure project folders exist early
PROJECT_ROOT.mkdir(parents=True, exist_ok=True)

# ---------------------------------- Upload / Conversion Helpers -------------------------------
def upload_env_and_artifacts_in_colab():
    """
    In Colab this will prompt for:
      1) .env (or Alpaca_keys.env.txt)  -> moves to PROJECT_ROOT/.env
      2) Any model/feature/vecnorm files -> moves to ARTIFACTS_DIR (or fallback to PROJECT_ROOT/artifacts)
    """
    if not IN_COLAB:
        return

    target_dir = Path(os.getenv("ARTIFACTS_DIR", str(PROJECT_ROOT / "artifacts")))
    target_dir.mkdir(parents=True, exist_ok=True)

    print("Upload your .env (or Alpaca_keys.env.txt). Cancel if already on Drive.")
    up = files.upload()
    if up:
        if "Alpaca_keys.env.txt" in up:
            src = Path("Alpaca_keys.env.txt")
            dst = PROJECT_ROOT / ".env"
            shutil.move(str(src), str(dst))
            print(f"Saved env → {dst}")
        elif ".env" in up:
            src = Path(".env")
            dst = PROJECT_ROOT / ".env"
            shutil.move(str(src), str(dst))
            print(f"Saved env → {dst}")
        else:
            any_name = next(iter(up.keys()))
            src = Path(any_name)
            dst = PROJECT_ROOT / ".env"
            shutil.move(str(src), str(dst))
            print(f"Saved env (renamed {any_name}) → {dst}")

    print("Upload your artifacts (ppo_*_model.zip, *_vecnorm*.pkl, *_features*.json or .txt).")
    up2 = files.upload()
    for name in up2.keys():
        shutil.move(name, target_dir / name)
    print("Artifacts now in:", sorted(p.name for p in target_dir.iterdir()))

def _maybe_convert_features_txt_to_json():
    """
    Convert any 'features_<TICKER>.txt' into 'ppo_<TICKER>_features.json' (simple list).
    """
    art_dir = Path(os.getenv("ARTIFACTS_DIR", str(PROJECT_ROOT / "artifacts")))
    art_dir.mkdir(parents=True, exist_ok=True)
    for p in art_dir.glob("features_*.txt"):
        ticker = re.sub(r"^features_|\.txt$", "", p.name, flags=re.IGNORECASE)
        try:
            raw = p.read_text().strip()
            items = [x.strip() for x in raw.replace(",", "\n").splitlines() if x.strip()]
            out = {"features": items}
            out_path = art_dir / f"ppo_{ticker}_features.json"
            out_path.write_text(json.dumps(out, indent=2))
            print(f"Converted {p.name} → {out_path.name}  ({len(items)} features)")
        except Exception as e:
            print(f"Could not convert {p.name}: {e}")

def _maybe_rename_vecnorm_scaler():
    """
    Rename any 'scaler_<TICKER>.pkl' to 'ppo_<TICKER>_vecnorm.pkl'.
    """
    art_dir = Path(os.getenv("ARTIFACTS_DIR", str(PROJECT_ROOT / "artifacts")))
    art_dir.mkdir(parents=True, exist_ok=True)
    for p in art_dir.glob("scaler_*.pkl"):
        ticker = re.sub(r"^scaler_|\.pkl$", "", p.name, flags=re.IGNORECASE)
        dst = art_dir / f"ppo_{ticker}_vecnorm.pkl"
        if not dst.exists():
            shutil.move(str(p), str(dst))
            print(f"Renamed {p.name} → {dst.name}")

# ---------------------------------- Env & logging --------------------------------------------
warnings.filterwarnings("ignore")

# Load env (supports PROJECT_ROOT/.env)
env_candidates = [PROJECT_ROOT / ".env", Path(".env")]
for env_path in env_candidates:
    if env_path.exists():
        load_dotenv(dotenv_path=env_path, override=True)
        break
else:
    load_dotenv(override=True)  # fallback to default search

logging.basicConfig(level=logging.DEBUG, format="%(asctime)s | %(levelname)s | %(message)s")

# Centralize Knobs
def _to_bool(x: str) -> bool:
    return str(x).strip().lower() in ("1","true","yes","y","on")

def _to_list_csv(x: str) -> list:
    return [s.strip().upper() for s in str(x).split(",") if s.strip()]

@dataclass
class Knobs:
    # API / mode
    APCA_API_BASE_URL: str = "https://paper-api.alpaca.markets"
    DRY_RUN: bool          = False       # False => place PAPER orders on PAPER endpoint
    AUTO_RUN_LIVE: bool    = True
    INF_DETERMINISTIC: bool= True

    # Universe / files
    TICKERS: list          = None
    ARTIFACTS_DIR: str     = ""
    RESULTS_ROOT: str      = ""

    # Data feed / cadence / staleness
    BARS_FEED: str         = "iex"          # "" lets Alpaca choose; "iex" for IEX
    COOLDOWN_MIN: int      = 1
    STALE_MAX_SEC: int     = 1800

    # Sizing & entry/exit sensitivity
    SIZING_MODE: str       = "linear"    # "linear" | "threshold"
    WEIGHT_CAP: float      = 0.35
    CONF_FLOOR: float      = 0.20        # threshold-mode only
    ENTER_CONF_MIN: float  = 0.005
    ENTER_WEIGHT_MIN: float= 0.010
    EXIT_WEIGHT_MAX: float = 0.003
    REBALANCE_MIN_NOTIONAL: float = 1.00
    USE_FRACTIONALS: bool  = True
    SEED_FIRST_SHARE: bool = True
    ALLOW_SHORTS: bool = False

    # add-ons
    DELTA_WEIGHT_MIN: float = 0.005
    RAW_POS_MIN: float = 0.00
    RAW_NEG_MAX: float = 0.00

    # Risk
    TAKE_PROFIT_PCT: float = 0.05
    STOP_LOSS_PCT: float   = 0.03

    # Misc
    STALE_BEST_WINDOW: str = ""    # e.g. "3" (exposed as BEST_WINDOW_ENV)

    # Secrets
    APCA_API_KEY_ID: str   = ""
    APCA_API_SECRET_KEY: str = ""

    @classmethod
    def from_env(cls, defaults: "Knobs", project_root: Path, env: Mapping[str, str], overrides: Mapping[str, object] = None):
        kv = {**defaults.__dict__}
        kv.update({
            "APCA_API_BASE_URL": env.get("APCA_API_BASE_URL", kv["APCA_API_BASE_URL"]),
            "AUTO_RUN_LIVE":     _to_bool(env.get("AUTO_RUN_LIVE", str(kv["AUTO_RUN_LIVE"]))),
            "DRY_RUN":           _to_bool(env.get("DRY_RUN",       str(kv["DRY_RUN"]))),
            "INF_DETERMINISTIC": _to_bool(env.get("INF_DETERMINISTIC", str(kv["INF_DETERMINISTIC"]))),

            "TICKERS":           _to_list_csv(env.get("TICKERS", ",".join(kv["TICKERS"] or ["UNH","GE"]))),
            "ARTIFACTS_DIR":     env.get("ARTIFACTS_DIR", kv["ARTIFACTS_DIR"] or str(project_root / "artifacts")),
            "RESULTS_ROOT":      env.get("RESULTS_ROOT",  kv["RESULTS_ROOT"]  or str(project_root / "results")),

            "BARS_FEED":         env.get("BARS_FEED", kv["BARS_FEED"]),
            "COOLDOWN_MIN":      int(env.get("COOLDOWN_MIN", str(kv["COOLDOWN_MIN"])) or kv["COOLDOWN_MIN"]),
            "STALE_MAX_SEC":     int(env.get("STALE_MAX_SEC", str(kv["STALE_MAX_SEC"])) or kv["STALE_MAX_SEC"]),

            "SIZING_MODE":       env.get("SIZING_MODE", kv["SIZING_MODE"]),
            "WEIGHT_CAP":        float(env.get("WEIGHT_CAP",        str(kv["WEIGHT_CAP"]))),
            "CONF_FLOOR":        float(env.get("CONF_FLOOR",        str(kv["CONF_FLOOR"]))),
            "ENTER_CONF_MIN":    float(env.get("ENTER_CONF_MIN",    str(kv["ENTER_CONF_MIN"]))),
            "ENTER_WEIGHT_MIN":  float(env.get("ENTER_WEIGHT_MIN",  str(kv["ENTER_WEIGHT_MIN"]))),
            "EXIT_WEIGHT_MAX":   float(env.get("EXIT_WEIGHT_MAX",   str(kv["EXIT_WEIGHT_MAX"]))),
            "REBALANCE_MIN_NOTIONAL": float(env.get("REBALANCE_MIN_NOTIONAL", str(kv["REBALANCE_MIN_NOTIONAL"]))),
            "USE_FRACTIONALS":   _to_bool(env.get("USE_FRACTIONALS", str(kv["USE_FRACTIONALS"]))),
            "SEED_FIRST_SHARE":  _to_bool(env.get("SEED_FIRST_SHARE",str(kv["SEED_FIRST_SHARE"]))),

            "TAKE_PROFIT_PCT":   float(env.get("TAKE_PROFIT_PCT",   str(kv["TAKE_PROFIT_PCT"]))),
            "STOP_LOSS_PCT":     float(env.get("STOP_LOSS_PCT",     str(kv["STOP_LOSS_PCT"]))),

            "DELTA_WEIGHT_MIN": float(env.get("DELTA_WEIGHT_MIN", str(kv.get("DELTA_WEIGHT_MIN", 0.02)))),
            "RAW_POS_MIN":      float(env.get("RAW_POS_MIN",      str(kv.get("RAW_POS_MIN", 0.00)))),
            "RAW_NEG_MAX":      float(env.get("RAW_NEG_MAX",      str(kv.get("RAW_NEG_MAX", 0.00)))),

            "STALE_BEST_WINDOW": env.get("BEST_WINDOW", kv["STALE_BEST_WINDOW"]),
        })
        kv["APCA_API_KEY_ID"]     = env.get("APCA_API_KEY_ID")     or env.get("ALPACA_API_KEY_ID", "")     or ""
        kv["APCA_API_SECRET_KEY"] = env.get("APCA_API_SECRET_KEY") or env.get("ALPACA_API_SECRET_KEY", "") or ""
        if overrides:
            for k, v in overrides.items():
                if k.upper() == "TICKERS" and isinstance(v, str):
                    v = _to_list_csv(v)
                kv[k] = v
        return cls(**kv)

    def apply_to_globals(self):
        g = globals()
        g["BASE_URL"]           = self.APCA_API_BASE_URL
        g["DRY_RUN"]            = bool(self.DRY_RUN)
        g["INF_DETERMINISTIC"]  = bool(self.INF_DETERMINISTIC)

        g["TICKERS"]            = list(self.TICKERS or ["UNH","GE"])
        g["ARTIFACTS_DIR"]      = Path(self.ARTIFACTS_DIR)
        g["RESULTS_ROOT"]       = Path(self.RESULTS_ROOT)
        g["RESULTS_DIR"]        = RESULTS_ROOT / datetime.now(timezone.utc).strftime("%Y-%m-%d")
        g["LATEST_DIR"]         = RESULTS_ROOT / "latest"
        for p in (ARTIFACTS_DIR, RESULTS_DIR, LATEST_DIR):
            p.mkdir(parents=True, exist_ok=True)

        g["BARS_FEED"]          = str(self.BARS_FEED).strip()
        g["COOLDOWN_MIN"]       = int(self.COOLDOWN_MIN)
        g["STALE_MAX_SEC"]      = int(self.STALE_MAX_SEC)

        g["SIZING_MODE"]        = self.SIZING_MODE
        g["WEIGHT_CAP"]         = float(self.WEIGHT_CAP)
        g["ENTER_CONF_MIN"]     = float(self.ENTER_CONF_MIN)
        g["ENTER_WEIGHT_MIN"]   = float(self.ENTER_WEIGHT_MIN)
        g["EXIT_WEIGHT_MAX"]    = float(self.EXIT_WEIGHT_MAX)
        g["REBALANCE_MIN_NOTIONAL"] = float(self.REBALANCE_MIN_NOTIONAL)
        g["USE_FRACTIONALS"]    = bool(self.USE_FRACTIONALS)
        g["SEED_FIRST_SHARE"]   = bool(self.SEED_FIRST_SHARE)
        g["ALLOW_SHORTS"]       = bool(self.ALLOW_SHORTS)
        g["CONF_FLOOR"]         = float(self.CONF_FLOOR)
        g["TAKE_PROFIT_PCT"]    = float(self.TAKE_PROFIT_PCT)
        g["STOP_LOSS_PCT"]      = float(self.STOP_LOSS_PCT)

        g["BEST_WINDOW_ENV"]    = (self.STALE_BEST_WINDOW or None)

        g["API_KEY"]    = self.APCA_API_KEY_ID or ""
        g["API_SECRET"] = self.APCA_API_SECRET_KEY or ""

        g["TRADE_LOG_CSV"]      = RESULTS_DIR / "trade_log_master.csv"
        g["EQUITY_LOG_CSV"]     = RESULTS_DIR / "equity_log.csv"
        g["PLOT_PATH"]          = RESULTS_DIR / "equity_curve.png"
        g["PLOT_PATH_LATEST"]   = LATEST_DIR / "equity_curve.png"
        g["EQUITY_LOG_LATEST"]  = LATEST_DIR / "equity_log.csv"
        g["TRADE_LOG_LATEST"]   = LATEST_DIR / "trade_log_master.csv"
        g["DELTA_WEIGHT_MIN"]   = float(self.DELTA_WEIGHT_MIN)
        g["RAW_POS_MIN"]        = float(self.RAW_POS_MIN)
        g["RAW_NEG_MAX"]        = float(self.RAW_NEG_MAX)

        os.environ["APCA_API_BASE_URL"] = self.APCA_API_BASE_URL
        os.environ["DRY_RUN"]           = "1" if self.DRY_RUN else "0"
        os.environ["AUTO_RUN_LIVE"]     = "1" if self.AUTO_RUN_LIVE else "0"
        os.environ["BARS_FEED"]         = self.BARS_FEED


def configure_knobs(overrides: Mapping[str, object] = None) -> Knobs:
    defaults = Knobs(
        TICKERS=_to_list_csv(os.getenv("TICKERS", "UNH,GE")),
        ARTIFACTS_DIR=os.getenv("ARTIFACTS_DIR", str(PROJECT_ROOT / "artifacts")),
        RESULTS_ROOT=os.getenv("RESULTS_ROOT",  str(PROJECT_ROOT / "results")),
    )
    cfg = Knobs.from_env(defaults, PROJECT_ROOT, os.environ, overrides=overrides)
    cfg.apply_to_globals()
    return cfg

# ---------------------------------- Utility: time ---------------------------------------------
def now_utc() -> datetime:
    return datetime.now(timezone.utc)

def utc_ts(dt_like) -> int:
    if isinstance(dt_like, (int, np.integer)):
        return int(dt_like)
    if isinstance(dt_like, (float, np.floating)):
        return int(dt_like)
    ts = pd.Timestamp(dt_like)
    if ts.tzinfo is None:
        ts = ts.tz_localize("UTC")
    else:
        ts = ts.tz_convert("UTC")
    return int(ts.value // 10**9)

def utcnow_iso() -> str:
    return datetime.now(timezone.utc).isoformat()

def _sleep_to_next_minute_block(n: int):
    n = max(1, int(n))
    now = now_utc()
    base = now.replace(second=0, microsecond=0)
    remainder = base.minute % n
    add = n if remainder == 0 else (n - remainder)
    next_slot = base + timedelta(minutes=add)
    time.sleep(max(0, (next_slot - now).total_seconds()))


# --------------------------------- CSV logging (master, optional) -----------------------------
def ensure_trade_log_header():
    if not TRADE_LOG_CSV.exists():
        pd.DataFrame([{
            "datetime_utc": "", "ticker": "", "signal": np.nan, "action": "",
            "price": np.nan, "equity": np.nan, "qty": np.nan, "comment": ""
        }]).iloc[0:0].to_csv(TRADE_LOG_CSV, index=False)

def log_trade(ticker:str, signal:float, action:str, price:float, equity:float, qty:float=None, comment:str=""):
    ensure_trade_log_header()
    row = {
        "datetime_utc": utcnow_iso(),
        "ticker": ticker,
        "signal": signal,
        "action": action,
        "price": float(price) if price is not None else np.nan,
        "equity": float(equity) if equity is not None else np.nan,
        "qty": float(qty) if qty is not None else np.nan,
        "comment": str(comment) if comment else ""
    }
    df_new = pd.DataFrame([row])
    if TRADE_LOG_CSV.exists():
        df_old = pd.read_csv(TRADE_LOG_CSV)
        pd.concat([df_old, df_new], ignore_index=True).to_csv(TRADE_LOG_CSV, index=False)
    else:
        df_new.to_csv(TRADE_LOG_CSV, index=False)
    try:
        shutil.copy2(TRADE_LOG_CSV, TRADE_LOG_LATEST)
    except Exception:
        pass

# --------------------------------- Alpaca API init --------------------------------------------
def init_alpaca() -> "tradeapi.REST":
    if not (globals().get("API_KEY") and globals().get("API_SECRET")):
        raise RuntimeError("Missing Alpaca API keys (check your .env).")
    api = tradeapi.REST(API_KEY, API_SECRET, base_url=BASE_URL)
    _ = api.get_account()
    return api

# ------------------------- Portfolio equity logging + metrics ---------------------------------
def fetch_portfolio_history(period="1M", timeframe="1Hour", api_in=None):
    a = api_in if api_in is not None else globals().get("api", None)
    if a is None:
        return pd.DataFrame(columns=["timestamp_utc","equity"])
    hist = a.get_portfolio_history(period=period, timeframe=timeframe)
    return pd.DataFrame({
        "timestamp_utc": pd.to_datetime(hist.timestamp, unit="s", utc=True),
        "equity": pd.Series(hist.equity, dtype="float64")
    }).dropna()

def log_equity_snapshot(api_in=None):
    snap = fetch_portfolio_history(period="1D", timeframe="5Min", api_in=api_in)
    if snap.empty:
        return
    latest = snap.iloc[-1:].copy()
    latest.rename(columns={"timestamp_utc": "datetime_utc"}, inplace=True)

    if EQUITY_LOG_CSV.exists():
        df_old = pd.read_csv(EQUITY_LOG_CSV, parse_dates=["datetime_utc"])
        merged = pd.concat([df_old, latest], ignore_index=True)
        merged.drop_duplicates(subset=["datetime_utc"], keep="last").to_csv(EQUITY_LOG_CSV, index=False)
    else:
        latest.to_csv(EQUITY_LOG_CSV, index=False)

    try:
        shutil.copy2(EQUITY_LOG_CSV, EQUITY_LOG_LATEST)
    except Exception:
        pass

def plot_equity_curve(from_equity_csv: bool = True):
    with plt.ioff():
        if from_equity_csv and EQUITY_LOG_CSV.exists():
            df = pd.read_csv(EQUITY_LOG_CSV, parse_dates=["datetime_utc"]).sort_values("datetime_utc")
        else:
            df = fetch_portfolio_history(period="3M", timeframe="1Hour").rename(columns={"timestamp_utc":"datetime_utc"})
        if df.empty:
            print("No equity data to plot yet.")
            return
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.plot(df["datetime_utc"], df["equity"])
        ax.set_title("Portfolio Value Over Time (Paper)")
        ax.set_xlabel("Time (UTC)")
        ax.set_ylabel("Equity ($)")
        fig.tight_layout()
        fig.savefig(PLOT_PATH, bbox_inches="tight")
        fig.savefig(PLOT_PATH_LATEST, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved equity curve → {PLOT_PATH}")
        print(f"Updated latest copy → {PLOT_PATH_LATEST}")

def compute_performance_metrics(df_equity: pd.DataFrame):
    if df_equity.empty or df_equity["equity"].isna().all():
        return {"cum_return": np.nan, "sharpe": np.nan, "max_drawdown": np.nan}

    df = df_equity.sort_values("datetime_utc")
    e = df["equity"].astype(float)
    r = e.pct_change().dropna()
    if r.empty:
        return {"cum_return": 0.0, "sharpe": np.nan, "max_drawdown": np.nan}

    # estimate periods/year from median spacing
    dt_sec = df["datetime_utc"].diff().dt.total_seconds().dropna().median()
    if not (isinstance(dt_sec, (int, float)) and dt_sec > 0):
        periods_per_year = 252 * 78  # ~5-min bars as fallback
    else:
        periods_per_day = (6.5 * 3600) / dt_sec
        periods_per_year = 252 * periods_per_day

    sharpe = (r.mean() / (r.std() + 1e-12)) * math.sqrt(periods_per_year)
    cum = (1 + r).cumprod()
    peak = cum.cummax()
    dd = (cum / peak - 1.0).min()
    cum_return = e.iloc[-1] / e.iloc[0] - 1.0

    return {"cum_return": float(cum_return), "sharpe": float(sharpe), "max_drawdown": float(dd)}


# -------------------------------- Hook for strategy loops -------------------------------------
def handle_signal_and_trade(ticker:str, signal:float, action:str, price:float, qty:int):
    log_equity_snapshot()
    eq_df = pd.read_csv(EQUITY_LOG_CSV, parse_dates=["datetime_utc"]) if EQUITY_LOG_CSV.exists() else pd.DataFrame()
    eq_val = float(eq_df.iloc[-1]["equity"]) if not eq_df.empty else np.nan
    log_trade(ticker=ticker, signal=signal, action=action, price=price, equity=eq_val, qty=qty)

# -------------------------------- Per-ticker CSV logging -------------------------------------
def _append_csv_row(path: Path, row: dict):
    fieldnames = list(row.keys())
    if not path.exists():
        with path.open("w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=fieldnames)
            w.writeheader()
            w.writerow(row)
        return

    try:
        with path.open("r", newline="") as f:
            r = csv.reader(f)
            old_header = next(r)
    except Exception:
        old_header = []

    if old_header != fieldnames:
        tmp = path.with_suffix(".tmp")
        with tmp.open("w", newline="") as wf, path.open("r", newline="") as rf:
            r = csv.DictReader(rf) if old_header else None
            w = csv.DictWriter(wf, fieldnames=fieldnames)
            w.writeheader()
            if r:
                for old_row in r:
                    merged = {k: old_row.get(k, "") for k in fieldnames}
                    w.writerow(merged)
        tmp.replace(path)

    with path.open("a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writerow(row)


def log_trade_symbol(symbol: str,
                     bar_time,
                     signal: int,
                     raw_action: float,
                     weight: float,
                     confidence: float,
                     price: float,
                     equity: float,
                     dry_run: bool,
                     note: str = ""):
    try:
        if bar_time is not None and not pd.isna(bar_time):
            ts = pd.to_datetime(bar_time, utc=True)
            bt_iso = ts.isoformat()
            age_sec = max(0, int((now_utc() - ts).total_seconds()))
        else:
            bt_iso, age_sec = "", ""
    except Exception:
        bt_iso, age_sec = "", ""

    resolved_feed = (os.getenv("BARS_FEED", "") or "").strip() or "default"

    # Derive a simple decision label (unless 'note' is explicitly set)
    try:
        ew = float(weight) if np.isfinite(weight) else 0.0
        cf = float(confidence) if np.isfinite(confidence) else 0.0
    except Exception:
        ew, cf = 0.0, 0.0

    decision = note or (
        "rebalance" if (abs(ew) >= float(globals().get("ENTER_WEIGHT_MIN", 0.0))
                        and cf >= float(globals().get("ENTER_CONF_MIN", 0.0)))
        else ("flatten" if abs(ew) <= float(globals().get("EXIT_WEIGHT_MAX", 0.0)) else "hold")
    )

    row = {
        "log_time": now_utc().isoformat(),
        "symbol": symbol,
        "bar_time": bt_iso,
        "bar_age_sec": age_sec,
        "feed": resolved_feed,
        "signal": "BUY" if int(signal) == 1 else "SELL_OR_HOLD",
        "raw_action": float(raw_action) if np.isfinite(raw_action) else "",
        "weight": float(weight) if np.isfinite(weight) else "",
        "confidence": float(confidence) if np.isfinite(confidence) else "",
        "price": float(price) if np.isfinite(price) else "",
        "equity": float(equity) if np.isfinite(equity) else "",
        "dry_run": int(bool(dry_run)),
        "decision": decision,
        "note": note,
    }

    _append_csv_row(RESULTS_DIR / f"trade_log_{symbol}.csv", row)

# -------------------------------- Artifacts: picker & loaders --------------------------------
def _extract_window_idx(path: Path) -> Optional[int]:
    m = re.search(r"_window(\d+)_", path.stem, re.IGNORECASE)
    if not m:
        return None
    try:
        return int(m.group(1))
    except Exception:
        return None

def pick_artifacts_for_ticker(
    ticker: str,
    artifacts_dir: str,
    best_window: Optional[str] = None
) -> Dict[str, Optional[Path]]:
    p = Path(artifacts_dir)
    if not p.exists():
        raise FileNotFoundError(f"Artifacts directory not found: {p.resolve()}")

    models = sorted(p.glob(f"ppo_{ticker}_window*_model*.zip"))
    if not models:
        models = sorted(p.glob(f"ppo_{ticker}_model*.zip")) or sorted(p.glob(f"*{ticker}*model*.zip"))
    if not models:
        raise FileNotFoundError(f"No PPO model zip found for {ticker} in {p}")

    def _model_sort_key(path: Path):
        w = _extract_window_idx(path)
        return (w if w is not None else -1, " (1)" in path.stem)

    models = sorted(models, key=_model_sort_key)

    chosen: Optional[Path] = None
    if best_window:
        chosen = next((m for m in models if f"_window{best_window}_" in m.stem), None)
        if chosen is None:
            logging.warning("BEST_WINDOW=%s not found; falling back to best available.", best_window)

    if chosen is None:
        with_idx = [(m, _extract_window_idx(m)) for m in models]
        with_idx = [(m, w) for (m, w) in with_idx if w is not None]
        chosen = max(with_idx, key=lambda t: t[1])[0] if with_idx else models[-1]

    base = chosen.stem.replace("_model", "")
    base_nodup = re.sub(r"\s\(\d+\)$", "", base)

    vec_candidates = list(p.glob(base + "_vecnorm*.pkl")) + \
                     list(p.glob(base_nodup + "_vecnorm*.pkl")) + \
                     list(p.glob(f"ppo_{ticker}_*_vecnorm*.pkl"))
    feat_candidates = list(p.glob(base + "_features*.json")) + \
                      list(p.glob(base_nodup + "_features*.json")) + \
                      list(p.glob(f"ppo_{ticker}_*_features*.json"))

    vecnorm = sorted(vec_candidates)[0] if vec_candidates else None
    feats   = sorted(feat_candidates)[0] if feat_candidates else None

    logging.info(f"[{ticker}] model={chosen.name} | vecnorm={bool(vecnorm)} | features={bool(feats)}")
    return {"model": chosen, "vecnorm": vecnorm, "features": feats}

def load_vecnormalize(path: Optional[Path]):
    if path is None:
        return None
    try:
        with open(path, "rb") as f:
            return pickle.load(f)
    except Exception as e:
        logging.warning("VecNormalize load failed (%s). Proceeding without it.", e)
        return None

def load_features(path: Optional[Path]):
    if path is None:
        return None
    with open(path, "r") as f:
        return json.load(f)

def load_ppo_model(model_path: Path):
    return PPO.load(str(model_path))

# ---- Cached asset flags (tradable / fractionable / shortable) ----
@lru_cache(maxsize=256)
def _asset_flags(symbol: str) -> Tuple[bool, bool, bool]:
    """
    Return (tradable, fractionable, shortable) for a symbol.
    Cached per-process to reduce repetitive API calls.
    """
    try:
        _api = globals().get("api") or init_alpaca()
        a = _api.get_asset(symbol)
        return (
            bool(getattr(a, "tradable", True)),
            bool(getattr(a, "fractionable", False)),
            bool(getattr(a, "shortable", False)),
        )
    except Exception:
        # conservative fallback
        return True, False, False

# ---------------------------- Market data + account helpers ----------------------------------
def get_recent_bars(api, symbol: str, limit: int = 200, timeframe=TimeFrame.Minute) -> pd.DataFrame:
    def _as_df(bars):
        if hasattr(bars, "df"):
            df = bars.df.copy()
            if not df.empty:
                if isinstance(df.index, pd.MultiIndex):
                    try:
                        df = df.xs(symbol, level=0)
                    except KeyError:
                        df = df.reset_index(level=0, drop=True)
                df.index = pd.to_datetime(df.index, utc=True, errors="coerce")
                df = df.rename(columns={"open": "Open", "high": "High", "low": "Low",
                                        "close": "Close", "volume": "Volume"})
                cols = [c for c in ["Open","High","Low","Close","Volume"] if c in df.columns]
                return df[cols].sort_index()
            return pd.DataFrame(columns=["Open","High","Low","Close","Volume"])

        rows = []
        for b in bars:
            ts = getattr(b, "t", None)
            ts = pd.to_datetime(ts, utc=True) if ts is not None else pd.NaT
            rows.append({
                "timestamp": ts,
                "Open":   float(getattr(b, "o", getattr(b, "open",  np.nan))),
                "High":   float(getattr(b, "h", getattr(b, "high",  np.nan))),
                "Low":    float(getattr(b, "l", getattr(b, "low",   np.nan))),
                "Close":  float(getattr(b, "c", getattr(b, "close", np.nan))),
                "Volume": float(getattr(b, "v", getattr(b, "volume",np.nan))),
            })
        df = pd.DataFrame(rows)
        if df.empty:
            return pd.DataFrame(columns=["Open","High","Low","Close","Volume"])
        return df.set_index(pd.to_datetime(df["timestamp"], utc=True)).drop(columns=["timestamp"]).sort_index()

    feed = os.getenv("BARS_FEED", "").strip()
    try:
        logging.info(f"[{symbol}] fetching {limit} {timeframe} bars (feed='{feed or 'default'}')")
        bars = api.get_bars(symbol, timeframe, limit=limit, feed=feed) if feed else api.get_bars(symbol, timeframe, limit=limit)
        df = _as_df(bars)
        if not df.empty:
            return df
        if feed:
            logging.info(f"[{symbol}] explicit feed empty; retrying with default feed")
            df2 = _as_df(api.get_bars(symbol, timeframe, limit=limit))
            if not df2.empty:
                return df2
    except Exception as e:
        logging.warning(f"[{symbol}] get_bars(limit) failed: {e}")

    try:
        end_dt = datetime.now(timezone.utc).replace(microsecond=0)
        start_dt = end_dt - timedelta(days=5)
        end = end_dt.isoformat().replace("+00:00", "Z")
        start = start_dt.isoformat().replace("+00:00", "Z")
        logging.info(f"[{symbol}] retry window start={start} end={end} (feed='{feed or 'default'}')")
        bars = api.get_bars(symbol, timeframe, start=start, end=end, feed=feed) if feed else api.get_bars(symbol, timeframe, start=start, end=end)
        return _as_df(bars)
    except Exception as e:
        logging.warning(f"[{symbol}] get_bars(start/end) failed: {e}")
        return pd.DataFrame(columns=["Open","High","Low","Close","Volume"])

def get_account_equity(api) -> float:
    return float(api.get_account().equity)

def get_position(api, symbol: str):
    try:
        return api.get_position(symbol)
    except Exception:
        return None

def get_position_qty(api, symbol: str):
    pos = get_position(api, symbol)
    if not pos:
        return 0.0 if USE_FRACTIONALS else 0
    try:
        q = float(pos.qty)
        return q if USE_FRACTIONALS else int(round(q))
    except Exception:
        return 0.0 if USE_FRACTIONALS else 0

def get_last_price(api, symbol: str) -> float:
    try:
        tr = api.get_latest_trade(symbol)
        price = getattr(tr, "price", None)
        if price is None:
            price = getattr(tr, "p", None)
        if price is not None and np.isfinite(price):
            return float(price)
    except Exception:
        pass

    try:
        feed = os.getenv("BARS_FEED", "").strip() or None
        bars = api.get_bars(symbol, TimeFrame.Minute, limit=1, feed=feed) if feed else api.get_bars(symbol, TimeFrame.Minute, limit=1)
        if hasattr(bars, "df"):
            df = bars.df.copy()
            if isinstance(df.index, pd.MultiIndex):
                try:
                    df = df.xs(symbol, level=0)
                except Exception:
                    df = df.reset_index(level=0, drop=True)
            if not df.empty:
                if "close" in df.columns: return float(df["close"].iloc[-1])
                if "Close" in df.columns: return float(df["Close"].iloc[-1])
        elif bars:
            b = bars[0]
            close = getattr(b, "c", getattr(b, "close", None))
            if close is not None:
                return float(close)
    except Exception as e:
        logging.warning(f"[{symbol}] get_last_price via bars failed: {e}")

    try:
        qt = api.get_latest_quote(symbol)
        ap = getattr(qt, "ap", None) or getattr(qt, "ask_price", None)
        bp = getattr(qt, "bp", None) or getattr(qt, "bid_price", None)
        if ap and bp:
            return float((float(ap) + float(bp)) / 2.0)
        if ap: return float(ap)
        if bp: return float(bp)
    except Exception:
        pass

    try:
        pos = api.get_position(symbol)
        return float(pos.avg_entry_price)
    except Exception:
        return float("nan")

def cancel_open_symbol_orders(api, symbol: str):
    try:
        for o in api.list_orders(status="open"):
            if o.symbol == symbol:
                api.cancel_order(o.id)
    except Exception as e:
        logging.warning(f"[{symbol}] cancel orders failed: {e}")

def to_2dp_str(x) -> str:
    return format(Decimal(str(x)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), "f")

def to_6dp_str(x) -> str:
    return format(Decimal(str(x)).quantize(Decimal("0.000001"), rounding=ROUND_DOWN), "f")

def market_order(api, symbol: str, side: str, qty=None, notional: float=None):
    if qty is not None and notional is not None:
        logging.warning(f"[{symbol}] Both qty and notional provided; preferring notional and ignoring qty.")
        qty = None

    if qty is None and notional is None:
        logging.warning(f"[{symbol}] No order size provided; skipping.")
        return None
    if qty is not None:
        try:
            if float(qty) <= 0:
                logging.warning(f"[{symbol}] Non-positive qty ({qty}); skipping.")
                return None
        except Exception:
            pass
    if notional is not None and notional <= 0:
        logging.warning(f"[{symbol}] Non-positive notional (${notional}); skipping.")
        return None

    if DRY_RUN:
        logging.info(
            f"[DRY_RUN] Would submit {side} "
            f"{('notional=$' + to_2dp_str(notional)) if notional is not None else ('qty=' + str(qty))} "
            f"{symbol} (market, day)"
        )
        return None

    try:
        qty_arg = None
        if qty is not None:
            qty_arg = to_6dp_str(qty) if USE_FRACTIONALS else int(qty)
        notional_arg = to_2dp_str(notional) if notional is not None else None

        o = api.submit_order(
            symbol=symbol,
            side=side,
            type="market",
            time_in_force="day",
            qty=qty_arg,
            notional=notional_arg,
        )

        logging.info(
            f"[{symbol}] Submitted {side} "
            f"{('notional=$' + notional_arg) if notional_arg else ('qty=' + str(qty_arg))}"
        )
        return o
    except Exception as e:
        logging.error(f"[{symbol}] submit_order failed: {e}")
        return None

def market_order_to_qty(api, symbol: str, side: str, qty: Union[int, float, str]):
    if USE_FRACTIONALS:
        qf = float(qty)
        q = int(round(qf)) if abs(qf - round(qf)) < 1e-8 else to_6dp_str(qf)
    else:
        q = int(qty)
    return market_order(api, symbol, side=side, qty=q)


# ----------------------------- Sizing / risk + (un)flatten / rebalance ------------------------
def action_to_weight(action) -> Tuple[float, float, float]:
    a = float(np.array(action).squeeze())
    conf = float(abs(np.tanh(a)))
    if a == 0:
        return 0.0, conf, a
    if a < 0:
        if not globals().get("ALLOW_SHORTS", False):
            return 0.0, conf, a
        w = -WEIGHT_CAP * conf if SIZING_MODE == "linear" else (
            0.0 if conf < CONF_FLOOR else -WEIGHT_CAP * (conf - CONF_FLOOR) / (1.0 - CONF_FLOOR)
        )
        w = max(-WEIGHT_CAP, min(0.0, float(w)))
        return w, conf, a
    # a > 0 (long)
    if SIZING_MODE == "linear":
        w = WEIGHT_CAP * conf
    else:
        w = 0.0 if conf < CONF_FLOOR else WEIGHT_CAP * (conf - CONF_FLOOR) / (1.0 - CONF_FLOOR)
    w = max(0.0, min(WEIGHT_CAP, float(w)))
    return w, conf, a


def compute_target_qty_by_cash(equity: float, price: float, target_weight: float, api=None) -> int:
    if not np.isfinite(price) or price <= 0:
        return 0
    if api:
        acct = api.get_account()
        budget = float(getattr(acct, "buying_power", getattr(acct, "cash", equity)))
    else:
        budget = equity

    target_notional = equity * float(target_weight)           # can be negative
    allowed = min(budget, abs(target_notional))
    qty = int(allowed // price)

    if target_weight > 0:
        return max(0, qty)
    else:
        # negative qty means short
        return min(0, -qty) if globals().get("ALLOW_SHORTS", False) else 0


def flatten_symbol(api, symbol: str):
    qty = get_position_qty(api, symbol)
    if (USE_FRACTIONALS and abs(qty) < 1e-8) or (not USE_FRACTIONALS and int(qty) == 0):
        return
    cancel_open_symbol_orders(api, symbol)
    if DRY_RUN:
        logging.info(f"[DRY_RUN] Would close position {symbol}")
        return
    try:
        api.close_position(symbol)
        logging.info(f"[{symbol}] close_position submitted")
    except Exception:
        side = "sell" if qty > 0 else "buy"
        market_order_to_qty(api, symbol, side, abs(qty))

def rebalance_to_weight(api, symbol: str, equity: float, target_weight: float):
    """
    Rebalance toward target_weight.
    - Uses per-symbol fractionals flag (asset.fractionable).
    - Avoids fractional *buys* when covering integer shorts.
    - Checks tradable/shortable constraints.
    """
    price = get_last_price(api, symbol)
    if not np.isfinite(price) or price <= 0:
        logging.warning(f"[{symbol}] Price unavailable; skipping rebalance this cycle.")
        return

    tradable, fractionable, shortable = _asset_flags(symbol)
    if not tradable:
        logging.info(f"[{symbol}] Not tradable; skipping rebalance.")
        return
    use_fractionals = bool(USE_FRACTIONALS and fractionable)

    have_qty        = get_position_qty(api, symbol)          # signed (negative if short)
    have_notional   = have_qty * price                       # current exposure
    target_notional = equity * float(target_weight)          # desired exposure
    delta_notional  = target_notional - have_notional        # change in exposure

    # Skip tiny changes
    if abs(delta_notional) < 1e-9:
        return
    if equity > 0:
        delta_weight = abs(delta_notional) / equity
        if delta_weight < float(globals().get("DELTA_WEIGHT_MIN", 0.0)):
            return

    if use_fractionals:
        dn = round_to_cents(abs(delta_notional))
        if dn < float(globals().get("REBALANCE_MIN_NOTIONAL", 0.0)):
            return

        side = "buy" if delta_notional > 0 else "sell"
        shorting = (target_notional < 0) and (side == "sell")  # increasing a short
        covering = (have_qty < 0) and (side == "buy")         # reducing a short

        if shorting:
            if not shortable:
                logging.info(f"[{symbol}] Not shortable; skipping rebalance toward short.")
                return
            qty = max(1, int(math.floor(dn / price))) if np.isfinite(price) and price > 0 else 1
            market_order_to_qty(api, symbol, side="sell", qty=qty)
            return

        if covering:
            # Covering shorts: buy whole shares (avoid fractional buy vs integer short)
            qty = max(1, int(math.ceil(dn / price))) if np.isfinite(price) and price > 0 else 1
            qty = min(int(abs(have_qty)), qty) if have_qty < 0 else qty
            market_order_to_qty(api, symbol, side="buy", qty=qty)
            return

        # Long exposure changes can safely use notional
        market_order(api, symbol, side=side, notional=dn)
        return

    # ---- Non-fractional mode (whole shares only) ----
    want_qty  = compute_target_qty_by_cash(equity, price, target_weight, api)
    delta_qty = want_qty - have_qty
    if delta_qty == 0:
        return

    approx_delta_notional = abs(delta_qty) * price
    if equity > 0 and approx_delta_notional / equity < float(globals().get("DELTA_WEIGHT_MIN", 0.0)):
        return
    if approx_delta_notional < float(globals().get("REBALANCE_MIN_NOTIONAL", 0.0)):
        return

    side = "buy" if delta_qty > 0 else "sell"
    shorting = (target_notional < 0) and (side == "sell")
    if shorting and not shortable:
        logging.info(f"[{symbol}] Not shortable; skipping rebalance toward short.")
        return

    market_order_to_qty(api, symbol, side=side, qty=int(abs(delta_qty)))

def check_tp_sl_and_maybe_flatten(api, symbol: str) -> bool:
    if TAKE_PROFIT_PCT <= 0 and STOP_LOSS_PCT <= 0:
        return False
    pos = get_position(api, symbol)
    if not pos:
        return False
    try:
        plpc = float(pos.unrealized_plpc)
    except Exception:
        return False
    if TAKE_PROFIT_PCT > 0 and plpc >= TAKE_PROFIT_PCT:
        logging.info(f"[{symbol}] TP hit ({plpc:.4f} >= {TAKE_PROFIT_PCT:.4f}). Flattening.")
        flatten_symbol(api, symbol)
        return True
    if STOP_LOSS_PCT > 0 and plpc <= -abs(STOP_LOSS_PCT):
        logging.info(f"[{symbol}] SL hit ({plpc:.4f} <= {-abs(STOP_LOSS_PCT):.4f}). Flattening.")
        flatten_symbol(api, symbol)
        return True
    return False

# ----------------------------- Inference helpers / features -----------------------------------
def expected_obs_shape(model, vecnorm) -> Optional[tuple]:
    for src in (vecnorm, model):
        try:
            shp = tuple(src.observation_space.shape)
            if shp:
                return shp
        except Exception:
            pass
    return None

def compute_art_feat_order(features_hint: Any, df: pd.DataFrame) -> List[str]:
    if features_hint is None:
        return [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    feats = features_hint.get("features", features_hint) if isinstance(features_hint, dict) else list(features_hint)
    drop = {"datetime", "symbol", "target", "return"}
    return [c for c in feats if c not in drop and (c in df.columns) and pd.api.types.is_numeric_dtype(df[c])]

def build_obs_from_row(row: pd.Series, order: List[str]) -> np.ndarray:
    vals = []
    for c in order:
        v = row.get(c, np.nan)
        vals.append(0.0 if (pd.isna(v) or v is None or v is False) else float(v))
    return np.array(vals, dtype=np.float32)

def _pick_columns_for_channels(features_hint: Any, df: pd.DataFrame, channels: int) -> List[str]:
    numeric = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    cols: List[str] = []
    if isinstance(features_hint, dict) and "features" in features_hint:
        cand = [c for c in features_hint["features"] if c in df.columns and pd.api.types.is_numeric_dtype(df[c])]
        if len(cand) >= channels:
            cols = cand[:channels]
    if not cols:
        pref = ["Close", "Volume", "Adj Close", "Open", "High", "Low"]
        cols = [c for c in pref if c in numeric]
        cols += [c for c in numeric if c not in cols]
        cols = cols[:channels]
    if len(cols) < channels and cols:
        while len(cols) < channels:
            cols.append(cols[-1])
    return cols[:channels]

def add_regime(df: pd.DataFrame) -> pd.DataFrame:
    df["Vol20"] = df["Close"].pct_change().rolling(20).std()
    df["Ret20"] = df["Close"].pct_change(20)
    vol_hi   = (df["Vol20"] > df["Vol20"].median()).astype(int)
    trend_hi = (df["Ret20"].abs() > df["Ret20"].abs().median()).astype(int)
    df["Regime4"] = vol_hi * 2 + trend_hi
    return df

def denoise_wavelet(series: pd.Series, wavelet: str = "db1", level: int = 2) -> pd.Series:
    try:
        import pywt
    except Exception:
        return pd.Series(series).astype(float).ewm(span=5, adjust=False).mean()
    s = pd.Series(series).astype(float).ffill().bfill()
    arr = s.to_numpy()
    try:
        w = pywt.Wavelet(wavelet)
        maxlvl = pywt.dwt_max_level(len(arr), w.dec_len)
        lvl = int(max(0, min(level, maxlvl)))
        if lvl < 1:
            return s
        coeffs = pywt.wavedec(arr, w, mode="symmetric", level=lvl)
        for i in range(1, len(coeffs)):
            coeffs[i] = np.zeros_like(coeffs[i])
        rec = pywt.waverec(coeffs, w, mode="symmetric")
        return pd.Series(rec[:len(arr)], index=s.index)
    except Exception:
        return s.ewm(span=5, adjust=False).mean()

def add_features_live(
    df: pd.DataFrame,
    use_sentiment: bool = False,
    rsi_wilder: bool = True,
    atr_wilder: bool = True,
) -> pd.DataFrame:
    df = df.copy().sort_index()
    cols_ci = {c.lower(): c for c in df.columns}
    rename = {}
    for final, alts in {
        "Open": ["open"], "High": ["high"], "Low": ["low"],
        "Close": ["close","close*","last"], "Adj Close":["adj close","adj_close","adjclose","adjusted close"],
        "Volume":["volume","vol"]
    }.items():
        for a in [final.lower()] + alts:
            if a in cols_ci:
                rename[cols_ci[a]] = final
                break
    df = df.rename(columns=rename)
    if "Adj Close" not in df.columns and "Close" in df.columns:
        df["Adj Close"] = df["Close"]

    df["SMA_20"] = df["Close"].rolling(20).mean()
    df["STD_20"] = df["Close"].rolling(20).std()
    df["Upper_Band"] = df["SMA_20"] + 2 * df["STD_20"]
    df["Lower_Band"] = df["SMA_20"] - 2 * df["STD_20"]

    df["Lowest_Low"]   = df["Low"].rolling(14).min()
    df["Highest_High"] = df["High"].rolling(14).max()
    denom = (df["Highest_High"] - df["Lowest_Low"]).replace(0, np.nan)
    df["Stoch"] = ((df["Close"] - df["Lowest_Low"]) / denom) * 100

    df["ROC"] = df["Close"].pct_change(10)
    sign = np.sign(df["Close"].diff().fillna(0))
    df["OBV"] = (sign * df["Volume"].fillna(0)).cumsum()

    tp = (df["High"] + df["Low"] + df["Close"]) / 3.0
    sma_tp = tp.rolling(20).mean()
    md = (tp - sma_tp).abs().rolling(20).mean().replace(0, np.nan)
    df["CCI"] = (tp - sma_tp) / (0.015 * md)

    df["EMA_10"] = df["Close"].ewm(span=10, adjust=False).mean()
    df["EMA_50"] = df["Close"].ewm(span=50, adjust=False).mean()
    ema12 = df["Close"].ewm(span=12, adjust=False).mean()
    ema26 = df["Close"].ewm(span=26, adjust=False).mean()
    df["MACD_Line"]   = ema12 - ema26
    df["MACD_Signal"] = df["MACD_Line"].ewm(span=9, adjust=False).mean()

    d = df["Close"].diff()
    gain = d.clip(lower=0)
    loss = (-d.clip(upper=0))
    if rsi_wilder:
        avg_gain = gain.ewm(alpha=1/14, adjust=False).mean()
        avg_loss = loss.ewm(alpha=1/14, adjust=False).mean()
    else:
        avg_gain = gain.rolling(14).mean()
        avg_loss = loss.rolling(14).mean()
    rs = avg_gain / avg_loss.replace(0, np.nan)
    df["RSI"] = 100 - (100 / (1 + rs))

    tr = pd.concat([
        (df["High"] - df["Low"]),
        (df["High"] - df["Close"].shift()).abs(),
        (df["Low"]  - df["Close"].shift()).abs(),
    ], axis=1).max(axis=1)
    df["ATR"] = tr.ewm(alpha=1/14, adjust=False).mean() if atr_wilder else tr.rolling(14).mean()

    df["Volatility"]     = df["Close"].pct_change().rolling(20).std()
    df["Denoised_Close"] = denoise_wavelet(df["Close"])

    df = add_regime(df)
    df["SentimentScore"] = (df.get("SentimentScore", 0.0) if use_sentiment else 0.0)
    df["Delta"] = df["Close"].pct_change(1).fillna(0.0)
    df["Gamma"] = df["Delta"].diff().fillna(0.0)

    df.replace([np.inf, -np.inf], np.nan, inplace=True)
    return df

def prepare_observation_from_bars(
    bars_df: pd.DataFrame,
    features_hint: Any = None,
    min_required_rows: int = 60,
    expected_shape: Optional[tuple] = None,
) -> Tuple[np.ndarray, int]:
    feats_df = add_features_live(bars_df).replace([np.inf, -np.inf], np.nan)
    ts = pd.Timestamp.utcnow()
    try:
        idx_ts = pd.Timestamp(feats_df.index[-1])
        ts = idx_ts.tz_convert("UTC") if idx_ts.tzinfo else idx_ts.tz_localize("UTC")
    except Exception:
        pass

    if expected_shape is not None:
        if len(expected_shape) == 2:
            lookback, channels = int(expected_shape[0]), int(expected_shape[1])
            cols = _pick_columns_for_channels(features_hint, feats_df, channels)
            window_df = feats_df[cols].tail(lookback).fillna(0.0)
            arr = window_df.to_numpy(dtype=np.float32)
            if arr.shape[0] < lookback:
                pad_rows = lookback - arr.shape[0]
                arr = np.vstack([np.zeros((pad_rows, channels), dtype=np.float32), arr])
            arr = arr[-lookback:, :channels]
            return arr.reshape(lookback, channels), int(ts.timestamp())

        elif len(expected_shape) == 1:
            n = int(expected_shape[0])
            cand = compute_art_feat_order(features_hint, feats_df)
            if len(feats_df) < max(20, min_required_rows):
                raise ValueError(f"Not enough bars to compute features robustly (have {len(feats_df)}).")
            last = feats_df.iloc[-1]
            vals = []
            for c in cand[:n]:
                v = last.get(c, np.nan)
                vals.append(0.0 if (pd.isna(v) or v is None) else float(v))
            if len(vals) < n:
                vals += [0.0] * (n - len(vals))
            return np.asarray(vals, dtype=np.float32), int(ts.timestamp())

    order = compute_art_feat_order(features_hint, feats_df)
    if not order:
        raise ValueError("No usable features after resolving artifact order.")
    feats_df = feats_df.dropna(subset=order)
    if len(feats_df) < max(20, min_required_rows):
        raise ValueError(f"Not enough bars to compute features robustly (have {len(feats_df)}).")
    last = feats_df.iloc[-1]
    obs = build_obs_from_row(last, order)
    return obs.astype(np.float32), int(ts.timestamp())

# -------------------------------- Live step & loop --------------------------------------------
def ensure_market_open(api) -> bool:
    try:
        return bool(api.get_clock().is_open)
    except Exception:
        return False

def _sleep_until_open(api):
    try:
        clock = api.get_clock()
        if getattr(clock, "is_open", False):
            return
        nxt = pd.to_datetime(getattr(clock, "next_open"), utc=True, errors="coerce")
        if pd.isna(nxt):
            time.sleep(60)
            return
        wait = max(1, int((nxt - now_utc()).total_seconds()))
        logging.info("Market closed. Sleeping %ds until next open.", wait)
        time.sleep(wait)
    except Exception:
        time.sleep(60)

def infer_target_weight(model: PPO, vecnorm: Optional[VecNormalize], obs: np.ndarray) -> Tuple[float, float, float]:
    x = obs
    if vecnorm is not None and hasattr(vecnorm, "normalize_obs"):
        try:
            x = vecnorm.normalize_obs(x)
        except Exception as e:
            logging.warning(f"VecNormalize.normalize_obs failed; using raw obs. Err: {e}")
    action, _ = model.predict(x, deterministic=INF_DETERMINISTIC)
    return action_to_weight(action)

def maybe_patch_stale_with_latest_trade(api, symbol: str, bars_df: pd.DataFrame, max_age_sec: int = None) -> pd.DataFrame:
    """
    If the last minute bar is older than STALE_MAX_SEC (or max_age_sec) but a fresher latest trade exists,
    append a synthetic bar using the trade price (O=H=L=C=last trade, V=0).
    """
    if bars_df.empty:
        return bars_df
    max_age_sec = max_age_sec or int(globals().get("STALE_MAX_SEC", 600))
    try:
        last_ts = pd.Timestamp(bars_df.index[-1])
        last_ts = last_ts.tz_convert("UTC") if last_ts.tzinfo else last_ts.tz_localize("UTC")
        age_sec = int((now_utc() - last_ts).total_seconds())
        if age_sec <= max_age_sec:
            return bars_df

        lt = api.get_latest_trade(symbol)
        price = float(getattr(lt, "price", getattr(lt, "p", float("nan"))))
        ts = pd.to_datetime(getattr(lt, "timestamp", getattr(lt, "t", None)), utc=True)
        if not (pd.notna(ts) and np.isfinite(price)):
            return bars_df

        lt_age = int((now_utc() - ts).total_seconds())
        if lt_age > max_age_sec:
            return bars_df

        synth_time = max(last_ts + pd.Timedelta(minutes=1), ts.floor("min"))
        row = pd.DataFrame(
            {"Open":[price], "High":[price], "Low":[price], "Close":[price], "Volume":[0.0]},
            index=pd.DatetimeIndex([synth_time], tz="UTC")
        )
        patched = pd.concat([bars_df, row]).sort_index()
        patched = patched[~patched.index.duplicated(keep="last")]
        logging.info(f"[{symbol}] Patched stale bars with synthetic trade bar @ {synth_time.isoformat()} px={price:.2f}")
        return patched
    except Exception as e:
        logging.debug(f"[{symbol}] maybe_patch_stale_with_latest_trade failed: {e}")
        return bars_df

def run_live_once_for_symbol(
    api,
    symbol: str,
    model: PPO,
    vecnorm: Optional[VecNormalize],
    features_hint: Optional[dict] = None,
):
    shape = expected_obs_shape(model, vecnorm)

    bars_df = get_recent_bars(api, symbol, limit=200, timeframe=TimeFrame.Minute)
    if bars_df.empty:
        logging.warning(f"[{symbol}] No recent bars; skipping.")
        return

    # Freshness + context
    last_ts = pd.Timestamp(bars_df.index[-1])
    if last_ts.tzinfo is None:
        last_ts = last_ts.tz_localize("UTC")
    else:
        last_ts = last_ts.tz_convert("UTC")
    age = int((now_utc() - last_ts).total_seconds())
    logging.info(f"[{symbol}] last bar: {last_ts} | age={age}s | feed='{BARS_FEED or 'default'}'")

    bars_df = maybe_patch_stale_with_latest_trade(api, symbol, bars_df)
    # Build observation (robust to early-session / data hiccups)
    try:
        obs, obs_ts = prepare_observation_from_bars(
            bars_df,
            features_hint=features_hint,
            min_required_rows=60,
            expected_shape=shape,
        )
    except Exception as e:
        logging.info(f"[{symbol}] Could not build observation this cycle: {e}")
        try:
            eq = get_account_equity(api)
            px = float(bars_df["Close"].iloc[-1]) if not bars_df.empty else get_last_price(api, symbol)
        except Exception:
            eq, px = float("nan"), float("nan")
        log_trade_symbol(
            symbol,
            bars_df.index[-1] if not bars_df.empty else pd.NaT,
            0, 0.0, 0.0, 0.0, px, eq, DRY_RUN,
            note="obs_build_error"
        )
        return
    # Stale observation guard
    if utc_ts(now_utc()) - obs_ts > STALE_MAX_SEC:
        logging.info(f"[{symbol}] Stale obs ...; skip.")
        try:
            eq = get_account_equity(api)
            px = float(bars_df["Close"].iloc[-1]) if not bars_df.empty else get_last_price(api, symbol)
        except Exception:
            eq, px = float("nan"), float("nan")
        log_trade_symbol(
            symbol,
            bars_df.index[-1] if not bars_df.empty else pd.NaT,
            0, 0.0, 0.0, 0.0, px, eq, DRY_RUN,
            note="skip_stale"
        )
        return

    # TP/SL guard (may flatten and exit this cycle)
    if check_tp_sl_and_maybe_flatten(api, symbol):
        return

    # --- inference ---
    target_w, conf, raw = infer_target_weight(model, vecnorm, obs)
    eq   = get_account_equity(api)
    px   = float(bars_df["Close"].iloc[-1]) if not bars_df.empty else get_last_price(api, symbol)
    have = get_position_qty(api, symbol)

    logging.info(
        f"[{symbol}] raw={raw:.4f} conf={conf:.3f} → target_w={target_w:.4f} "
        f"px=${px:.2f} eq=${eq:,.2f} have={have}"
    )
    logging.debug(
        f"[{symbol}] Gates: conf≥ENTER_CONF_MIN? {conf>=ENTER_CONF_MIN} | "
        f"|target_w|≥ENTER_WEIGHT_MIN? {abs(target_w)>=ENTER_WEIGHT_MIN} | "
        f"|target_w|≤EXIT_WEIGHT_MAX? {abs(target_w)<=EXIT_WEIGHT_MAX} | "
        f"Δw floor (DELTA_WEIGHT_MIN): {float(globals().get('DELTA_WEIGHT_MIN',0.0))}"
    )

    # Raw gates
    RAW_POS_MIN = float(globals().get("RAW_POS_MIN", 0.0))
    if target_w > 0 and raw < RAW_POS_MIN:
        logging.info(f"[{symbol}] Raw {raw:.4f} < RAW_POS_MIN {RAW_POS_MIN:.4f}; no action.")
        log_trade_symbol(symbol, bars_df.index[-1], 0, raw, target_w, conf, px, eq, DRY_RUN, note="raw_gate_long")
        return

    RAW_NEG_GATE = float(globals().get("RAW_NEG_MAX", 0.0))
    if target_w < 0 and abs(raw) < RAW_NEG_GATE:
        logging.info(f"[{symbol}] |raw| {abs(raw):.4f} < RAW_NEG_GATE {RAW_NEG_GATE:.4f}; no action.")
        log_trade_symbol(symbol, bars_df.index[-1], 0, raw, target_w, conf, px, eq, DRY_RUN, note="raw_gate_short")
        return

    # Flatten FIRST if near-flat and we have a position
    pos = get_position(api, symbol)
    if abs(target_w) <= EXIT_WEIGHT_MAX and pos:
        logging.info(f"[{symbol}] Model near-flat (≤{EXIT_WEIGHT_MAX:.3f}); flattening.")
        flatten_symbol(api, symbol)
        log_trade_symbol(symbol, bars_df.index[-1], int(target_w > 0), raw, target_w, conf, px, eq, DRY_RUN, note="flatten")
        return

    # Low confidence AND near-flat → do nothing
    if conf < ENTER_CONF_MIN and abs(target_w) <= EXIT_WEIGHT_MAX:
        logging.info(f"[{symbol}] Below conf/near-flat gates; no action.")
        log_trade_symbol(symbol, bars_df.index[-1], int(target_w > 0), raw, target_w, conf, px, eq, DRY_RUN, note="no_action")
        return

    # Entry / rebalance
    if abs(target_w) >= ENTER_WEIGHT_MIN and conf >= ENTER_CONF_MIN:
        # Seed a brand-new position (long OR short)
        if SEED_FIRST_SHARE and have == 0:
            seed_notional = max(REBALANCE_MIN_NOTIONAL, round_to_cents(px if np.isfinite(px) else 1.00))
            side = "buy" if target_w > 0 else "sell"

            # If seeding a short, verify shortable
            if side == "sell":
                try:
                    a = api.get_asset(symbol)
                    if not getattr(a, "shortable", False):
                        logging.info(f"[{symbol}] Not shortable; skipping seed short.")
                        log_trade_symbol(symbol, bars_df.index[-1], 0, raw, target_w, conf, px, eq, DRY_RUN, note="not_shortable_seed")
                        return
                except Exception as e:
                    logging.debug(f"[{symbol}] get_asset shortable check failed: {e}")

            if target_w > 0 and USE_FRACTIONALS:
                market_order(api, symbol, side=side, notional=seed_notional)
            else:
                market_order_to_qty(api, symbol, side=side, qty=1)

            log_trade_symbol(symbol, bars_df.index[-1], int(target_w > 0), raw, target_w, conf, px, eq, DRY_RUN, note="seed_open")
            return

        # Normal rebalance toward target weight
        log_trade_symbol(symbol, bars_df.index[-1], int(target_w > 0), raw, target_w, conf, px, eq, DRY_RUN, note="rebalance")
        rebalance_to_weight(api, symbol, eq, target_w)
    else:
        # Did not meet entry thresholds → hold
        logging.info(f"[{symbol}] target_w ({target_w:.4f}) or conf ({conf:.3f}) below entry gates; hold.")
        log_trade_symbol(symbol, bars_df.index[-1], 0, raw, target_w, conf, px, eq, DRY_RUN, note="hold")

def run_live(tickers: List[str]):
    api_local = init_alpaca()

    per_ticker: Dict[str, Tuple[PPO, Optional[VecNormalize], Optional[dict]]] = {}
    best = (globals().get("BEST_WINDOW_ENV") or None)

    for t in tickers:
        try:
            picks   = pick_artifacts_for_ticker(t, os.getenv("ARTIFACTS_DIR", str(ARTIFACTS_DIR)), best_window=best)
            model   = load_ppo_model(picks["model"])
            vecnorm = load_vecnormalize(picks.get("vecnorm"))
            if vecnorm and hasattr(vecnorm, "training"): vecnorm.training = False
            if vecnorm and hasattr(vecnorm, "norm_reward"): vecnorm.norm_reward = False
            feats   = load_features(picks.get("features"))
            per_ticker[t] = (model, vecnorm, feats)
            logging.info("[%s] Artifacts loaded and ready.", t)
        except Exception as e:
            logging.exception("[%s] Failed to load artifacts: %s", t, e)

    if not per_ticker:
        raise RuntimeError("No models loaded for any ticker. Check artifacts directory and names.")

    loaded_syms = list(per_ticker.keys())
    logging.info("Starting live execution for (loaded): %s", loaded_syms)

    last_exec_at = now_utc() - timedelta(minutes=COOLDOWN_MIN)
    cycle = 0

    try:
        while True:
            if not ensure_market_open(api_local):
                _sleep_until_open(api_local)
                continue

            t_cycle_start = time.perf_counter()

            for t, (model, vecnorm, feat_hint) in per_ticker.items():
                t_sym_start = time.perf_counter()
                run_live_once_for_symbol(api_local, t, model, vecnorm, features_hint=feat_hint)
                logging.info("[TIMER] %s symbol work: %.3fs", t, time.perf_counter() - t_sym_start)

            log_equity_snapshot(api_in=api_local)
            cycle += 1

            if (cycle % 6) == 0:
                try:
                    plot_equity_curve(from_equity_csv=True)
                    df = pd.read_csv(EQUITY_LOG_CSV, parse_dates=["datetime_utc"])  # <-- parse dates here
                    m = compute_performance_metrics(df)
                    logging.info(
                        "Perf: cum_return=%.2f%% | sharpe=%.2f | maxDD=%.2f%%",   # label + key
                        100*m["cum_return"], m["sharpe"], 100*m["max_drawdown"]
                    )
                except Exception as e:
                    logging.warning("Plot/metrics failed: %s", e)

            last_exec_at = now_utc()
            logging.info("[TIMER] full-cycle active time: %.3fs (cooldown=%d min)",
                         time.perf_counter() - t_cycle_start, COOLDOWN_MIN)

            if (cycle % 12) == 0:
                gc.collect()

            _sleep_to_next_minute_block(COOLDOWN_MIN)

    except KeyboardInterrupt:
        logging.info("KeyboardInterrupt: stopping live loop.")
        try:
            log_equity_snapshot(api_in=api_local)
            plot_equity_curve(from_equity_csv=True)
        except Exception as e:
            logging.warning("Finalization failed: %s", e)
    except Exception as e:
        logging.exception("Live loop exception: %s", e)
        try:
            log_equity_snapshot(api_in=api_local)
        except Exception:
            pass
        time.sleep(5)

# --------------------------------- Diagnostic runner ------------------------------------------
def ticker_diagnostic(ticker: str,
                      dry_run: bool = None,
                      timeframe: TimeFrame = TimeFrame.Minute,
                      limit: int = 300):
    if dry_run is None:
        dry_run = bool(globals().get("DRY_RUN", True))

    print(f"\nRunning strategy for {ticker}...")

    try:
        api_local = init_alpaca()
        positions_start = len(api_local.list_positions())
        orders_start    = len(api_local.list_orders(status="open"))
    except Exception as e:
        print(f"Error initializing Alpaca: {e}")
        return

    try:
        best   = (globals().get("BEST_WINDOW_ENV") or None)
        picks  = pick_artifacts_for_ticker(
            ticker,
            os.getenv("ARTIFACTS_DIR", str(globals().get("ARTIFACTS_DIR", PROJECT_ROOT / "artifacts"))),
            best_window=best
        )
        model   = load_ppo_model(picks["model"])
        vecnorm = load_vecnormalize(picks.get("vecnorm")) if picks.get("vecnorm") else None
        if vecnorm and hasattr(vecnorm, "training"): vecnorm.training = False
        if vecnorm and hasattr(vecnorm, "norm_reward"): vecnorm.norm_reward = False
        feats   = load_features(picks.get("features"))
        print(f"Model artifacts loaded for {ticker}")
    except Exception as e:
        print(f"Could not load model for {ticker}: {e}")
        return

    min_rows_needed = 60
    try:
        shape     = expected_obs_shape(model, vecnorm)
        lookback  = int(shape[0]) if (shape is not None and len(shape) == 2) else None
        bars_need = max(200, (lookback or 0) * 3)
        bars_df   = get_recent_bars(api_local, ticker, limit=max(limit, bars_need), timeframe=timeframe)

        min_rows_needed = lookback if lookback is not None else 60
        if len(bars_df) < min_rows_needed:
            print(f"Not enough data for {ticker}: {len(bars_df)} rows (need ≥ {min_rows_needed})")
            bars_df = pd.DataFrame()
    except Exception as e:
        print(f"Error fetching bars for {ticker}: {e}")
        bars_df = pd.DataFrame()

    obs, obs_ts = None, None
    if not bars_df.empty:
        try:
            obs, obs_ts = prepare_observation_from_bars(
                bars_df,
                features_hint=feats,
                min_required_rows=min_rows_needed,
                expected_shape=shape,
            )
        except Exception as e:
            print(f"Error preparing observation for {ticker}: {e}")

    signal = None
    target_w = conf = raw = float("nan")
    predictions_made = 0
    bar_time = pd.NaT
    price = float("nan")
    equity = float("nan")

    orders_submitted = 0
    market_closed = 0

    if obs is not None:
        # OUTER TRY — wraps the whole prediction/trade logic
        try:
            target_w, conf, raw = infer_target_weight(model, vecnorm, obs)
            signal = int(target_w > 0.0)  # (diagnostic display only)
            predictions_made = 1
            print(f"Prediction for {ticker}: {signal} (1 = Buy, 0 = Sell)")

            bar_time = bars_df.index[-1] if not bars_df.empty else pd.NaT
            price    = float(bars_df["Close"].iloc[-1]) if not bars_df.empty else get_last_price(api_local, ticker)
            equity   = get_account_equity(api_local)
            print(f"raw={raw:.4f} conf={conf:.3f} target_w={target_w:.3f} price=${price:.2f} equity=${equity:,.2f}")

            # Log to per-ticker CSV
            log_trade_symbol(ticker, bar_time, signal, raw, target_w, conf, price, equity,
                             dry_run=dry_run, note="diagnostic")

            # INNER TRY — clock/orders section
            try:
                clock = api_local.get_clock()
                if not clock.is_open:
                    print("Market is closed.")
                    market_closed = 1
                else:
                    # Mirror live loop: skip if obs got stale
                    if utc_ts(now_utc()) - obs_ts > STALE_MAX_SEC:
                        print("Stale observation; skipping order submission.")
                        log_trade_symbol(
                            ticker, bar_time, 0, raw, target_w, conf, price, equity,
                            dry_run, note="skip_stale_diag"
                        )
                    elif signal is not None and not dry_run:
                        FORCE_FIRST_BUY = os.getenv("FORCE_FIRST_BUY","0").lower() in ("1","true","yes")

                        # Do we already hold ticker?
                        try:
                            pos = api_local.get_position(ticker)
                            has_position = float(pos.qty) != 0.0
                        except APIError:
                            has_position = False

                        have = get_position_qty(api_local, ticker)

                        if abs(target_w) <= EXIT_WEIGHT_MAX and has_position:
                            flatten_symbol(api_local, ticker)
                            print(f"FLATTEN submitted for {ticker}")
                            orders_submitted += 1


                        elif (
                            SEED_FIRST_SHARE
                            and have == 0
                            and abs(target_w) >= ENTER_WEIGHT_MIN
                            and conf >= ENTER_CONF_MIN
                        ):

                            seed_notional = max(REBALANCE_MIN_NOTIONAL, round_to_cents(price if np.isfinite(price) else 1.00))
                            side = "buy" if target_w > 0 else "sell"

                            skip_seed = False
                            if side == "sell":
                                try:
                                    a = api_local.get_asset(ticker)
                                    if not getattr(a, "shortable", False):
                                        print(f"[{ticker}] Not shortable; skipping seed short.")
                                        log_trade_symbol(ticker, bar_time, 0, raw, target_w, conf, price, equity, dry_run, note="not_shortable_seed")
                                        skip_seed = True
                                except Exception as e:
                                    print(f"[{ticker}] get_asset shortable check failed: {e}")

                            if not skip_seed:
                                if USE_FRACTIONALS and target_w > 0:
                                    market_order(api_local, ticker, side=side, notional=seed_notional)
                                else:
                                    market_order_to_qty(api_local, ticker, side=side, qty=1)

                                log_trade_symbol(ticker, bar_time, int(target_w > 0), raw, target_w, conf, price, equity, dry_run, note="seed_open")
                                orders_submitted += 1

                        elif (FORCE_FIRST_BUY and not has_position) or (signal == 1 and not has_position):
                            market_order(
                                api_local, symbol=ticker, side="buy",
                                qty=(1 if not USE_FRACTIONALS else None),
                                notional=(price if USE_FRACTIONALS else None),
                            )
                            print(f"BUY order submitted for {ticker}")
                            orders_submitted += 1

                        elif signal == 0 and has_position and have > 0:
                            market_order(
                                api_local, symbol=ticker, side="sell",
                                qty=(1 if not USE_FRACTIONALS else None),
                                notional=(price if USE_FRACTIONALS else None),
                            )
                            print(f"SELL order submitted for {ticker}")
                            orders_submitted += 1

                        else:
                            print(f"No action taken for {ticker}")
                    else:
                        print(f"(dry-run) No order submitted for {ticker} — signal={signal}")
            except Exception as e:
                print(f"Trade/clock error for {ticker}: {e}")

        except Exception as e:
            print(f"Error during prediction/trading for {ticker}: {e}")

    try:
        positions_end = len(api_local.list_positions())
        orders_end    = len(api_local.list_orders(status="open"))
        print("\n========== SUMMARY ==========")
        print(f"Processed:         1")
        print(f"Models loaded:     1")
        print(f"Predictions made:  {predictions_made}")
        print(f"Market closed:     {market_closed}")
        print(f"Orders submitted:  {orders_submitted} (dry_run={dry_run})")
        print(f"Existing positions (start -> end): {positions_start} -> {positions_end}")
        print(f"Open orders        (start -> end): {orders_start} -> {orders_end}")
        print("=============================")
    except Exception:
        pass

    return {
        "signal": signal,
        "target_w": target_w,
        "conf": conf,
        "raw": raw,
        "bar_time": bar_time,
        "price": price,
        "equity": equity,
        "dry_run": dry_run,
    }

def log_config_banner():
    try:
        artifacts_list = sorted(p.name for p in ARTIFACTS_DIR.iterdir()) if ARTIFACTS_DIR.exists() else []
    except Exception:
        artifacts_list = []

    logging.info("=== CONFIG ===")
    logging.info("Project root        : %s", PROJECT_ROOT)
    logging.info("ARTIFACTS_DIR       : %s", ARTIFACTS_DIR)
    logging.info("RESULTS_DIR         : %s", RESULTS_DIR)
    logging.info("Tickers             : %s", TICKERS)
    logging.info("API base            : %s", BASE_URL)
    logging.info("AUTO_RUN_LIVE       : %s", os.getenv("AUTO_RUN_LIVE", ""))
    logging.info("INF_DETERMINISTIC   : %s", INF_DETERMINISTIC)
    logging.info(
        "DRY_RUN: %s | BARS_FEED: %s | USE_FRACTIONALS: %s | COOLDOWN_MIN: %s | STALE_MAX_SEC: %s",
        DRY_RUN, BARS_FEED, USE_FRACTIONALS, COOLDOWN_MIN, STALE_MAX_SEC,
    )
    logging.info(
        "WEIGHT_CAP: %.3f | SIZING_MODE: %s | ENTER_CONF_MIN: %.3f | ENTER_WEIGHT_MIN: %.3f | "
        "EXIT_WEIGHT_MAX: %.3f | REBALANCE_MIN_NOTIONAL: %.2f",
        WEIGHT_CAP, SIZING_MODE, ENTER_CONF_MIN, ENTER_WEIGHT_MIN, EXIT_WEIGHT_MAX, REBALANCE_MIN_NOTIONAL,
    )
    logging.info(
        "TAKE_PROFIT_PCT: %.3f | STOP_LOSS_PCT: %.3f | BEST_WINDOW_ENV: %s",
        TAKE_PROFIT_PCT, STOP_LOSS_PCT, (BEST_WINDOW_ENV or ""),
    )
    logging.info(
        "DELTA_WEIGHT_MIN: %.3f | RAW_POS_MIN: %.3f | RAW_NEG_MAX: %.3f",
        float(globals().get("DELTA_WEIGHT_MIN", 0.0)),
        float(globals().get("RAW_POS_MIN", 0.0)),
        float(globals().get("RAW_NEG_MAX", 0.0)),
    )
    if artifacts_list:
        logging.info("Artifacts present (%d): %s", len(artifacts_list), ", ".join(artifacts_list))

# ===================================== MAIN ===================================================
if __name__ == "__main__":
    if IN_COLAB:
        upload_env_and_artifacts_in_colab()
        _maybe_convert_features_txt_to_json()
        _maybe_rename_vecnorm_scaler()
        load_dotenv(dotenv_path=PROJECT_ROOT / ".env", override=True)

    cfg = configure_knobs(overrides={
    # cadence / freshness
    "BARS_FEED": "iex",
    "STALE_MAX_SEC": 600,
    "COOLDOWN_MIN": 3,

    # entry/exit gates
    "ENTER_CONF_MIN": 0.12,
    "ENTER_WEIGHT_MIN": 0.02,
    "EXIT_WEIGHT_MAX": 0.008,

    # rebalance tolerance
    "DELTA_WEIGHT_MIN": 0.012,     # ~1.2% of equity shift required
    "REBALANCE_MIN_NOTIONAL": 25.0, # avoid $1 trickles (was 1.00)

    # raw-action sanity gates
    "RAW_POS_MIN": 0.10,
    "RAW_NEG_MAX": 0.10,           # add this for symmetry on shorts

    # sizing & risk
    "WEIGHT_CAP": 0.35,
    "TAKE_PROFIT_PCT": 0.02,
    "STOP_LOSS_PCT": 0.01,

    # shorts
    "ALLOW_SHORTS": True,
})

    log_config_banner()

    api = init_alpaca()
    acct = api.get_account()
    logging.info("Account status: %s | equity=%s | cash=%s", acct.status, acct.equity, acct.cash)

    for _sym in TICKERS:
        try:
            ticker_diagnostic(_sym, dry_run=DRY_RUN)
        except Exception as e:
            print(f"[DIAG] {_sym} failed: {e}")

    if os.getenv("AUTO_RUN_LIVE", "1").lower() in ("1","true","yes","y","on"):
        run_live(TICKERS)
    else:
        logging.info("AUTO_RUN_LIVE disabled; live loop not started.")


In [None]:
# ---------- Safe summary + diagnostics (no path clobbering) ----------
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Resolve dirs once, preferring globals set by the main script
RESULTS_DIR = Path(globals().get("RESULTS_DIR", os.getenv("RESULTS_DIR", ".")))
LATEST_DIR  = Path(globals().get("LATEST_DIR",  os.getenv("LATEST_DIR",  str(RESULTS_DIR))))

# Prefer explicit global equity paths if present; else pick newest equity_log*.csv
eq_candidates = [
    globals().get("EQUITY_LOG_CSV"),
    globals().get("EQUITY_LOG_LATEST"),
    RESULTS_DIR / "equity_log.csv",
    LATEST_DIR / "equity_log.csv",
]

def _first_existing(paths):
    for p in paths:
        if p:
            p = Path(p)
            if p.exists() and p.is_file():
                return p
    return None

eq_path = _first_existing(eq_candidates)
if eq_path is None:
    # try any "equity_log*.csv" and pick the most recent
    all_eq = list(RESULTS_DIR.glob("equity_log*.csv")) + list(LATEST_DIR.glob("equity_log*.csv"))
    eq_path = max(all_eq, key=lambda p: p.stat().st_mtime, default=None)

if eq_path and eq_path.exists():
    try:
        eq = pd.read_csv(eq_path, parse_dates=["datetime_utc"]).sort_values("datetime_utc")
        if not eq.empty:
            r = eq["equity"].pct_change().dropna()
            sharpe_h = (r.mean() / (r.std() + 1e-12)) * np.sqrt(252 * 6.5) if len(r) else float("nan")
            print(f"\nEquity summary — last: ${eq['equity'].iloc[-1]:,.2f} | "
                  f"n={len(eq)} pts | Sharpe(h): {sharpe_h:.2f} | src={eq_path}")
        else:
            print(f"No rows in equity log: {eq_path}")
    except Exception as e:
        print(f"Could not summarize equity ({eq_path}): {e}")
else:
    print("No equity_log*.csv found in RESULTS_DIR/LATEST_DIR.")

# Report only what's in your env (default UNH)
tickers_to_report = [t.strip().upper() for t in os.getenv("TICKERS", "UNH").split(",") if t.strip()]

print("\nTrade Summary:")
for ticker in tickers_to_report:
    # check both locations
    trade_candidates = [
        RESULTS_DIR / f"trade_log_{ticker}.csv",
        LATEST_DIR / f"trade_log_{ticker}.csv",
    ]
    log_path = _first_existing(trade_candidates)
    if not log_path:
        # tolerate Drive duplicates like "trade_log_XYZ (1).csv"
        any_logs = list(RESULTS_DIR.glob(f"trade_log_{ticker}*.csv")) + \
                   list(LATEST_DIR.glob(f"trade_log_{ticker}*.csv"))
        log_path = max(any_logs, key=lambda p: p.stat().st_mtime, default=None)

    if not log_path or not log_path.exists():
        print(f"{ticker}: no trades logged yet.")
        continue

    try:
        df = pd.read_csv(log_path, on_bad_lines="skip",
                         parse_dates=["log_time","bar_time"])
        key = "signal" if "signal" in df.columns else ("action" if "action" in df.columns else None)
        if key:
            counts = df[key].value_counts(dropna=False).to_dict()
            print(f"{ticker}: {counts} | src={log_path.name}")
        else:
            print(f"{ticker}: log present but missing 'signal'/'action' columns. src={log_path.name}")

        if "confidence" in df.columns and df["confidence"].notna().any():
            plt.figure(figsize=(8, 3.5))
            df["confidence"].dropna().plot(kind="hist", bins=10, edgecolor="black")
            plt.title(f"{ticker} - Confidence Distribution")
            plt.xlabel("confidence")
            plt.tight_layout()
            plt.show()

        for col in ["weight", "raw_action"]:
            if col in df.columns and df[col].notna().any():
                s = df[col].dropna()
                print(f"{ticker} {col}: mean={s.mean():.3f}, std={s.std():.3f}, "
                      f"min={s.min():.3f}, max={s.max():.3f}")
    except Exception as e:
        print(f"{ticker}: could not summarize trades ({log_path}): {e}")

# --- Position Summary (unchanged) ---
try:
    if 'api' not in globals():
        api = init_alpaca()
    positions = api.list_positions()
    total_market_value = 0.0
    print("\nPosition Summary:")
    for p in positions:
        mv = float(p.market_value)
        total_market_value += mv
        print(f"  {p.symbol}: {p.qty} shares @ ${float(p.current_price):.2f} | Value: ${mv:,.2f}")
    print(f"\nTotal Market Value: ${total_market_value:,.2f}")
except Exception as e:
    print(f"Could not summarize positions: {e}")

# --- Filled order counts (last 14 days) ---
from datetime import datetime, timedelta, timezone

def count_filled_orders_since(api, symbol: str, days: int = 14) -> int:
    after = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()
    orders = api.list_orders(status="all", after=after, nested=True)
    return sum(1 for o in orders if o.symbol == symbol and o.status in ("filled","partially_filled"))

try:
    api_chk = api if 'api' in globals() else init_alpaca()
    for sym in tickers_to_report:  # or use TICKERS
        n = count_filled_orders_since(api_chk, sym, days=14)
        print(f"{sym}: {n} filled trades in last 14 days")
except Exception as e:
    print(f"Could not fetch filled orders: {e}")


In [None]:
# --- Export locally & download to your computer (Colab) ---
from pathlib import Path
from datetime import datetime, timezone
from google.colab import files   # <-- NEW: for browser download
import shutil, time, pandas as pd

# Drive root (same as before, to read your results)
ROOT = Path("/content/drive/MyDrive/AlpacaPaper")
TODAY = datetime.now(timezone.utc).strftime("%Y-%m-%d")

# Original sources in Drive (unchanged)
SRC_RESULTS = ROOT / "results" / TODAY         # e.g., /.../results/2025-10-13
SRC_EXPORT  = ROOT / "results_export" / TODAY  # rescue export folder (if used)

# === CHANGE: write/export to LOCAL staging (in Colab VM), not Drive ===
DEST = Path("/content") / "exports" / f"{TODAY}_export"
DEST.mkdir(parents=True, exist_ok=True)

def copy_all(src_dir, dest_dir):
    if src_dir.exists():
        for p in src_dir.glob("*"):
            if p.is_file():
                shutil.copy2(p, dest_dir / p.name)
                print("Copied:", p.name, "from", src_dir.name)
    else:
        print("Missing source:", src_dir)

# Copy from both possible sources into local /content/exports/<today>_export
copy_all(SRC_RESULTS, DEST)
copy_all(SRC_EXPORT, DEST)

# Build/refresh trade_log_master.csv from per-symbol logs (in LOCAL DEST)
sym_logs = list(DEST.glob("trade_log_*.csv"))
if sym_logs:
    frames = []
    for p in sym_logs:
        try:
            df = pd.read_csv(p)
            df["symbol_file"] = p.stem.replace("trade_log_", "")
            frames.append(df)
        except Exception as e:
            print("Skip", p.name, "->", e)
    if frames:
        master = pd.concat(frames, ignore_index=True, sort=False)
        master_path = DEST / "trade_log_master.csv"
        master.to_csv(master_path, index=False)
        print("Wrote:", master_path)

# Zip LOCALLY under /content and trigger a browser download
zip_base = Path("/content") / f"results_{TODAY}_{int(time.time())}"
archive_path = shutil.make_archive(str(zip_base), "zip", DEST)
archive_path = str(Path(archive_path))  # ensure string for files.download

print("ZIP ->", archive_path)

# OPTIONAL: also keep a copy in Drive (uncomment if wanted)
# shutil.copy2(archive_path, ROOT / "results" / Path(archive_path).name)

# Prompt download to your computer
files.download(archive_path)

# Show what's in the LOCAL export folder
print("\nLocal export now contains:")
for p in sorted(DEST.iterdir()):
    print(" -", p.name)
