# Astro pipeline: target variable and XGBoost (BTC)

This notebook shows the full cycle:
1) load quotes (daily)
2) build target variable (oracle labels)
3) compute astro data and build astro features
4) train and evaluate XGBoost.

Important: features are astro-only; price is used only for targets.


## 0. Environment setup

If some packages are missing, install via conda-forge (in active env):

```
conda install -c conda-forge xgboost scikit-learn matplotlib seaborn tqdm pyarrow jupyterlab
```

Also check:
- `configs/astro.yaml` -> `ephe_path` (path to Swiss Ephemeris)
- `configs/subjects.yaml` -> `active_subject_id` and subject birth date


In [None]:

# --- TRACE HELPERS (auto-generated) ---
import inspect as _inspect
from pathlib import Path as _Path

def _format_value(val, max_items=5):
    try:
        import pandas as _pd
    except Exception:
        _pd = None
    try:
        import numpy as _np
    except Exception:
        _np = None

    if _inspect.ismodule(val):
        return f"<module {getattr(val, '__name__', 'module')}>"
    if _inspect.isfunction(val) or _inspect.isclass(val):
        return f"<{type(val).__name__} {getattr(val, '__name__', '')}>"
    if isinstance(val, _Path):
        return f"Path('{val}')"
    if _pd is not None and isinstance(val, _pd.DataFrame):
        cols = list(val.columns)
        head = val.head(3)
        return f"DataFrame shape={val.shape} cols={cols} head=\n{head}"
    if _pd is not None and isinstance(val, _pd.Series):
        head = val.head(3).to_list()
        return f"Series len={len(val)} name={val.name} head={head}"
    if _np is not None and isinstance(val, _np.ndarray):
        sample = val.flatten()[:max_items]
        return f"ndarray shape={val.shape} dtype={val.dtype} sample={sample}"
    if isinstance(val, dict):
        keys = list(val.keys())
        return f"dict keys={keys[:max_items]}" + ("..." if len(keys) > max_items else "")
    if isinstance(val, (list, tuple, set)):
        lst = list(val)
        return f"{type(val).__name__} len={len(val)} sample={lst[:max_items]}" + ("..." if len(lst) > max_items else "")
    if isinstance(val, str):
        if len(val) > 200:
            return repr(val[:200] + '...')
        return repr(val)
    try:
        return repr(val)
    except Exception:
        return f"<{type(val).__name__}>"

VAR_HELP = {
    'center': "astro coordinate center ('geo' or 'helio')",
    'price_mode': "oracle price_mode: 'log' or 'raw'",
    'LABEL_PRICE_MODE': "labeling price space: 'log' or 'raw'",
    'LABEL_MODE': "labeling mode (balanced_future_return / balanced_detrended)",
    'GAUSS_WINDOW': "Gaussian window size (odd) for centered detrend",
    'GAUSS_STD': "Gaussian std for centered detrend",
    'HORIZON': "prediction horizon (days ahead)",
    'TARGET_MOVE_SHARE': "target share of samples kept for balanced labeling",
    'MOVE_SHARE_TOTAL': "total share kept (split up/down)",
    'cfg_market': "market config loaded from configs/market.yaml",
    'cfg_astro': "astro config loaded from configs/astro.yaml",
    'cfg_labels': "labels config loaded from configs/labels.yaml",
    'df_market': "market DataFrame (daily OHLCV)",
    'df_bodies': "astro bodies table (daily)",
    'df_aspects': "astro aspects table (daily)",
    'df_transits': "transit-to-natal aspects table (daily)",
    'df_features': "feature matrix (astro features)",
    'df_labels': "oracle labels table",
    'feature_cols': "feature column names used for model training",
    'TWO_STAGE': "two-stage XGB (move + direction) flag",
    'model': "single-stage model wrapper",
    'model_move': "two-stage move model",
    'model_dir': "two-stage direction model",
}


def trace_cell(title, purpose=None, used_vars=None, notes=None):
    print("\n" + "=" * 120)
    print(f"[CELL] {title}")
    if purpose:
        print(f"Purpose: {purpose}")
    if notes:
        print("Notes:")
        for n in notes:
            print(f"- {n}")
    if used_vars:
        print("Variables used in this cell:")
        for name in used_vars:
            if name in globals():
                val = globals()[name]
                expl = VAR_HELP.get(name)
                if expl:
                    print(f"  {name} ({expl}) = {_format_value(val)}")
                else:
                    print(f"  {name} = {_format_value(val)}")
            else:
                print(f"  {name} = <not set>")
    print("=" * 120)

# Check dependencies (stop notebook if missing)
import importlib.util as iu

required = ["xgboost", "sklearn", "matplotlib", "seaborn", "tqdm", "pyarrow"]
missing = [pkg for pkg in required if iu.find_spec(pkg) is None]

if missing:
    print("Missing packages:", ", ".join(missing))
    print("Install them with:")
    print("conda install -c conda-forge xgboost scikit-learn matplotlib seaborn tqdm pyarrow jupyterlab")
    raise SystemExit("Stopped: install dependencies and rerun")

print("OK: all core dependencies found")


In [None]:
trace_cell(
    title='Cell 3',
    purpose='Base imports and environment setup',
    used_vars=['PROJECT_ROOT', 'Path', 'parent', 'pd', 'plt', 'sns', 'sys'],
)

# Base imports and environment setup
from pathlib import Path
import sys
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

# Visual style
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (12, 4)

# Table display settings
pd.set_option("display.max_columns", 200)
pd.set_option("display.width", 200)

# Project root search (look for configs/market.yaml)
PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / "configs/market.yaml").exists():
    for parent in PROJECT_ROOT.parents:
        if (parent / "configs/market.yaml").exists():
            PROJECT_ROOT = parent
            break

if not (PROJECT_ROOT / "configs/market.yaml").exists():
    raise FileNotFoundError("Project root not found: configs/market.yaml")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"PROJECT_ROOT = {PROJECT_ROOT}")


In [None]:
trace_cell(
    title='Cell 4',
    purpose='Load configs and market data (from Postgres)',
    used_vars=['NB_PLOT_PRICE_MODE', 'PLOT_PRICE_MODE', 'PROJECT_ROOT', 'Path', '_resolve_path', 'active_id', 'cfg_db', 'cfg_labels', 'cfg_market', 'conn', 'data_root', 'db_url', 'df_market', 'load_subjects', 'load_yaml', 'market_cfg', 'path', 'pd', 'psql_connection', 'reports_dir', 'subject', 'subjects', 'value'],
    notes=['Loads market config; if DB/parquet cache enabled, avoids re-download.', 'Market dataframe is the price source for labels/plots.'],
)

# Load configs and market data (from Postgres)
from src.common.config import load_yaml, load_subjects
from src.db.connection import psql_connection

