In [None]:
# Triple Barrier Sampling Notebook
# This notebook builds a dataset of triple-barrier samples for stock trajectories.
# Each sample starts either every 5 days or immediately after the previous one ends.
# Parallelized across tickers with overlap counts tracked for downstream correlation control.

# -----------------------------
# 1. Configuration
# -----------------------------

#@title Configuration
import numpy as np
import pandas as pd
from joblib import Parallel, delayed

#file
DF_LONG_FILE = '../artifacts/features_long.parquet'

# Barrier and sampling config
UP_MULT = 4.0
DN_MULT = 2.0
MAX_HORIZON = 21
START_EVERY = 5

# Column names
DATE_COL = 'date'
PRICE_COL = 'close'
ATR_COL = 'atr14'
HIGH_COL = 'high'
LOW_COL = 'low'
OPEN_COL = 'open'
SYMBOL_COL = 'symbol'

N_JOBS = -1  # parallelism

# -----------------------------
# 2. Load Data
# -----------------------------

#@title Load your long-format stock data into df_long
# Data must contain at least:
# symbol, date, open, high, low, close, atr
# and be sorted by symbol and date

# Placeholder:
df_long = pd.read_parquet(DF_LONG_FILE)

# Ensure types
for col in [PRICE_COL, ATR_COL, HIGH_COL, LOW_COL, OPEN_COL]:
    df_long[col] = pd.to_numeric(df_long[col], errors='coerce')
df_long[DATE_COL] = pd.to_datetime(df_long[DATE_COL], errors='coerce')

In [None]:
# -----------------------------
# 3. Compute Triple Barrier Samples for One Ticker
# -----------------------------

#@title Compute triple barrier samples for a single ticker using NumPy for speed

def compute_triples_numpy(sym, g):
    g = g.sort_values(DATE_COL).reset_index(drop=True)
    dates = g[DATE_COL].to_numpy('datetime64[ns]')
    close = g[PRICE_COL].to_numpy()
    atr   = g[ATR_COL].to_numpy()
    high  = g[HIGH_COL].to_numpy()
    low   = g[LOW_COL].to_numpy()

    n = len(g)
    pos = 0
    starts = []

    while pos < n - MAX_HORIZON - 1:
        c0, a0 = close[pos], atr[pos]
        if not (np.isfinite(c0) and np.isfinite(a0) and a0 > 0):
            pos += 1
            continue

        top = c0 + UP_MULT * a0
        bot = c0 - DN_MULT * a0
        start = pos + 1
        end = start + MAX_HORIZON

        if end > n:
            break  # Not enough room for full horizon

        hslice = high[start:end]
        lslice = low[start:end]
        cslice = close[start:end]

        hit_top = np.where(hslice >= top)[0]
        hit_bot = np.where(lslice <= bot)[0]

        hit_type = 0
        hit_idx = MAX_HORIZON - 1
        price_hit = cslice[hit_idx] if cslice.size else np.nan

        if hit_top.size and (hit_bot.size == 0 or hit_top[0] <= hit_bot[0]):
            hit_type = 1
            hit_idx = hit_top[0]
            price_hit = top
        elif hit_bot.size:
            hit_type = -1
            hit_idx = hit_bot[0]
            price_hit = bot

        h_used = hit_idx + 1
        t_hit = dates[start + hit_idx]
        ret_from_entry = np.log(price_hit / c0) if price_hit > 0 and c0 > 0 else np.nan

        starts.append(dict(
            symbol=sym,
            t0=dates[pos],
            t_hit=t_hit,
            hit=hit_type,
            entry_px=c0,
            top=top,
            bot=bot,
            h_used=h_used,
            price_hit=price_hit,
            ret_from_entry=ret_from_entry
        ))

        pos += h_used if hit_type != 0 and h_used < START_EVERY else START_EVERY

    return pd.DataFrame(starts)

In [None]:
compute_triples_numpy("AAPL", df_long[df_long.symbol == 'AAPL'])

In [None]:
# -----------------------------
# 4. Parallel Processing for All Tickers
# -----------------------------

#@title Run triple barrier sampling for all tickers in parallel

def compute_all_triples(df):
    groups = list(df.groupby(SYMBOL_COL, sort=False))
    results = Parallel(n_jobs=N_JOBS, backend="loky")(
        delayed(compute_triples_numpy)(sym, g) for sym, g in groups
    )
    return pd.concat(results, ignore_index=True)

trajs_df = compute_all_triples(df_long)

In [None]:
from joblib import Parallel, delayed

# -----------------------------------------------
# 5. Parallelized Overlap Tracking by Symbol
# -----------------------------------------------

def expand_trajectories_for_symbol(sym, g):
    expanded = []
    for _, row in g.iterrows():
        dates = pd.date_range(start=row['t0'], periods=row['h_used'], freq='B')
        expanded.append(pd.DataFrame({
            'symbol': sym,
            'date': dates,
            'trajectory_id': f"{sym}_{row['t0'].date()}"
        }))
    return pd.concat(expanded, ignore_index=True)

# Group by symbol first
symbol_groups = list(trajs_df.groupby('symbol'))

# Run in parallel
expanded_parts = Parallel(n_jobs=-1, backend='loky', verbose=0)(
    delayed(expand_trajectories_for_symbol)(sym, g) for sym, g in symbol_groups
)

# Combine the expanded per-symbol dataframes
overlap_df = pd.concat(expanded_parts, ignore_index=True)

# Count overlapping trajectories
overlap_counts = (
    overlap_df.groupby(['symbol', 'date'])['trajectory_id']
    .count()
    .reset_index(name='n_overlapping_trajs')
)