cfg_market = load_yaml(PROJECT_ROOT / "configs/market.yaml")
cfg_astro = load_yaml(PROJECT_ROOT / "configs/astro.yaml")
cfg_labels = load_yaml(PROJECT_ROOT / "configs/labels.yaml")
cfg_db = load_yaml(PROJECT_ROOT / "configs/db.yaml")
cfg_train = load_yaml(PROJECT_ROOT / "configs/training.yaml")

subjects, active_id = load_subjects(PROJECT_ROOT / "configs/subjects.yaml")
subject = subjects[active_id]

market_cfg = cfg_market["market"]

# NOTE: if path is relative, resolve from PROJECT_ROOT
def _resolve_path(value: str | Path) -> Path:
    path = Path(value)
    if path.is_absolute():
        return path
    return (PROJECT_ROOT / path).resolve()

data_root = _resolve_path(market_cfg["data_root"])
processed_dir = data_root / "processed"
reports_dir = data_root / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)

print(f"Active subject: {subject.subject_id}")
print(f"Data root: {data_root}")

# Market source: Postgres only
if "db" not in cfg_db or "url" not in cfg_db["db"]:
    raise KeyError("configs/db.yaml must define db.url")

db_url = cfg_db["db"]["url"]

# Load market_daily from DB
with psql_connection(db_url) as conn:
    df_market = pd.read_sql_query(
        "SELECT date, close FROM market_daily WHERE subject_id = %s ORDER BY date",
        conn,
        params=(subject.subject_id,),
    )

if df_market.empty:
    raise ValueError(
        f"No market data for subject_id={subject.subject_id}. "
        "Load market_daily into Postgres first."
    )

if "date" not in df_market.columns or "close" not in df_market.columns:
    raise ValueError("market_daily must have date and close columns")

df_market["date"] = pd.to_datetime(df_market["date"])
print(df_market.head())
print(f"Market range: {df_market['date'].min().date()} -> {df_market['date'].max().date()}")
print(f"Rows: {len(df_market)}")

# Plot price mode (optional override)
NB_PLOT_PRICE_MODE = None  # 'log' or 'raw'
PLOT_PRICE_MODE = str(NB_PLOT_PRICE_MODE or cfg_labels['labels'].get('price_mode', 'log')).lower()
if PLOT_PRICE_MODE not in {'log', 'raw'}:
    print(f"[WARN] Unknown PLOT_PRICE_MODE={PLOT_PRICE_MODE}, fallback to 'log'")
    PLOT_PRICE_MODE = 'log'
print(f"PLOT_PRICE_MODE = {PLOT_PRICE_MODE}")


In [None]:
trace_cell(
    title='Cell 5',
    purpose='Quick look at price and daily change distribution',
    used_vars=['PLOT_PRICE_MODE', 'ax', 'df_market', 'log_ret', 'np', 'plt', 'price_label', 'price_series'],
    notes=['Market dataframe is the price source for labels/plots.'],
)

# Quick look at price and daily change distribution
fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=False)

price_series = np.log(df_market['close']) if PLOT_PRICE_MODE == 'log' else df_market['close']
price_label = 'log(close)' if PLOT_PRICE_MODE == 'log' else 'close'

ax[0].plot(df_market['date'], price_series, color='tab:blue', linewidth=1)
ax[0].set_title('BTC close (daily)')
ax[0].set_xlabel('Date')
ax[0].set_ylabel(price_label)

# Log returns for a rough distribution check
log_ret = np.log(df_market['close']).diff().dropna()
ax[1].hist(log_ret, bins=80, color='tab:gray')
ax[1].set_title('Daily log return distribution')
ax[1].set_xlabel('log_return')
ax[1].set_ylabel('frequency')

plt.tight_layout()
plt.show()


## 1. Oracle labels (target variable)

Idea: smooth log price, take slope, classify by threshold.


In [None]:
trace_cell(
    title='Cell 7',
    purpose='Balanced binary labeling based on future return',
    used_vars=['GAUSS_STD', 'GAUSS_WINDOW', 'HORIZON', 'LABEL_MODE', 'LABEL_PRICE_MODE', 'MOVE_SHARE_TOTAL', 'TARGET_MOVE_SHARE', '_gaussian_kernel', '_gaussian_smooth_centered', 'base', 'base_series', 'cfg_labels', 'df_labels', 'df_market', 'down_idx', 'future_ret', 'labels_cfg', 'n_down', 'n_up', 'neg', 'np', 'pd', 'per_side', 'pos', 'series', 'smooth', 'std', 'total_n', 'up_idx', 'valid', 'w', 'weights', 'window', 'x'],
    notes=['Labeling mode controls how targets are constructed.', 'Labeling price space: log vs raw.', 'Market dataframe is the price source for labels/plots.', 'Oracle labels (targets) for training/eval.'],
)

# Balanced binary labeling based on future return
# This creates an UP/DOWN target with roughly balanced classes.

labels_cfg = cfg_labels['labels']
HORIZON = int(labels_cfg.get('horizon', 1))
TARGET_MOVE_SHARE = float(labels_cfg.get('target_move_share', 0.5))

LABEL_MODE = 'balanced_future_return'  # 'balanced_future_return' or 'balanced_detrended'
LABEL_PRICE_MODE = 'raw'  # 'raw' or 'log' for labeling space
MOVE_SHARE_TOTAL = float(TARGET_MOVE_SHARE)  # total share of samples kept (split UP/DOWN)

# Centered Gaussian smoothing for detrending (no lag)
GAUSS_WINDOW = int(labels_cfg.get('gauss_window', 201))  # must be odd
GAUSS_STD = float(labels_cfg.get('gauss_std', 50.0))


def _gaussian_kernel(window: int, std: float) -> np.ndarray:
    if window % 2 == 0:
        raise ValueError('GAUSS_WINDOW must be odd')
    x = np.arange(window) - window // 2
    w = np.exp(-(x ** 2) / (2 * (std ** 2)))
    w /= w.sum()
    return w


def _gaussian_smooth_centered(series: pd.Series, window: int, std: float) -> pd.Series:
    weights = _gaussian_kernel(window, std)
    # full window only to avoid edge bias; edges become NaN
    return series.rolling(window=window, center=True, min_periods=window).apply(
        lambda x: np.dot(x, weights), raw=True
    )


if LABEL_PRICE_MODE == 'log':
    base_series = np.log(df_market['close']).astype(float)
elif LABEL_PRICE_MODE == 'raw':
    base_series = df_market['close'].astype(float)
else:
    raise ValueError(f'Unknown LABEL_PRICE_MODE={LABEL_PRICE_MODE}')

if LABEL_MODE == 'balanced_detrended':
    smooth = _gaussian_smooth_centered(base_series, GAUSS_WINDOW, GAUSS_STD)
    base = base_series - smooth
    future_ret = base.shift(-HORIZON) - base
else:
    # Plain future return in selected price space
    future_ret = base_series.shift(-HORIZON) - base_series

valid = future_ret.dropna()
total_n = len(valid)
if total_n == 0:
    raise ValueError('No valid future returns for labeling')

# Choose top-N UP and top-N DOWN by absolute size (balanced)
per_side = max(1, int(total_n * MOVE_SHARE_TOTAL / 2))
pos = valid[valid > 0]
neg = valid[valid < 0]
n_up = min(per_side, len(pos))
n_down = min(per_side, len(neg))

up_idx = pos.nlargest(n_up).index
down_idx = neg.nsmallest(n_down).index  # most negative

df_labels = df_market.copy()
df_labels['target'] = np.nan
df_labels.loc[up_idx, 'target'] = 1
df_labels.loc[down_idx, 'target'] = 0

df_labels = df_labels.dropna(subset=['target']).reset_index(drop=True)
df_labels['target'] = df_labels['target'].astype(int)

BINARY_TREND = True
print(f'Label mode: {LABEL_MODE}, price_mode={LABEL_PRICE_MODE}, horizon={HORIZON}, move_share_total={MOVE_SHARE_TOTAL}')
if LABEL_MODE == 'balanced_detrended':
    print(f'Gaussian detrend: window={GAUSS_WINDOW}, std={GAUSS_STD}')
print(df_labels[['date', 'close', 'target']].head())


In [None]:
trace_cell(
    title='Cell 8',
    purpose='Class distribution (binary)',
    used_vars=['colors', 'counts', 'df_labels', 'i', 'label_map', 'plt'],
    notes=['Oracle labels (targets) for training/eval.'],
)

# Class distribution (binary)
label_map = {0: 'DOWN', 1: 'UP'}
counts = df_labels['target'].value_counts(normalize=True).sort_index() * 100
colors = ['#d62728', '#2ca02c']

plt.figure(figsize=(6, 4))
plt.bar([label_map[i] for i in counts.index], counts.values, color=colors)
plt.title('Class share (balanced labels)')
plt.ylabel('%')
plt.show()


In [None]:
trace_cell(
    title='Cell 9',
    purpose='Visual: price and centered Gaussian smoothing in labeling space',
    used_vars=['GAUSS_STD', 'GAUSS_WINDOW', 'LABEL_MODE', 'LABEL_PRICE_MODE', '_gaussian_smooth_centered', 'ax', 'df_market', 'np', 'plt', 'price_series', 'smooth', 'y_label'],
    notes=['Labeling mode controls how targets are constructed.', 'Labeling price space: log vs raw.', 'Market dataframe is the price source for labels/plots.'],
)

# Visual: price and centered Gaussian smoothing in labeling space
if LABEL_PRICE_MODE == 'log':
    price_series = np.log(df_market['close']).astype(float)
    y_label = 'log(price)'
else:
    price_series = df_market['close'].astype(float)
    y_label = 'price'

smooth = None
if LABEL_MODE == 'balanced_detrended':
    smooth = _gaussian_smooth_centered(price_series, GAUSS_WINDOW, GAUSS_STD)

fig, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.plot(df_market['date'], price_series, label='price', linewidth=0.8)
if smooth is not None:
    ax.plot(df_market['date'], smooth, label=f'Gauss(w={GAUSS_WINDOW}, std={GAUSS_STD})', linewidth=1.2)
ax.set_title('Price and centered Gaussian smoothing in labeling space')
ax.set_xlabel('Date')
ax.set_ylabel(y_label)
ax.legend()
plt.show()


In [None]:
trace_cell(
    title='Cell 10',
    purpose='Future return distribution (labeling space)',
    used_vars=['GAUSS_STD', 'GAUSS_WINDOW', 'HORIZON', 'LABEL_MODE', 'LABEL_PRICE_MODE', '_gaussian_smooth_centered', 'base', 'base_series', 'df_market', 'future_ret', 'np', 'plt', 'smooth'],
    notes=['Labeling mode controls how targets are constructed.', 'Labeling price space: log vs raw.', 'Market dataframe is the price source for labels/plots.'],
)

# Future return distribution (labeling space)
if LABEL_PRICE_MODE == 'log':
    base_series = np.log(df_market['close']).astype(float)
else:
    base_series = df_market['close'].astype(float)

if LABEL_MODE == 'balanced_detrended':
    smooth = _gaussian_smooth_centered(base_series, GAUSS_WINDOW, GAUSS_STD)
    base = base_series - smooth
    future_ret = base.shift(-HORIZON) - base
else:
    future_ret = base_series.shift(-HORIZON) - base_series

plt.figure(figsize=(7, 4))
plt.hist(future_ret.dropna(), bins=80, color='tab:gray')
plt.title('Future return distribution')
plt.xlabel('future return')
plt.ylabel('count')
plt.show()


In [None]:
trace_cell(
    title='Cell 11',
    purpose='Simple label plot (binary)',
    used_vars=['PLOT_PRICE_MODE', 'ax', 'close', 'dates', 'df_labels', 'down_mask', 'labels', 'np', 'pd', 'plot_df', 'plt', 'shade_up_down', 'title', 'up_mask', 'y_label'],
    notes=['Oracle labels (targets) for training/eval.'],
)

# Simple label plot (binary)
plot_df = df_labels[['date', 'close', 'target']].copy()
plot_df['date'] = pd.to_datetime(plot_df['date'])
plot_df = plot_df.sort_values('date').reset_index(drop=True)

def shade_up_down(ax, dates, close, up_mask, down_mask, title: str, y_label: str):
    ax.plot(dates, close, color='black', linewidth=1.2, label='BTC close')
    ax.fill_between(dates, 0, 1, where=up_mask, transform=ax.get_xaxis_transform(),
                   color='green', alpha=0.15, label='UP')
    ax.fill_between(dates, 0, 1, where=down_mask, transform=ax.get_xaxis_transform(),
                   color='red', alpha=0.15, label='DOWN')
    ax.set_title(title)
    ax.set_ylabel(y_label)
    ax.legend(loc='upper left')

labels = plot_df['target'].to_numpy()
up_mask = labels == 1
down_mask = labels == 0

dates = plot_df['date'].to_numpy()
if PLOT_PRICE_MODE == 'log':
    close = np.log(plot_df['close'].to_numpy())
    y_label = 'log(price)'
else:
    close = plot_df['close'].to_numpy()
    y_label = 'Close'

fig, ax = plt.subplots(figsize=(12, 4))
shade_up_down(ax, dates, close, up_mask, down_mask, 'Balanced labels (UP/DOWN)', y_label)
ax.set_xlabel('Date')
plt.tight_layout()
plt.show()


## 2. Target shift by horizon