# Merge back into trajs_df to annotate each entry point
trajs_df = trajs_df.merge(
    overlap_counts.rename(columns={'date': 't0'}),
    on=['symbol', 't0'],
    how='left'
)

In [None]:
# -----------------------------
# 6. Output Preview
# -----------------------------

#@title View the final trajectory samples
print(trajs_df.head())
print("\nSample count:", len(trajs_df))

In [None]:
df_long[df_long.symbol == 'AAPL']['t_hit']

In [None]:
import matplotlib.pyplot as plt
import mplfinance as mpf
import pandas as pd
import numpy as np

# === USER CONFIGURABLE ===
TICKER   = "AAPL"                      # Example ticker
START_DT = pd.Timestamp("2025-06-30")  # Entry point to visualize

# --- Get the relevant trajectory row
row = df_long[(df_long.symbol == TICKER) & (df_long["t0__up4.0_dn2.0_h21_s5"] == START_DT)].squeeze()
if row.empty:
    raise ValueError("No trajectory found for given ticker and date")

# --- Price data for the ticker
px = df_long[df_long.symbol == TICKER].set_index(DATE_COL).sort_index()

# --- Extract key trajectory values
# t0__up4.0_dn2.0_h21_s5
# t_hit__up4.0_dn2.0_h21_s5
# hit__up4.0_dn2.0_h21_s5
# entry_px__up4.0_dn2.0_h21_s5
# top__up4.0_dn2.0_h21_s5
# bot__up4.0_dn2.0_h21_s5
# h_used__up4.0_dn2.0_h21_s5
# price_hit__up4.0_dn2.0_h21_s5
# ret_from_entry__up4.0_dn2.0_h21_s5
# n_overlapping_trajs__up4.0_dn2.0_h21_s5

entry_dt = pd.to_datetime(row['t0__up4.0_dn2.0_h21_s5'])
hit_dt   = pd.to_datetime(row['t_hit__up4.0_dn2.0_h21_s5'])
horizon  = row['h_used__up4.0_dn2.0_h21_s5']
entry_px = row['entry_px__up4.0_dn2.0_h21_s5']
top_px   = row['.top__up4.0_dn2.0_h21_s5']
bot_px   = row['bot__up4.0_dn2.0_h21_s5']
hit      = row['hit__up4.0_dn2.0_h21_s5']

# --- Get price window around trajectory
idx_start = px.index.get_loc(entry_dt)
idx_end   = min(idx_start + int(horizon) + 10, len(px))
px_window = px.iloc[max(0, idx_start - 10):idx_end].copy()

# Required OHLC format for mplfinance
ohlc = px_window[[OPEN_COL, HIGH_COL, LOW_COL, PRICE_COL]].copy()
ohlc.columns = ['Open', 'High', 'Low', 'Close']

# --- Count overlapping trajectories
active_mask = (trajs_df.symbol == TICKER) & \
              (trajs_df['t0__up4.0_dn2.0_h21_s5'] <= entry_dt) & \
              ((trajs_df['hit__up4.0_dn2.0_h21_s5'] >= entry_dt) | (trajs_df['hit__up4.0_dn2.0_h21_s5'].isna()))
overlap_count = active_mask.sum()

# --- Custom Lines and Markers
add_lines = [
    mpf.make_addplot([top_px] * len(ohlc), color='green', linestyle='--'),
    mpf.make_addplot([bot_px] * len(ohlc), color='red', linestyle='--')
]

# --- Custom marker for entry and hit
entry_idx = ohlc.index.get_loc(entry_dt) if entry_dt in ohlc.index else None
hit_idx   = ohlc.index.get_loc(hit_dt) if hit_dt in ohlc.index else None

if entry_idx is not None:
    add_lines.append(
        mpf.make_addplot(
            [np.nan if i != entry_idx else entry_px for i in range(len(ohlc))],
            type='scatter', markersize=70, marker='o', color='black'
        )
    )

if hit_idx is not None:
    add_lines.append(
        mpf.make_addplot(
            [np.nan if i != hit_idx else row.price_hit for i in range(len(ohlc))],
            type='scatter', markersize=70, marker='x', color='purple'
        )
    )

# --- Plot the candlestick chart
mpf.plot(
    ohlc,
    type='candle',
    style='yahoo',
    addplot=add_lines,
    title=f"{TICKER} — Triple Barrier Trajectory",
    ylabel='Price',
    datetime_format='%Y-%m-%d',
    xrotation=15,
    tight_layout=True,
    volume=False,
    alines=dict(
        alines=[[(entry_dt, ohlc['Low'].min()), (entry_dt, ohlc['High'].max())]],
        colors=['gray'], linestyle=':', linewidths=1
    )
)

# --- Metadata display below the chart
print(f"\nMetadata for {TICKER} on {entry_dt.date()}")
print("-" * 40)
print(f"Entry Date        : {entry_dt.date()}")
print(f"Hit Date          : {hit_dt.date()}")
print(f"Hit Type          : {hit} ({'Top' if hit==1 else 'Bottom' if hit==-1 else 'Horizon'})")
print(f"Entry Price       : {entry_px:.2f}")
print(f"Top Barrier       : {top_px:.2f}")
print(f"Bottom Barrier    : {bot_px:.2f}")
print(f"Overlapping Count : {overlap_count}")
print(f"Horizon Used      : {horizon} bars")

In [None]:
df_long[df_long.symbol == 'AAPL']['t0__up4.0_dn2.0_h21_s5']

In [None]:
for col in df_long.columns:
    print(col)