In [None]:
trace_cell(
    title='Cell 13',
    purpose='Target already includes horizon (future return), no additional shift needed',
    used_vars=['df_labels', 'df_labels_shifted'],
    notes=['Oracle labels (targets) for training/eval.'],
)

# Target already includes horizon (future return), no additional shift needed
df_labels_shifted = df_labels.copy()
df_labels_shifted = df_labels_shifted.dropna(subset=['target']).reset_index(drop=True)
df_labels_shifted['target'] = df_labels_shifted['target'].astype(int)

print(df_labels_shifted[['date', 'target']].tail())
print(f'Rows after labeling: {len(df_labels_shifted)}')


## 3. Astro data and astro features


In [None]:
trace_cell(
    title='Cell 15',
    used_vars=['AstroSettings', '_ephe_path', '_resolve_path', 'a', 'aspects', 'aspects_path', 'aspects_rows', 'astro_cfg', 'b', 'bodies', 'bodies_path', 'bodies_rows', 'calculate_aspects', 'calculate_daily_bodies', 'center', 'cfg_astro', 'd', 'dates', 'datetime', 'df_aspects', 'df_bodies', 'df_market', 'pd', 'processed_dir', 'set_ephe_path', 'settings', 'subject', 'time_utc', 'tqdm'],
    notes=['Uses astro config (bodies/aspects paths, ephemeris path, center).', 'Center controls geo vs helio coordinates for astro bodies/aspects.', 'Market dataframe is the price source for labels/plots.'],
)

from datetime import datetime
from tqdm import tqdm

from src.astro.engine.settings import AstroSettings
from src.astro.engine.calculator import set_ephe_path, calculate_daily_bodies
from src.astro.engine.aspects import calculate_aspects
from src.features.builder import build_features_daily

# Astro settings
astro_cfg = cfg_astro["astro"]
# Same path rules: resolve to PROJECT_ROOT
_ephe_path = _resolve_path(astro_cfg["ephe_path"])
set_ephe_path(str(_ephe_path))

settings = AstroSettings(
    bodies_path=_resolve_path(astro_cfg["bodies_path"]),
    aspects_path=_resolve_path(astro_cfg["aspects_path"]),
)

time_utc = datetime.strptime(astro_cfg["daily_time_utc"], "%H:%M:%S").time()

center = astro_cfg.get("center", "geo")

bodies_path = processed_dir / f"{subject.subject_id}_astro_bodies.parquet"
aspects_path = processed_dir / f"{subject.subject_id}_astro_aspects.parquet"
features_path = processed_dir / f"{subject.subject_id}_features.parquet"

# Ignore astro cache, recompute
print("Ignoring astro cache, recomputing...")
bodies_rows = []
aspects_rows = []
dates = pd.to_datetime(df_market["date"]).dt.date

for d in tqdm(dates, desc="astro days"):
    bodies = calculate_daily_bodies(d, time_utc, settings.bodies, center=center)
    aspects = calculate_aspects(bodies, settings.aspects)

    for b in bodies:
        bodies_rows.append({
            "date": b.date,
            "body": b.body,
            "lon": b.lon,
            "lat": b.lat,
            "speed": b.speed,
            "is_retro": b.is_retro,
            "sign": b.sign,
            "declination": b.declination,
        })

    for a in aspects:
        aspects_rows.append({
            "date": a.date,
            "p1": a.p1,
            "p2": a.p2,
            "aspect": a.aspect,
            "orb": a.orb,
            "is_exact": a.is_exact,
            "is_applying": a.is_applying,
        })

df_bodies = pd.DataFrame(bodies_rows)
df_aspects = pd.DataFrame(aspects_rows)

bodies_path.parent.mkdir(parents=True, exist_ok=True)
df_bodies.to_parquet(bodies_path, index=False)
df_aspects.to_parquet(aspects_path, index=False)
print(f"Saved bodies: {bodies_path}")
print(f"Saved aspects: {aspects_path}")

print(df_bodies.head())
print(df_aspects.head())


In [None]:
trace_cell(
    title='Cell 16',
    purpose='Build astro features',
    used_vars=['build_features_daily', 'df_aspects', 'df_bodies', 'df_features', 'features_path'],
    notes=['Feature matrix used for XGBoost training.'],
)

# Build astro features
# Ignore features cache, recompute
print("Ignoring features cache, recomputing...")
df_features = build_features_daily(df_bodies, df_aspects)
df_features.to_parquet(features_path, index=False)
print(f"Saved features: {features_path}")

print(df_features.head())
print(f"Features: {df_features.shape}")


## 4. Merge features and target


In [None]:
trace_cell(
    title='Cell 18',
    purpose='Merge by date',
    used_vars=['df_dataset', 'df_features', 'df_labels_shifted', 'features', 'labels', 'pd'],
    notes=['Feature matrix used for XGBoost training.'],
)

# Merge by date
features = df_features.copy()
features["date"] = pd.to_datetime(features["date"])

labels = df_labels_shifted[["date", "target"]].copy()
labels["date"] = pd.to_datetime(labels["date"])

# Date intersection only
df_dataset = pd.merge(features, labels, on="date", how="inner")

# Drop possible duplicates
if df_dataset["date"].duplicated().any():
    df_dataset = df_dataset.drop_duplicates(subset=["date"]).reset_index(drop=True)

print(df_dataset.head())
print(f"Final dataset: {df_dataset.shape}")


In [None]:
trace_cell(
    title='Cell 19',
    purpose='Feature inventory (readable table)',
    used_vars=['c', 'df_dataset', 'display', 'feature_cols', 'feature_group', 'feature_type', 'info', 'missing_pct', 'name', 'pd', 'stats'],
)

# Feature inventory (readable table)
feature_cols = [c for c in df_dataset.columns if c not in ['date', 'target']]

def feature_group(name: str) -> str:
    if name.startswith('transit_aspect_'):
        return 'transit_aspect'
    if name.startswith('aspect_'):
        return 'aspect'
    if '_' in name:
        return name.split('_', 1)[0]
    return 'other'

def feature_type(name: str) -> str:
    if '_' in name:
        return name.split('_', 1)[1]
    return ''

info = pd.DataFrame({'feature': feature_cols})
info['group'] = info['feature'].apply(feature_group)
info['type'] = info['feature'].apply(feature_type)

stats = df_dataset[feature_cols].describe().T[['mean', 'std', 'min', 'max']]
missing_pct = df_dataset[feature_cols].isna().mean() * 100
stats = stats.join(missing_pct.rename('missing_%'), how='left')

info = info.merge(stats, left_on='feature', right_index=True, how='left')
info = info.sort_values(['group', 'feature']).reset_index(drop=True)

display(info.groupby('group').size().rename('count').to_frame())

# Full feature list
display(info.style.format({
    'mean': '{:.6f}',
    'std': '{:.6f}',
    'min': '{:.6f}',
    'max': '{:.6f}',
    'missing_%': '{:.2f}'
}))


## 5. Train/val/test split (time-based)


In [None]:
trace_cell(
    title='Cell 21',
    purpose='Time-based split without shuffling',
    used_vars=['df_dataset', 'n', 'test_df', 'train_df', 'train_end', 'train_ratio', 'val_df', 'val_end', 'val_ratio'],
)

# Time-based split without shuffling
train_ratio = 0.7
val_ratio = 0.15

n = len(df_dataset)
train_end = int(n * train_ratio)
val_end = int(n * (train_ratio + val_ratio))

train_df = df_dataset.iloc[:train_end].copy()
val_df = df_dataset.iloc[train_end:val_end].copy()
test_df = df_dataset.iloc[val_end:].copy()

print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")
print(f"Train range: {train_df['date'].min().date()} -> {train_df['date'].max().date()}")
print(f"Test range : {test_df['date'].min().date()} -> {test_df['date'].max().date()}")


## 6. Prepare X/y matrices


In [None]:
trace_cell(
    title='Cell 23',
    purpose='Feature list (astro only)',
    used_vars=['X_test', 'X_train', 'X_val', 'c', 'df_dataset', 'feature_cols', 'np', 'test_df', 'train_df', 'val_df', 'y_test', 'y_train', 'y_val'],
)

# Feature list (astro only)
feature_cols = [c for c in df_dataset.columns if c not in ["date", "target"]]

X_train = train_df[feature_cols].to_numpy(dtype=np.float64)
y_train = train_df["target"].to_numpy(dtype=np.int32)

X_val = val_df[feature_cols].to_numpy(dtype=np.float64)
y_val = val_df["target"].to_numpy(dtype=np.int32)

X_test = test_df[feature_cols].to_numpy(dtype=np.float64)
y_test = test_df["target"].to_numpy(dtype=np.int32)

print(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
print(f"X_val  : {X_val.shape}, y_val  : {y_val.shape}")
print(f"X_test : {X_test.shape}, y_test : {y_test.shape}")


## 7. XGBoost training


In [None]:
trace_cell(
    title='Cell 25',
    used_vars=['BINARY_TREND', 'TWO_STAGE', 'acc', 'accuracy_score', 'bal', 'calc_metrics', 'cfg_train', 'counts', 'device', 'e', 'f1_score', 'f1m', 'fallback_label', 'hi', 'idx', 'info', 'k', 'label_ids', 'lbl', 'lbls', 'lo', 'm', 'majority_label', 'matthews_corrcoef', 'mcc', 'n', 'n_boot', 'np', 'out', 'pred', 'recall_score', 'rng', 'samples', 'seed', 'train_cfg', 'use_cuda', 'vals', 'xgb', 'y_pred', 'y_true'],
)

import xgboost as xgb
import numpy as np
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
    confusion_matrix,
    balanced_accuracy_score,
    matthews_corrcoef,
    recall_score,
)
from src.models.xgb import XGBBaseline

train_cfg = cfg_train.get("training", {})
TWO_STAGE = bool(train_cfg.get("two_stage", True))
if BINARY_TREND and TWO_STAGE:
    print("Binary trend active -> forcing SINGLE_STAGE")
    TWO_STAGE = False

# Check if this XGBoost build supports CUDA
use_cuda = False
try:
    info = xgb.build_info()
    use_cuda = bool(info.get("USE_CUDA", False))
    print(f"XGBoost build_info USE_CUDA = {info.get('USE_CUDA', None)}")
except Exception as e:
    print("Failed to read build_info:", e)

device = "cuda" if use_cuda else "cpu"
print(f"Using device={device}")

if BINARY_TREND:
    label_names = ["DOWN", "UP"]
    label_ids = [0, 1]
else:
    label_names = ["DOWN", "SIDEWAYS", "UP"]
    label_ids = [0, 1, 2]
N_CLASSES = len(label_ids)


def majority_baseline_pred(y_true, lbls):
    counts = [int((y_true == lbl).sum()) for lbl in lbls]
    majority_label = lbls[int(np.argmax(counts))]
    return np.full_like(y_true, majority_label)


def prev_label_baseline_pred(y_true, fallback_label: int = 0):
    if len(y_true) == 0:
        return np.array([], dtype=y_true.dtype)
    pred = np.roll(y_true, 1)
    pred[0] = fallback_label
    return pred


def calc_metrics(y_true, y_pred, lbls):
    acc = accuracy_score(y_true, y_pred)
    bal = recall_score(y_true, y_pred, labels=lbls, average="macro", zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average="macro", zero_division=0)
    return {
        "acc": acc,
        "bal_acc": bal,
        "mcc": mcc,
        "f1_macro": f1m,
        "summary": 0.5 * (bal + f1m),
    }


def bootstrap_metrics(y_true, y_pred, lbls, n_boot=200, seed=42):
    rng = np.random.default_rng(seed)
    n = len(y_true)
    if n == 0:
        return None
    samples = {"acc": [], "bal_acc": [], "mcc": [], "f1_macro": [], "summary": []}
    for _ in range(n_boot):
        idx = rng.integers(0, n, size=n)
        m = calc_metrics(y_true[idx], y_pred[idx], lbls)
        for k in samples:
            samples[k].append(m[k])
    out = {}
    for k, vals in samples.items():
        lo, hi = np.percentile(vals, [2.5, 97.5])
        out[k] = (float(lo), float(hi))
    return out

In [None]:
trace_cell(
    title='Cell 26',
    used_vars=['BINARY_TREND', 'N_CLASSES', 'TWO_STAGE', 'XGBBaseline', 'X_test', 'X_test_dir', 'X_train', 'X_train_dir', 'X_val', 'X_val_dir', 'base_metrics', 'base_pred', 'bootstrap_metrics', 'calc_metrics', 'ci', 'classification_report', 'cm', 'cm_title', 'cnt', 'compute_sample_weight', 'confusion_matrix', 'counts', 'device', 'dir_names', 'dir_pred', 'dir_pred_full', 'dist_parts', 'feature_cols', 'hi', 'key', 'label_ids', 'label_names', 'lbl', 'lbls', 'lo', 'low_recall', 'majority_baseline_pred', 'mask_test_dir', 'mask_train_dir', 'mask_val_dir', 'metrics', 'model', 'model_dir', 'model_move', 'move_names', 'move_pred', 'n', 'name', 'names', 'np', 'overall_title', 'pct', 'plot_confusion', 'plt', 'prev_label_baseline_pred', 'prev_metrics', 'prev_pred', 'print_basic_metrics', 'print_ci', 'report_dict', 'report_str', 'sns', 'title', 'w_train', 'w_train_dir', 'w_train_move', 'w_val', 'w_val_dir', 'w_val_move', 'warn_margin', 'y_pred', 'y_test', 'y_test_dir', 'y_test_move', 'y_train', 'y_train_dir', 'y_train_move', 'y_true', 'y_val', 'y_val_dir', 'y_val_move'],
)

def print_ci(ci, key, name):
    if ci is None:
        return
    lo, hi = ci[key]
    print(f"  {name} 95% CI: [{lo:.4f}, {hi:.4f}]")


def print_basic_metrics(y_true, y_pred, lbls, names, title: str) -> None:
    print()
    print(title)

    metrics = calc_metrics(y_true, y_pred, lbls)
    print("Accuracy:", metrics["acc"])
    print("Balanced acc:", metrics["bal_acc"])
    print("MCC:", metrics["mcc"])
    print("F1 macro:", metrics["f1_macro"])
    print("Summary score (avg bal_acc + f1_macro):", metrics["summary"])

    counts = [int((y_true == lbl).sum()) for lbl in lbls]
    n = len(y_true)
    dist_parts = []
    for lbl, name, cnt in zip(lbls, names, counts):
        pct = 100.0 * cnt / n if n else 0.0
        dist_parts.append(f"{name}={cnt} ({pct:.1f}%)")
    print("Class distribution:", ", ".join(dist_parts))

    report_str = classification_report(
        y_true,
        y_pred,
        labels=lbls,
        target_names=names,
        zero_division=0,
    )
    report_dict = classification_report(
        y_true,
        y_pred,
        labels=lbls,
        target_names=names,
        output_dict=True,
        zero_division=0,
    )
    print("Classification report:")
    print(report_str)

    # Baseline 1: always predict majority class
    base_pred = majority_baseline_pred(y_true, lbls)
    base_metrics = calc_metrics(y_true, base_pred, lbls)
    print(
        f"Majority baseline -> acc={base_metrics['acc']:.4f}, "
        f"bal_acc={base_metrics['bal_acc']:.4f}, f1_macro={base_metrics['f1_macro']:.4f}, "
        f"summary={base_metrics['summary']:.4f}"
    )

    # Baseline 2: predict previous label (naive time baseline)
    prev_pred = prev_label_baseline_pred(y_true, fallback_label=lbls[0])
    prev_metrics = calc_metrics(y_true, prev_pred, lbls)
    print(
        f"Prev-label baseline -> acc={prev_metrics['acc']:.4f}, "
        f"bal_acc={prev_metrics['bal_acc']:.4f}, f1_macro={prev_metrics['f1_macro']:.4f}, "
        f"summary={prev_metrics['summary']:.4f}"
    )

    # Bootstrap CI for model metrics
    ci = bootstrap_metrics(y_true, y_pred, lbls, n_boot=200, seed=42)
    if ci is not None:
        print("Model 95% bootstrap CI:")
        print_ci(ci, "acc", "acc")
        print_ci(ci, "bal_acc", "bal_acc")
        print_ci(ci, "f1_macro", "f1_macro")
        print_ci(ci, "summary", "summary")

    # Sanity warnings
    warn_margin = 0.02
    if metrics["acc"] < max(base_metrics["acc"], prev_metrics["acc"]) + warn_margin:
        print("WARNING: accuracy barely above naive baselines")
    if metrics["bal_acc"] < max(base_metrics["bal_acc"], prev_metrics["bal_acc"]) + warn_margin:
        print("WARNING: balanced accuracy barely above naive baselines")
    if metrics["f1_macro"] < max(base_metrics["f1_macro"], prev_metrics["f1_macro"]) + warn_margin:
        print("WARNING: macro F1 barely above naive baselines")
    if metrics["summary"] < max(base_metrics["summary"], prev_metrics["summary"]) + warn_margin:
        print("WARNING: summary score barely above naive baselines")

    low_recall = []
    for name in names:
        if name in report_dict and report_dict[name]["recall"] < 0.2:
            low_recall.append(f"{name} (recall={report_dict[name]['recall']:.2f})")
    if low_recall:
        print("WARNING: low recall ->", ", ".join(low_recall))

    if len(counts) > 0 and min(counts) < 30:
        print("WARNING: some classes have <30 samples; metrics may be unstable")


def plot_confusion(y_true, y_pred, lbls, names, title: str) -> None:
    cm = confusion_matrix(y_true, y_pred, labels=lbls)
    plt.figure(figsize=(4.5, 3.8))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=names,
        yticklabels=names,
    )
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.show()


if TWO_STAGE:
    print("Training mode: TWO_STAGE (MOVE/NO_MOVE -> UP/DOWN)")

    # --- Stage 1: MOVE vs NO_MOVE ---
    y_train_move = (y_train != 1).astype(np.int32)
    y_val_move = (y_val != 1).astype(np.int32)
    y_test_move = (y_test != 1).astype(np.int32)

    w_train_move = compute_sample_weight(class_weight="balanced", y=y_train_move)
    w_val_move = compute_sample_weight(class_weight="balanced", y=y_val_move)

    model_move = XGBBaseline(
        n_classes=2,
        device=device,
        random_state=42,
        n_estimators=300,
        max_depth=6,
        learning_rate=0.01,
        subsample=0.8,
        colsample_bytree=0.8,
        tree_method="hist",
    )

    model_move.fit(
        X_train, y_train_move,
        X_val=X_val, y_val=y_val_move,
        feature_names=feature_cols,
        sample_weight=w_train_move,
        sample_weight_val=w_val_move,
    )

    # --- Stage 2: direction (UP vs DOWN) only on MOVE rows ---
    mask_train_dir = y_train != 1
    mask_val_dir = y_val != 1
    mask_test_dir = y_test != 1

    X_train_dir = X_train[mask_train_dir]
    y_train_dir = (y_train[mask_train_dir] == 2).astype(np.int32)
    X_val_dir = X_val[mask_val_dir]
    y_val_dir = (y_val[mask_val_dir] == 2).astype(np.int32)

    w_train_dir = compute_sample_weight(class_weight="balanced", y=y_train_dir)
    w_val_dir = compute_sample_weight(class_weight="balanced", y=y_val_dir)

    model_dir = XGBBaseline(
        n_classes=2,
        device=device,
        random_state=42,
        n_estimators=300,
        max_depth=6,
        learning_rate=0.01,
        subsample=0.8,
        colsample_bytree=0.8,
        tree_method="hist",
    )

    model_dir.fit(
        X_train_dir, y_train_dir,
        X_val=X_val_dir, y_val=y_val_dir,
        feature_names=feature_cols,
        sample_weight=w_train_dir,
        sample_weight_val=w_val_dir,
    )

    # --- Combine predictions to preserve the original sequence ---
    move_pred = model_move.predict(X_test)
    dir_pred_full = model_dir.predict(X_test)

    # If MOVE then UP/DOWN else SIDEWAYS(1)
    y_pred = np.where(move_pred == 1, np.where(dir_pred_full == 1, 2, 0), 1)
else:
    if BINARY_TREND:
        print("Training mode: SINGLE_STAGE (binary)")
    else:
        print("Training mode: SINGLE_STAGE (3 classes)")

    w_train = compute_sample_weight(class_weight="balanced", y=y_train)
    w_val = compute_sample_weight(class_weight="balanced", y=y_val)

    model = XGBBaseline(
        n_classes=N_CLASSES,
        device=device,
        random_state=42,
        n_estimators=300,
        max_depth=3,
        learning_rate=0.01,
        subsample=0.8,
        colsample_bytree=0.8,
        tree_method="hist",
    )

    model.fit(
        X_train, y_train,
        X_val=X_val, y_val=y_val,
        feature_names=feature_cols,
        sample_weight=w_train,
        sample_weight_val=w_val,
    )

    y_pred = model.predict(X_test)

# --- Metrics: overall 3-class ---
overall_title = "Overall (binary) metrics" if BINARY_TREND else "Overall (3-class) metrics"
print_basic_metrics(y_test, y_pred, label_ids, label_names, overall_title)
cm_title = "Confusion matrix (binary)" if BINARY_TREND else "Confusion matrix (3-class)"
plot_confusion(y_test, y_pred, label_ids, label_names, cm_title)

if TWO_STAGE:
    # Stage 1 metrics (MOVE vs NO_MOVE)
    move_names = ["NO_MOVE", "MOVE"]
    print_basic_metrics(y_test_move, move_pred, [0, 1], move_names, "Stage 1 (MOVE vs NO_MOVE) metrics")
    plot_confusion(y_test_move, move_pred, [0, 1], move_names, "Confusion matrix (MOVE vs NO_MOVE)")

    # Stage 2 metrics (UP vs DOWN) only on MOVE rows
    if mask_test_dir.sum() > 0:
        X_test_dir = X_test[mask_test_dir]
        y_test_dir = (y_test[mask_test_dir] == 2).astype(np.int32)
        dir_pred = model_dir.predict(X_test_dir)
        dir_names = ["DOWN", "UP"]
        print_basic_metrics(y_test_dir, dir_pred, [0, 1], dir_names, "Stage 2 (UP vs DOWN) metrics")
        plot_confusion(y_test_dir, dir_pred, [0, 1], dir_names, "Confusion matrix (UP vs DOWN)")
    else:
        print()
        print("Stage 2 metrics skipped: no MOVE samples in test set.")

In [None]:
trace_cell(
    title='Cell 27',
    purpose='Plot BTC close with prediction background',
    used_vars=['BINARY_TREND', 'PLOT_END', 'PLOT_LAST_N', 'PLOT_PRICE_MODE', 'PLOT_SCOPE', 'PLOT_START', 'PRED_DIR_MASK_MOVE', 'PRED_MODE', 'PRICE_COL', 'SHOW_TRUE', 'TWO_STAGE', 'X_plot', 'ax', 'axes', 'base_df', 'close', 'dates', 'df_dataset', 'df_market', 'dir_pred', 'dir_pred_plot', 'down_mask', 'down_mask_pred', 'down_mask_true', 'feature_cols', 'market_dates', 'model', 'model_dir', 'model_move', 'move_mask', 'move_pred_plot', 'np', 'pd', 'plot_df', 'plt', 'pred_3c_plot', 'preds_3c', 'shade_up_down', 'test_df', 'title', 'title_pred', 'train_df', 'true_labels', 'up_mask', 'up_mask_pred', 'up_mask_true', 'val_df', 'y_label'],
    notes=['Market dataframe is the price source for labels/plots.'],
)

# Plot BTC close with prediction background
# Green: UP, Red: DOWN, no color: SIDEWAYS

# --- Plot options ---
PLOT_SCOPE = "test"  # "test", "val", "train", "full"
PLOT_START = None    # e.g. "2023-01-01"
PLOT_END = None      # e.g. "2024-01-01"
PLOT_LAST_N = 1400    # set None to disable
PRED_MODE = "dir_only"  # "three_class" or "dir_only"
PRED_DIR_MASK_MOVE = False  # if True, show dir preds only when MOVE predicted
SHOW_TRUE = True           # second panel with true labels
PRICE_COL = "close"

# Select base dataframe
if PLOT_SCOPE == "full":
    base_df = df_dataset.copy()
elif PLOT_SCOPE == "train":
    base_df = train_df.copy()
elif PLOT_SCOPE == "val":
    base_df = val_df.copy()
else:
    base_df = test_df.copy()

if PLOT_SCOPE != "test":
    print("NOTE: PLOT_SCOPE is not test; this is in-sample visualization.")

base_df["date"] = pd.to_datetime(base_df["date"])

# Compute predictions for chosen scope
X_plot = base_df[feature_cols].to_numpy(dtype=np.float64)
if TWO_STAGE:
    move_pred_plot = model_move.predict(X_plot)
    dir_pred_plot = model_dir.predict(X_plot)
    pred_3c_plot = np.where(move_pred_plot == 1, np.where(dir_pred_plot == 1, 2, 0), 1)
else:
    move_pred_plot = np.full(len(base_df), np.nan)
    dir_pred_plot = np.full(len(base_df), np.nan)
    pred_3c_plot = model.predict(X_plot)

plot_df = base_df[["date", "target"]].copy()
plot_df["pred_3c"] = pred_3c_plot
plot_df["pred_move"] = move_pred_plot
plot_df["pred_dir"] = dir_pred_plot

market_dates = df_market[["date", PRICE_COL]].copy()
market_dates["date"] = pd.to_datetime(market_dates["date"])

plot_df = plot_df.merge(market_dates, on="date", how="left")
plot_df = plot_df.dropna(subset=[PRICE_COL]).sort_values("date").reset_index(drop=True)

# Apply date window
if PLOT_START is not None:
    plot_df = plot_df[plot_df["date"] >= pd.to_datetime(PLOT_START)]
if PLOT_END is not None:
    plot_df = plot_df[plot_df["date"] <= pd.to_datetime(PLOT_END)]
if PLOT_LAST_N is not None and len(plot_df) > PLOT_LAST_N:
    plot_df = plot_df.tail(PLOT_LAST_N)

if plot_df.empty:
    raise ValueError("Plot window is empty. Check PLOT_START/PLOT_END/PLOT_LAST_N")

# Helper to shade UP/DOWN zones

def shade_up_down(ax, dates, close, up_mask, down_mask, title: str, y_label: str):
    ax.plot(dates, close, color="black", linewidth=1.2, label="BTC close")
    ax.fill_between(
        dates,
        0,
        1,
        where=up_mask,
        transform=ax.get_xaxis_transform(),
        color="green",
        alpha=0.15,
        label="UP",
    )
    ax.fill_between(
        dates,
        0,
        1,
        where=down_mask,
        transform=ax.get_xaxis_transform(),
        color="red",
        alpha=0.15,
        label="DOWN",
    )
    ax.set_title(title)
    ax.set_ylabel(y_label)
    ax.legend(loc="upper left")

# Choose prediction masks
preds_3c = plot_df["pred_3c"].to_numpy()
if BINARY_TREND:
    up_mask_pred = preds_3c == 1
    down_mask_pred = preds_3c == 0
    title_pred = "Predicted binary (UP/DOWN shaded)"
elif PRED_MODE == "dir_only" and TWO_STAGE:
    dir_pred = plot_df["pred_dir"].to_numpy().astype(int)
    if PRED_DIR_MASK_MOVE:
        move_mask = plot_df["pred_move"].to_numpy() == 1
        up_mask_pred = (dir_pred == 1) & move_mask
        down_mask_pred = (dir_pred == 0) & move_mask
        title_pred = "Predicted direction (dir model, MOVE only)"
    else:
        up_mask_pred = dir_pred == 1
        down_mask_pred = dir_pred == 0
        title_pred = "Predicted direction (dir model, all points)"
else:
    up_mask_pred = preds_3c == 2
    down_mask_pred = preds_3c == 0
    title_pred = "Predicted 3-class (UP/DOWN shaded)"

dates = plot_df["date"].to_numpy()
close = plot_df[PRICE_COL].to_numpy()
if PLOT_PRICE_MODE == 'log':
    close = np.log(close)
    y_label = 'log(price)'
else:
    y_label = 'Close'

if SHOW_TRUE:
    fig, axes = plt.subplots(2, 1, figsize=(12, 7), sharex=True)
    shade_up_down(axes[0], dates, close, up_mask_pred, down_mask_pred, title_pred, y_label)

    true_labels = plot_df["target"].to_numpy()
    if BINARY_TREND:
        up_mask_true = true_labels == 1
        down_mask_true = true_labels == 0
    else:
        up_mask_true = true_labels == 2
        down_mask_true = true_labels == 0
    shade_up_down(axes[1], dates, close, up_mask_true, down_mask_true, "True labels (UP/DOWN shaded)", y_label)

    axes[1].set_xlabel("Date")
    plt.tight_layout()
    plt.show()
else:
    fig, ax = plt.subplots(figsize=(12, 4))
    shade_up_down(ax, dates, close, up_mask_pred, down_mask_pred, title_pred, y_label)
    ax.set_xlabel("Date")
    plt.tight_layout()
    plt.show()



In [None]:
trace_cell(
    title='Cell 28',
    purpose='Confusion matrix',
    used_vars=['BINARY_TREND', 'cm', 'confusion_matrix', 'labels', 'lbl_ids', 'plt', 'sns', 'y_pred', 'y_test'],
)

# Confusion matrix
from sklearn.metrics import confusion_matrix

if BINARY_TREND:
    labels = ['DOWN', 'UP']
    lbl_ids = [0, 1]
else:
    labels = ['DOWN', 'SIDEWAYS', 'UP']
    lbl_ids = [0, 1, 2]

cm = confusion_matrix(y_test, y_pred, labels=lbl_ids)

plt.figure(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=labels, yticklabels=labels)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()


In [None]:
trace_cell(
    title='Cell 29',
    purpose='Feature importance (top 20)',
    used_vars=['IMP_MODEL_STAGE', 'TWO_STAGE', 'feature_cols', 'imp_df', 'importances', 'model', 'model_dir', 'model_move', 'pd', 'plt', 'sns', 'stage_model', 'stage_models', 'stage_name'],
)

# Feature importance (top 20)
# In two-stage mode you can plot one stage or both.
IMP_MODEL_STAGE = "dir"  # "dir", "move", or "both"

if TWO_STAGE:
    if IMP_MODEL_STAGE == "both":
        stage_models = [
            ("MOVE vs NO_MOVE", model_move),
            ("UP vs DOWN", model_dir),
        ]
    elif IMP_MODEL_STAGE == "move":
        stage_models = [("MOVE vs NO_MOVE", model_move)]
    else:
        stage_models = [("UP vs DOWN", model_dir)]
else:
    stage_models = [("3-class", model)]

for stage_name, stage_model in stage_models:
    importances = stage_model.model.feature_importances_
    imp_df = pd.DataFrame({
        "feature": feature_cols,
        "importance": importances,
    }).sort_values("importance", ascending=False)

    plt.figure(figsize=(8, 6))
    sns.barplot(data=imp_df.head(20), x="importance", y="feature", color="tab:blue")
    plt.title(f"Top-20 astro features by importance ({stage_name})")
    plt.xlabel("Importance")
    plt.ylabel("Feature")
    plt.tight_layout()
    plt.show()



In [None]:
trace_cell(
    title='Cell 30',
    purpose='Save model (optional)',
    used_vars=['GAUSS_STD', 'GAUSS_WINDOW', 'HORIZON', 'LABEL_MODE', 'LABEL_PRICE_MODE', 'MOVE_SHARE_TOTAL', 'PROJECT_ROOT', 'TWO_STAGE', 'artifact', 'artifact_config', 'artifact_dir', 'dump', 'feature_cols', 'model', 'model_dir', 'model_move', 'out_path'],
    notes=['Labeling mode controls how targets are constructed.', 'Labeling price space: log vs raw.'],
)

# Save model (optional)
from joblib import dump

artifact_dir = PROJECT_ROOT / "models_artifacts"
artifact_dir.mkdir(parents=True, exist_ok=True)

artifact_config = {
    "label_mode": LABEL_MODE,
    "label_price_mode": LABEL_PRICE_MODE,
    "move_share_total": MOVE_SHARE_TOTAL,
    "gauss_window": GAUSS_WINDOW,
    "gauss_std": GAUSS_STD,
    "horizon": HORIZON,
}

if TWO_STAGE:
    artifact = {
        "mode": "two_stage",
        "move": {
            "model": model_move.model,
            "scaler": model_move.scaler,
        },
        "dir": {
            "model": model_dir.model,
            "scaler": model_dir.scaler,
        },
        "feature_names": feature_cols,
        "config": artifact_config,
    }
    out_path = artifact_dir / f"xgb_astro_balanced_two_stage_h{HORIZON}.joblib"
else:
    artifact = {
        "mode": "single_stage",
        "model": model.model,
        "scaler": model.scaler,
        "feature_names": feature_cols,
        "config": artifact_config,
    }
    out_path = artifact_dir / f"xgb_astro_balanced_h{HORIZON}.joblib"

dump(artifact, out_path)
print(f"Saved: {out_path}")


## 8. Ideas for improvement

- Pick sigma/threshold based on model metrics, not only class balance.
- Add transit-to-natal aspects as extra features.
- Use separate models for different market regimes.


In [None]:
trace_cell(
    title='Cell 32',
    used_vars=[],
)

