# Crypto Label Visualization

Notebook để gắn nhãn và visualize nhãn giá bằng candlestick với Plotly.

## 2 Mode Labeling:
1. **Three-Barrier**: Triple-barrier labeling (0=hold, 1=buy, 2=sell)
2. **Sliding Window**: Tìm peak (sell) và dip (buy) trong cửa sổ trượt

## Features:
- Signals: marker tam giác xanh (buy) / đỏ (sell)
- Technical indicators: BBands, SMA


In [474]:
import numpy as np
import pandas as pd
import talib
from pathlib import Path
from numba import jit
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.signal import find_peaks


In [475]:
@jit(nopython=True)
def _calculate_labels_numba(
    close_prices: np.ndarray,
    atr_values: np.ndarray,
    barrier_atr_multiplier: float,
    horizon: int,
) -> np.ndarray:
    """
    Numba-accelerated triple-barrier labeling.
    Returns multiclass labels: 0=hold, 1=buy, 2=sell
    """
    n = len(close_prices)
    label_multi = np.zeros(n, dtype=np.int64)

    for i in range(n - 1):
        entry_price = close_prices[i]
        atr = atr_values[i]
        if np.isnan(atr) or atr <= 0:
            continue

        barrier_distance = atr * barrier_atr_multiplier
        upper_barrier = entry_price + barrier_distance
        lower_barrier = entry_price - barrier_distance
        end_idx = min(i + 1 + horizon, n)

        for j in range(i + 1, end_idx):
            if close_prices[j] >= upper_barrier:
                label_multi[i] = 1  # buy
                break
            elif close_prices[j] <= lower_barrier:
                label_multi[i] = 2  # sell
                break
    return label_multi


In [476]:
def calculate_labels_sliding_window(
    close_prices: np.ndarray,
    window_size: int = 20,
    min_distance: int = 5,
    prominence: float = None,
    use_percentile: bool = True,
    percentile_threshold: float = 0.7,
) -> np.ndarray:
    """
    Sliding window method để tìm peak và dip.
    
    Args:
        close_prices: Array giá đóng cửa
        window_size: Kích thước cửa sổ trượt
        min_distance: Khoảng cách tối thiểu giữa các peaks/dips
        prominence: Độ nổi bật tối thiểu (None = tự động tính từ percentile)
        use_percentile: Nếu True, dùng percentile của price changes làm threshold
        percentile_threshold: Percentile để tính threshold (0.7 = top 30%)
    
    Returns:
        Array labels: 0=hold, 1=buy (dip), 2=sell (peak)
    """
    n = len(close_prices)
    labels = np.zeros(n, dtype=np.int64)
    
    # Tính prominence tự động nếu không được chỉ định
    if prominence is None and use_percentile:
        # Tính price changes
        price_changes = np.abs(np.diff(close_prices))
        valid_changes = price_changes[~np.isnan(price_changes)]
        if len(valid_changes) > 0:
            prominence = np.percentile(valid_changes, percentile_threshold * 100)
        else:
            prominence = 0.01
    
    # Tìm peaks (local maxima) - SELL signals
    peaks, peak_properties = find_peaks(
        close_prices,
        distance=min_distance,
        prominence=prominence,
    )
    labels[peaks] = 2  # sell at peaks
    
    # Tìm dips (local minima) - BUY signals
    # Invert signal để tìm minima
    inverted_prices = -close_prices
    dips, dip_properties = find_peaks(
        inverted_prices,
        distance=min_distance,
        prominence=prominence,
    )
    labels[dips] = 1  # buy at dips
    
    return labels


In [477]:
def load_coin_data(coin: str, data_dir: str = "./data/crypto-1200", end_date: str = None):
    """
    Loads and filters Parquet data for a single coin.
    Tham khảo từ crypto_datamodule.py
    """
    data_path = Path(data_dir)
    filename = f"{coin.lower()}.parquet"
    filepath = data_path / "data" / filename
    
    if not filepath.exists():
        raise FileNotFoundError(f"File not found: {filepath}")
    
    df = pd.read_parquet(filepath)
    df["date"] = pd.to_datetime(df["date"])
    
    if end_date:
        end_date_dt = pd.to_datetime(end_date, utc=True)
        df = df[df["date"] <= end_date_dt]
    
    if df.empty:
        raise ValueError(f"No data found for {coin}")
    
    df = df.drop_duplicates(subset=["date"]).sort_values("date")
    return df[["date", "open", "high", "low", "close", "volume"]].copy()


In [478]:
def calculate_indicators(df: pd.DataFrame):
    """
    Tính toán các chỉ báo kỹ thuật: ATR, BBands, SMA
    Tham khảo từ crypto_datamodule.py
    """
    open_p = df["open"].values
    high_p = df["high"].values
    low_p = df["low"].values
    close_p = df["close"].values
    volume = df["volume"].values
    
    # ATR (14 periods)
    atr = talib.ATR(high_p, low_p, close_p, timeperiod=14)
    
    # Bollinger Bands (20 periods)
    bb_upper, bb_middle, bb_lower = talib.BBANDS(close_p, timeperiod=20)
    
    # SMA (50 và 200 periods)
    sma_50 = talib.SMA(close_p, timeperiod=50)
    sma_200 = talib.SMA(close_p, timeperiod=200)
    
    # Thêm vào dataframe
    df = df.copy()
    df["atr"] = atr
    df["bb_upper"] = bb_upper
    df["bb_middle"] = bb_middle
    df["bb_lower"] = bb_lower
    df["sma_50"] = sma_50
    df["sma_200"] = sma_200
    
    return df


In [479]:
def calculate_labels(
    df: pd.DataFrame,
    mode: str = "three_barrier",
    # Three-barrier parameters
    barrier_atr_multiplier: float = 2.0,
    horizon: int = 4,
    # Sliding window parameters
    window_size: int = 20,
    min_distance: int = 5,
    prominence: float = None,
    use_percentile: bool = True,
    percentile_threshold: float = 0.7,
):
    """
    Tính toán labels 3-class (0=hold, 1=buy, 2=sell)
    
    Args:
        df: DataFrame với columns: date, open, high, low, close, volume, atr
        mode: "three_barrier" hoặc "sliding_window"
        barrier_atr_multiplier: Multiplier cho barrier distance (three-barrier)
        horizon: Horizon cho triple-barrier labeling
        window_size: Kích thước cửa sổ (sliding_window, không dùng nhưng giữ để tương lai)
        min_distance: Khoảng cách tối thiểu giữa peaks/dips (sliding_window)
        prominence: Độ nổi bật tối thiểu (sliding_window)
        use_percentile: Dùng percentile để tính prominence (sliding_window)
        percentile_threshold: Percentile threshold (sliding_window)
    """
    close_prices = df["close"].values
    
    if mode == "three_barrier":
        atr_values = df["atr"].values
        labels = _calculate_labels_numba(
            close_prices,
            atr_values,
            barrier_atr_multiplier,
            horizon
        )
    elif mode == "sliding_window":
        labels = calculate_labels_sliding_window(
            close_prices,
            window_size=window_size,
            min_distance=min_distance,
            prominence=prominence,
            use_percentile=use_percentile,
            percentile_threshold=percentile_threshold,
        )
    else:
        raise ValueError(f"Unknown mode: {mode}. Use 'three_barrier' or 'sliding_window'")
    
    df = df.copy()
    df["label"] = labels
    
    return df


In [480]:
def visualize_labels(
    df: pd.DataFrame,
    coin_name: str = "BTC",
    show_last_n: int = None,
    mode: str = "three_barrier",
    **label_params
):
    """
    Visualize candlestick chart với labels, BBands, và SMA
    
    Args:
        df: DataFrame với columns: date, open, high, low, close, volume, atr, bb_upper, bb_middle, bb_lower, sma_50, sma_200, label
        coin_name: Tên coin để hiển thị
        show_last_n: Chỉ hiển thị n candles cuối (None = hiển thị tất cả)
        mode: "three_barrier" hoặc "sliding_window"
        **label_params: Các tham số cho labeling (sẽ hiển thị trong title)
    """
    # Lọc data nếu cần
    if show_last_n:
        df = df.tail(show_last_n).copy()
    
    # Tạo subplot với volume ở dưới
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.03,
        row_heights=[0.7, 0.3],
        subplot_titles=(f"{coin_name} Price Chart ({mode})", "Volume")
    )
    
    # 1. Candlestick chart
    fig.add_trace(
        go.Candlestick(
            x=df["date"],
            open=df["open"],
            high=df["high"],
            low=df["low"],
            close=df["close"],
            name="Price",
            increasing_line_color="#26a69a",
            decreasing_line_color="#ef5350",
        ),
        row=1, col=1
    )
    
    # 2. Bollinger Bands
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["bb_upper"],
            mode="lines",
            name="BB Upper",
            line=dict(color="rgba(128,128,128,0.3)", width=1, dash="dash"),
            showlegend=True,
        ),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["bb_middle"],
            mode="lines",
            name="BB Middle",
            line=dict(color="rgba(128,128,128,0.5)", width=1),
            showlegend=True,
        ),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["bb_lower"],
            mode="lines",
            name="BB Lower",
            line=dict(color="rgba(128,128,128,0.3)", width=1, dash="dash"),
            showlegend=True,
            fill="tonexty",
            fillcolor="rgba(128,128,128,0.1)",
        ),
        row=1, col=1
    )
    
    # 3. SMA
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["sma_50"],
            mode="lines",
            name="SMA 50",
            line=dict(color="blue", width=1.5),
            showlegend=True,
        ),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["sma_200"],
            mode="lines",
            name="SMA 200",
            line=dict(color="orange", width=1.5),
            showlegend=True,
        ),
        row=1, col=1
    )
    
    # 4. Buy signals (triangle up, màu xanh)
    buy_mask = df["label"] == 1
    if buy_mask.any():
        buy_dates = df.loc[buy_mask, "date"]
        buy_prices = df.loc[buy_mask, "low"] * 0.999  # Đặt marker hơi dưới low để dễ nhìn
        fig.add_trace(
            go.Scatter(
                x=buy_dates,
                y=buy_prices,
                mode="markers",
                name="Buy Signal",
                marker=dict(
                    symbol="triangle-up",
                    size=12,
                    color="green",
                    line=dict(width=1, color="darkgreen"),
                ),
                showlegend=True,
            ),
            row=1, col=1
        )
    
    # 5. Sell signals (triangle down, màu đỏ)
    sell_mask = df["label"] == 2
    if sell_mask.any():
        sell_dates = df.loc[sell_mask, "date"]
        sell_prices = df.loc[sell_mask, "high"] * 1.001  # Đặt marker hơi trên high để dễ nhìn
        fig.add_trace(
            go.Scatter(
                x=sell_dates,
                y=sell_prices,
                mode="markers",
                name="Sell Signal",
                marker=dict(
                    symbol="triangle-down",
                    size=12,
                    color="red",
                    line=dict(width=1, color="darkred"),
                ),
                showlegend=True,
            ),
            row=1, col=1
        )
    
    # 6. Volume bars
    colors = ["#26a69a" if close >= open else "#ef5350" 
              for close, open in zip(df["close"], df["open"])]
    fig.add_trace(
        go.Bar(
            x=df["date"],
            y=df["volume"],
            name="Volume",
            marker_color=colors,
            showlegend=False,
        ),
        row=2, col=1
    )
    
    # Tạo title với tham số
    param_str = ""
    if mode == "three_barrier":
        param_str = f"Barrier ATR: {label_params.get('barrier_atr_multiplier', 2.0)}, Horizon: {label_params.get('horizon', 4)}"
    elif mode == "sliding_window":
        param_str = f"Min Distance: {label_params.get('min_distance', 5)}, Percentile: {label_params.get('percentile_threshold', 0.7)}"
    
    # Update layout
    fig.update_layout(
        title=f"{coin_name} Price Chart with Labels ({mode.title()}) - {param_str}",
        height=800,
        xaxis_rangeslider_visible=False,
        hovermode="x unified",
        template="plotly_white",
    )
    
    # Update y-axis labels
    fig.update_yaxes(title_text="Price", row=1, col=1)
    fig.update_yaxes(title_text="Volume", row=2, col=1)
    fig.update_xaxes(title_text="Date", row=2, col=1)
    
    # Update x-axis để không bị overlap
    fig.update_xaxes(rangeslider_visible=False)
    
    return fig


In [481]:
# ========== CẤU HÌNH ==========
COIN = "TAO"  # Tên coin (không phân biệt hoa thường)
DATA_DIR = "../data/crypto-1200"
END_DATE = "2025-09-29"  # None để lấy tất cả dữ liệu
SHOW_LAST_N = 5000  # None để hiển thị tất cả, hoặc số lượng candles cuối cùng

# ========== CHỌN MODE LABELING ==========
# "three_barrier" hoặc "sliding_window"
# LABEL_MODE = "three_barrier"  # Đổi thành "sliding_window" để dùng mode khác
LABEL_MODE = "sliding_window"  # Đổi thành "sliding_window" để dùng mode khác

# ========== THAM SỐ THREE-BARRIER ==========
BARRIER_ATR_MULTIPLIER = 2.0
HORIZON = 4

# ========== THAM SỐ SLIDING WINDOW ==========
WINDOW_SIZE = 32  # Không dùng hiện tại, giữ cho tương lai
MIN_DISTANCE = 8  # Khoảng cách tối thiểu giữa peaks/dips (candles)
PROMINENCE = None  # None = tự động tính từ percentile
USE_PERCENTILE = True
PERCENTILE_THRESHOLD = 0.7  # 0.7 = top 30% price changes được coi là significant

## Detect NaN prices

Kiểm tra và phát hiện các giá trị NaN trong dữ liệu giá


In [482]:
def detect_nan_prices(df: pd.DataFrame):
    """
    Phát hiện và báo cáo các giá trị NaN trong dữ liệu giá.
    
    Returns:
        dict với thông tin về NaN values
    """
    price_cols = ["open", "high", "low", "close", "volume"]
    nan_info = {}
    
    print("=== NaN Detection Report ===\n")
    
    # Kiểm tra từng cột
    for col in price_cols:
        if col in df.columns:
            nan_count = df[col].isna().sum()
            nan_indices = df[df[col].isna()].index.tolist()
            nan_pct = (nan_count / len(df)) * 100
            
            nan_info[col] = {
                "count": nan_count,
                "percentage": nan_pct,
                "indices": nan_indices,
                "dates": df.loc[nan_indices, "date"].tolist() if "date" in df.columns else []
            }
            
            if nan_count > 0:
                print(f"⚠️  {col.upper()}: {nan_count:,} NaN values ({nan_pct:.2f}%)")
            else:
                print(f"✅ {col.upper()}: No NaN values")
    
    # Kiểm tra rows có bất kỳ NaN nào
    rows_with_nan = df[price_cols].isna().any(axis=1)
    nan_rows_count = rows_with_nan.sum()
    
    print(f"\n📊 Rows with any NaN: {nan_rows_count:,} ({(nan_rows_count/len(df))*100:.2f}%)")
    
    if nan_rows_count > 0:
        print("\n⚠️  Rows with NaN values:")
        nan_rows = df[rows_with_nan].copy()
        print(nan_rows[["date"] + price_cols].head(20))
        if nan_rows_count > 20:
            print(f"... and {nan_rows_count - 20} more rows")
    
    # Kiểm tra inf values
    print("\n=== Inf Detection ===")
    for col in price_cols:
        if col in df.columns:
            inf_count = np.isinf(df[col]).sum()
            if inf_count > 0:
                print(f"⚠️  {col.upper()}: {inf_count:,} Inf values")
            else:
                print(f"✅ {col.upper()}: No Inf values")
    
    # Kiểm tra giá trị bất thường (<= 0)
    print("\n=== Invalid Price Values (<= 0) ===")
    for col in ["open", "high", "low", "close"]:
        if col in df.columns:
            invalid_count = (df[col] <= 0).sum()
            if invalid_count > 0:
                print(f"⚠️  {col.upper()}: {invalid_count:,} values <= 0")
                invalid_rows = df[df[col] <= 0][["date", col]]
                print(invalid_rows.head(10))
            else:
                print(f"✅ {col.upper()}: All values > 0")
    
    print("\n" + "=" * 50)
    
    return nan_info


def detect_missing_timestamps_and_gaps(df: pd.DataFrame, expected_freq: str = "1h"):
    """
    Phát hiện missing timestamps và gaps trong dữ liệu.
    Đây là nguyên nhân chính gây ra gaps trên plot.
    
    Args:
        df: DataFrame với cột 'date'
        expected_freq: Tần suất mong đợi ('1h', '4h', '1d', etc.)
    
    Returns:
        dict với thông tin về gaps
    """
    if "date" not in df.columns:
        print("❌ No 'date' column found")
        return None
    
    print("=== Missing Timestamps & Gaps Detection ===\n")
    
    # Sắp xếp theo date
    df_sorted = df.sort_values("date").reset_index(drop=True)
    
    # Tạo expected date range
    min_date = df_sorted["date"].min()
    max_date = df_sorted["date"].max()
    
    # Chuyển đổi expected_freq thành pandas frequency
    freq_map = {
        "1h": "1H",
        "4h": "4H",
        "1d": "1D",
    }
    pandas_freq = freq_map.get(expected_freq.lower(), "1H")
    
    # Tạo expected date range
    expected_dates = pd.date_range(start=min_date, end=max_date, freq=pandas_freq, tz="UTC")
    actual_dates = set(df_sorted["date"].dt.tz_localize(None) if df_sorted["date"].dt.tz else df_sorted["date"])
    expected_dates_set = set(expected_dates)
    
    # Tìm missing timestamps
    missing_dates = sorted(expected_dates_set - actual_dates)
    missing_count = len(missing_dates)
    
    print(f"📅 Date Range: {min_date} to {max_date}")
    print(f"📊 Expected rows ({pandas_freq}): {len(expected_dates):,}")
    print(f"📊 Actual rows: {len(df_sorted):,}")
    print(f"⚠️  Missing timestamps: {missing_count:,} ({(missing_count/len(expected_dates))*100:.2f}%)")
    
    if missing_count > 0:
        print(f"\n⚠️  First 20 missing timestamps:")
        for i, date in enumerate(missing_dates[:20]):
            print(f"  {date}")
        if missing_count > 20:
            print(f"  ... and {missing_count - 20} more")
    
    # Tính gaps (khoảng cách giữa các timestamps liên tiếp)
    df_sorted["time_diff"] = df_sorted["date"].diff()
    expected_diff = pd.Timedelta(pandas_freq)
    
    # Tìm gaps lớn hơn expected
    large_gaps = df_sorted[df_sorted["time_diff"] > expected_diff * 1.5].copy()
    gap_count = len(large_gaps)
    
    print(f"\n📏 Large gaps (>1.5x expected interval): {gap_count:,}")
    if gap_count > 0:
        print(f"\n⚠️  First 10 large gaps:")
        for idx, row in large_gaps.head(10).iterrows():
            gap_duration = row["time_diff"]
            gap_size = gap_duration / expected_diff
            print(f"  {row['date']}: Gap = {gap_duration} ({gap_size:.1f}x expected)")
    
    # Tính số lượng gaps theo kích thước
    if gap_count > 0:
        gap_sizes = (large_gaps["time_diff"] / expected_diff).astype(int)
        gap_size_dist = gap_sizes.value_counts().sort_index()
        print(f"\n📊 Gap size distribution:")
        for size, count in gap_size_dist.head(10).items():
            print(f"  {size}x expected: {count:,} gaps")
    
    print("\n" + "=" * 50)
    
    return {
        "missing_dates": missing_dates,
        "missing_count": missing_count,
        "large_gaps": large_gaps,
        "gap_count": gap_count,
        "expected_dates": expected_dates,
        "expected_count": len(expected_dates),
        "actual_count": len(df_sorted),
    }


def detect_flat_price_periods(df: pd.DataFrame, min_periods: int = 5):
    """
    Phát hiện các khoảng thời gian giá không thay đổi (flat price periods).
    
    Args:
        df: DataFrame với cột 'date', 'close', 'open', 'high', 'low'
        min_periods: Số periods tối thiểu để coi là flat period
    
    Returns:
        dict với thông tin về flat periods
    """
    print("=== Flat Price Periods Detection ===\n")
    
    df_sorted = df.sort_values("date").reset_index(drop=True)
    
    # Kiểm tra giá close không đổi
    df_sorted["close_changed"] = df_sorted["close"].diff().abs() > 1e-8
    df_sorted["flat_group"] = (~df_sorted["close_changed"]).cumsum()
    
    # Đếm số periods liên tiếp có giá không đổi
    flat_groups = df_sorted.groupby("flat_group").agg({
        "date": ["first", "last", "count"],
        "close": "first",
    }).reset_index()
    
    flat_groups.columns = ["group", "start_date", "end_date", "duration", "price"]
    flat_groups = flat_groups[flat_groups["duration"] >= min_periods].sort_values("duration", ascending=False)
    
    total_flat_periods = len(flat_groups)
    total_flat_candles = flat_groups["duration"].sum()
    
    print(f"📊 Total flat price periods (≥{min_periods} candles): {total_flat_periods}")
    print(f"📊 Total candles with flat prices: {total_flat_candles:,} ({total_flat_candles/len(df_sorted)*100:.2f}%)")
    
    if total_flat_periods > 0:
        print(f"\n⚠️  Top 10 longest flat periods:")
        for idx, row in flat_groups.head(10).iterrows():
            print(f"  {row['start_date']} to {row['end_date']}: "
                  f"{int(row['duration'])} candles @ price {row['price']:.2f}")
    
    # Kiểm tra các cột khác
    for col in ["open", "high", "low"]:
        if col in df_sorted.columns:
            df_sorted[f"{col}_changed"] = df_sorted[col].diff().abs() > 1e-8
            unchanged_count = (~df_sorted[f"{col}_changed"]).sum()
            unchanged_pct = unchanged_count / len(df_sorted) * 100
            print(f"\n📊 {col.upper()} unchanged: {unchanged_count:,} candles ({unchanged_pct:.2f}%)")
    
    print("\n" + "=" * 50)
    
    return {
        "flat_periods": flat_groups,
        "total_flat_periods": total_flat_periods,
        "total_flat_candles": total_flat_candles,
    }

# Chạy detect NaN
nan_report = detect_nan_prices(df)


=== NaN Detection Report ===

✅ OPEN: No NaN values
✅ HIGH: No NaN values
✅ LOW: No NaN values
✅ CLOSE: No NaN values
✅ VOLUME: No NaN values

📊 Rows with any NaN: 0 (0.00%)

=== Inf Detection ===
✅ OPEN: No Inf values
✅ HIGH: No Inf values
✅ LOW: No Inf values
✅ CLOSE: No Inf values
✅ VOLUME: No Inf values

=== Invalid Price Values (<= 0) ===
✅ OPEN: All values > 0
✅ HIGH: All values > 0
✅ LOW: All values > 0
✅ CLOSE: All values > 0



In [483]:
# Visualize NaN trên chart
def visualize_nan_on_chart(df: pd.DataFrame, coin_name: str = "BTC", show_last_n: int = None):
    """
    Visualize candlestick chart với markers đánh dấu các điểm có NaN
    """
    # Lọc data nếu cần
    df_viz = df.tail(show_last_n).copy() if show_last_n else df.copy()
    
    # Tìm rows có NaN
    price_cols = ["open", "high", "low", "close", "volume"]
    nan_mask = df_viz[price_cols].isna().any(axis=1)
    nan_rows = df_viz[nan_mask]
    
    # Tạo figure
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.03,
        row_heights=[0.7, 0.3],
        subplot_titles=(f"{coin_name} Price Chart (NaN Detection)", "Volume")
    )
    
    # Candlestick
    fig.add_trace(
        go.Candlestick(
            x=df_viz["date"],
            open=df_viz["open"],
            high=df_viz["high"],
            low=df_viz["low"],
            close=df_viz["close"],
            name="Price",
            increasing_line_color="#26a69a",
            decreasing_line_color="#ef5350",
        ),
        row=1, col=1
    )
    
    # Mark NaN points
    if len(nan_rows) > 0:
        # Đánh dấu NaN với marker X màu đỏ
        fig.add_trace(
            go.Scatter(
                x=nan_rows["date"],
                y=nan_rows["close"].fillna(nan_rows["open"].fillna(0)),
                mode="markers",
                name="NaN Values",
                marker=dict(
                    symbol="x",
                    size=15,
                    color="red",
                    line=dict(width=2, color="darkred"),
                ),
                text=[f"NaN in: {', '.join([col for col in price_cols if pd.isna(nan_rows.iloc[i][col])])}" 
                      for i in range(len(nan_rows))],
                hovertemplate="<b>NaN Detected</b><br>Date: %{x}<br>%{text}<extra></extra>",
            ),
            row=1, col=1
        )
    
    # Volume
    colors = ["#26a69a" if close >= open else "#ef5350" 
              for close, open in zip(df_viz["close"], df_viz["open"])]
    fig.add_trace(
        go.Bar(
            x=df_viz["date"],
            y=df_viz["volume"],
            name="Volume",
            marker_color=colors,
            showlegend=False,
        ),
        row=2, col=1
    )
    
    fig.update_layout(
        title=f"{coin_name} - NaN Detection Chart (Red X = NaN values)",
        height=800,
        xaxis_rangeslider_visible=False,
        hovermode="x unified",
        template="plotly_white",
    )
    
    fig.update_yaxes(title_text="Price", row=1, col=1)
    fig.update_yaxes(title_text="Volume", row=2, col=1)
    fig.update_xaxes(title_text="Date", row=2, col=1)
    
    return fig

# Visualize nếu có NaN
if nan_report and any(info["count"] > 0 for info in nan_report.values()):
    print("\n📊 Visualizing NaN on chart...")
    fig_nan = visualize_nan_on_chart(df, coin_name=COIN.upper(), show_last_n=SHOW_LAST_N)
    fig_nan.show()
else:
    print("\n✅ No NaN values found - no visualization needed")



✅ No NaN values found - no visualization needed


## Visualize Gaps & Flat Periods

Visualize missing timestamps (gaps) và flat price periods trên chart


In [484]:
# Visualize gaps và flat periods trên chart
def visualize_gaps_and_flat_periods(
    df: pd.DataFrame, 
    gap_report: dict = None,
    flat_report: dict = None,
    coin_name: str = "BTC", 
    show_last_n: int = None
):
    """
    Visualize gaps (missing timestamps) và flat price periods trên chart.
    Đây là nguyên nhân chính gây ra gaps và vùng giá không đổi trên plot.
    """
    df_viz = df.tail(show_last_n).copy() if show_last_n else df.copy()
    df_viz = df_viz.sort_values("date").reset_index(drop=True)
    
    # Tạo figure
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.03,
        row_heights=[0.7, 0.3],
        subplot_titles=(f"{coin_name} - Gaps & Flat Periods", "Volume")
    )
    
    # Candlestick
    fig.add_trace(
        go.Candlestick(
            x=df_viz["date"],
            open=df_viz["open"],
            high=df_viz["high"],
            low=df_viz["low"],
            close=df_viz["close"],
            name="Price",
            increasing_line_color="#26a69a",
            decreasing_line_color="#ef5350",
        ),
        row=1, col=1
    )
    
    # Đánh dấu large gaps (missing timestamps)
    if gap_report and gap_report.get("large_gaps") is not None and len(gap_report["large_gaps"]) > 0:
        large_gaps = gap_report["large_gaps"].copy()
        # Chỉ lấy gaps trong df_viz date range
        date_min = df_viz["date"].min()
        date_max = df_viz["date"].max()
        large_gaps = large_gaps[(large_gaps["date"] >= date_min) & (large_gaps["date"] <= date_max)]
        
        if len(large_gaps) > 0:
            # Lấy close price tại các điểm gap
            gap_dates = large_gaps["date"].values
            gap_prices = []
            for date in gap_dates:
                price = df_viz[df_viz["date"] == date]["close"].values
                if len(price) > 0:
                    gap_prices.append(price[0])
                else:
                    gap_prices.append(df_viz["close"].iloc[0])
            
            fig.add_trace(
                go.Scatter(
                    x=gap_dates,
                    y=gap_prices,
                    mode="markers",
                    name="Large Gaps",
                    marker=dict(
                        symbol="diamond",
                        size=15,
                        color="orange",
                        line=dict(width=2, color="darkorange"),
                    ),
                    text=[f"Gap: {row['time_diff']}" for _, row in large_gaps.iterrows()],
                    hovertemplate="<b>Large Gap</b><br>Date: %{x}<br>Gap size: %{text}<extra></extra>",
                ),
                row=1, col=1
            )
    
    # Đánh dấu flat price periods
    if flat_report and flat_report.get("flat_periods") is not None and len(flat_report["flat_periods"]) > 0:
        flat_periods = flat_report["flat_periods"]
        date_min = df_viz["date"].min()
        date_max = df_viz["date"].max()
        
        # Lọc periods trong range
        flat_periods_filtered = flat_periods[
            (flat_periods["end_date"] >= date_min) & 
            (flat_periods["start_date"] <= date_max)
        ]
        
        if len(flat_periods_filtered) > 0:
            # Vẽ đường ngang cho các flat periods quan trọng (top 10)
            for idx, (_, row) in enumerate(flat_periods_filtered.head(10).iterrows()):
                start_date = max(row["start_date"], date_min)
                end_date = min(row["end_date"], date_max)
                
                # Vẽ đường ngang tại mức giá flat
                fig.add_trace(
                    go.Scatter(
                        x=[start_date, end_date],
                        y=[row["price"], row["price"]],
                        mode="lines",
                        name=f"Flat Period ({int(row['duration'])} candles)" if idx == 0 else "",
                        line=dict(color="red", width=2, dash="dash"),
                        showlegend=(idx == 0),
                        legendgroup="flat",
                        hovertemplate=f"Flat: {int(row['duration'])} candles @ {row['price']:.2f}<extra></extra>",
                    ),
                    row=1, col=1
                )
    
    # Volume
    colors = ["#26a69a" if close >= open else "#ef5350" 
              for close, open in zip(df_viz["close"], df_viz["open"])]
    fig.add_trace(
        go.Bar(
            x=df_viz["date"],
            y=df_viz["volume"],
            name="Volume",
            marker_color=colors,
            showlegend=False,
        ),
        row=2, col=1
    )
    
    fig.update_layout(
        title=f"{coin_name} - Gaps (Orange Diamonds) & Flat Periods (Red Dashed)",
        height=800,
        xaxis_rangeslider_visible=False,
        hovermode="x unified",
        template="plotly_white",
    )
    
    fig.update_yaxes(title_text="Price", row=1, col=1)
    fig.update_yaxes(title_text="Volume", row=2, col=1)
    fig.update_xaxes(title_text="Date", row=2, col=1)
    
    return fig

# Visualize gaps và flat periods (NGUYÊN NHÂN GÂY GAPS TRÊN PLOT)
print("\n📊 Visualizing gaps and flat periods on chart...")

# Kiểm tra và tự động chạy detections nếu chưa có
try:
    gap_report
    gap_report_exists = True
except NameError:
    print("⚠️  gap_report not found. Running detection...")
    gap_report = detect_missing_timestamps_and_gaps(df, expected_freq="1h")
    gap_report_exists = True

try:
    flat_report
    flat_report_exists = True
except NameError:
    print("⚠️  flat_report not found. Running detection...")
    flat_report = detect_flat_price_periods(df, min_periods=3)
    flat_report_exists = True

# Visualize
fig_gaps = visualize_gaps_and_flat_periods(
    df, 
    gap_report=gap_report if gap_report_exists else None, 
    flat_report=flat_report if flat_report_exists else None,
    coin_name=COIN.upper(), 
    show_last_n=SHOW_LAST_N
)
fig_gaps.show()



📊 Visualizing gaps and flat periods on chart...
⚠️  gap_report not found. Running detection...
=== Missing Timestamps & Gaps Detection ===

📅 Date Range: 2023-04-18 08:00:00+00:00 to 2025-09-29 00:00:00+00:00
📊 Expected rows (1H): 21,473
📊 Actual rows: 21,473
⚠️  Missing timestamps: 21,473 (100.00%)

⚠️  First 20 missing timestamps:
  2023-04-18 08:00:00+00:00
  2023-04-18 09:00:00+00:00
  2023-04-18 10:00:00+00:00
  2023-04-18 11:00:00+00:00
  2023-04-18 12:00:00+00:00
  2023-04-18 13:00:00+00:00
  2023-04-18 14:00:00+00:00
  2023-04-18 15:00:00+00:00
  2023-04-18 16:00:00+00:00
  2023-04-18 17:00:00+00:00
  2023-04-18 18:00:00+00:00
  2023-04-18 19:00:00+00:00
  2023-04-18 20:00:00+00:00
  2023-04-18 21:00:00+00:00
  2023-04-18 22:00:00+00:00
  2023-04-18 23:00:00+00:00
  2023-04-19 00:00:00+00:00
  2023-04-19 01:00:00+00:00
  2023-04-19 02:00:00+00:00
  2023-04-19 03:00:00+00:00
  ... and 21453 more

📏 Large gaps (>1.5x expected interval): 0

⚠️  flat_report not found. Running dete


'H' is deprecated and will be removed in a future version, please use 'h' instead.


'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



In [485]:
## Xử lý NaN (Optional)


def clean_nan_prices(df: pd.DataFrame, method: str = "forward_fill"):
    """
    Xử lý NaN values trong dữ liệu giá.
    
    Args:
        df: DataFrame với dữ liệu giá
        method: 
            - "forward_fill": Fill NaN với giá trị trước đó
            - "backward_fill": Fill NaN với giá trị sau đó
            - "interpolate": Nội suy tuyến tính
            - "drop": Xóa các rows có NaN
    """
    df_clean = df.copy()
    price_cols = ["open", "high", "low", "close", "volume"]
    
    original_count = len(df_clean)
    nan_before = df_clean[price_cols].isna().sum().sum()
    
    if method == "forward_fill":
        df_clean[price_cols] = df_clean[price_cols].ffill()
        print("✅ Applied forward fill")
    elif method == "backward_fill":
        df_clean[price_cols] = df_clean[price_cols].bfill()
        print("✅ Applied backward fill")
    elif method == "interpolate":
        df_clean[price_cols] = df_clean[price_cols].interpolate(method="linear")
        print("✅ Applied linear interpolation")
    elif method == "drop":
        df_clean = df_clean.dropna(subset=price_cols)
        print("✅ Dropped rows with NaN")
    else:
        print("❌ Unknown method, no cleaning applied")
        return df_clean
    
    nan_after = df_clean[price_cols].isna().sum().sum()
    final_count = len(df_clean)
    
    print(f"\n📊 Cleaning Summary:")
    print(f"  NaN before: {nan_before:,}")
    print(f"  NaN after: {nan_after:,}")
    print(f"  Rows before: {original_count:,}")
    print(f"  Rows after: {final_count:,}")
    print(f"  Rows removed: {original_count - final_count:,}")
    
    return df_clean

# Xử lý NaN nếu được cấu hình
if HANDLE_NAN_METHOD and HANDLE_NAN_METHOD in ["forward_fill", "backward_fill", "interpolate", "drop"]:
    print(f"🔧 Cleaning NaN using method: {HANDLE_NAN_METHOD}")
    df = clean_nan_prices(df, method=HANDLE_NAN_METHOD)
    print("\n✅ Data cleaned! Continuing with cleaned data...\n")
elif HANDLE_NAN_METHOD:
    print(f"⚠️  Unknown method '{HANDLE_NAN_METHOD}', skipping cleaning")
else:
    print("ℹ️  NaN cleaning skipped (HANDLE_NAN_METHOD = None)")
    print("   Set HANDLE_NAN_METHOD to 'forward_fill', 'backward_fill', 'interpolate', or 'drop' to clean NaN")


ℹ️  NaN cleaning skipped (HANDLE_NAN_METHOD = None)
   Set HANDLE_NAN_METHOD to 'forward_fill', 'backward_fill', 'interpolate', or 'drop' to clean NaN


In [486]:
# Load dữ liệu
print(f"Loading data for {COIN}...")
df = load_coin_data(COIN, data_dir=DATA_DIR, end_date=END_DATE)
print(f"Loaded {len(df)} candles")
print(f"Date range: {df['date'].min()} to {df['date'].max()}")
df.head()


Loading data for TAO...
Loaded 22571 candles
Date range: 2023-03-03 14:00:00+00:00 to 2025-09-29 00:00:00+00:00


Unnamed: 0,date,open,high,low,close,volume
0,2023-03-03 14:00:00+00:00,80.0,119.0,80.0,98.39,5355.21
1,2023-03-03 15:00:00+00:00,98.39,101.92,91.07,101.65,3922.54
2,2023-03-03 16:00:00+00:00,101.65,107.44,100.0,105.94,3908.69
3,2023-03-03 17:00:00+00:00,105.94,107.39,93.6,97.35,3767.04
4,2023-03-03 18:00:00+00:00,97.35,105.0,95.0,102.7,1883.76


In [487]:
# Tính toán indicators
print("Calculating indicators...")
df = calculate_indicators(df)
print("Done!")
df.head()


Calculating indicators...
Done!


Unnamed: 0,date,open,high,low,close,volume,atr,bb_upper,bb_middle,bb_lower,sma_50,sma_200
0,2023-03-03 14:00:00+00:00,80.0,119.0,80.0,98.39,5355.21,,,,,,
1,2023-03-03 15:00:00+00:00,98.39,101.92,91.07,101.65,3922.54,,,,,,
2,2023-03-03 16:00:00+00:00,101.65,107.44,100.0,105.94,3908.69,,,,,,
3,2023-03-03 17:00:00+00:00,105.94,107.39,93.6,97.35,3767.04,,,,,,
4,2023-03-03 18:00:00+00:00,97.35,105.0,95.0,102.7,1883.76,,,,,,


In [488]:
# Tính toán labels
print(f"Calculating labels using {LABEL_MODE} mode...")
if LABEL_MODE == "three_barrier":
    df = calculate_labels(
        df,
        mode=LABEL_MODE,
        barrier_atr_multiplier=BARRIER_ATR_MULTIPLIER,
        horizon=HORIZON,
    )
elif LABEL_MODE == "sliding_window":
    df = calculate_labels(
        df,
        mode=LABEL_MODE,
        window_size=WINDOW_SIZE,
        min_distance=MIN_DISTANCE,
        prominence=PROMINENCE,
        use_percentile=USE_PERCENTILE,
        percentile_threshold=PERCENTILE_THRESHOLD,
    )
print("Done!")

# Thống kê labels
label_counts = df["label"].value_counts().sort_index()
print("\nLabel distribution:")
print(f"  Hold (0): {label_counts.get(0, 0):,} ({label_counts.get(0, 0)/len(df)*100:.2f}%)")
print(f"  Buy  (1): {label_counts.get(1, 0):,} ({label_counts.get(1, 0)/len(df)*100:.2f}%)")
print(f"  Sell (2): {label_counts.get(2, 0):,} ({label_counts.get(2, 0)/len(df)*100:.2f}%)")
df.head()


Calculating labels using sliding_window mode...
Done!

Label distribution:
  Hold (0): 20,847 (92.36%)
  Buy  (1): 852 (3.77%)
  Sell (2): 872 (3.86%)


Unnamed: 0,date,open,high,low,close,volume,atr,bb_upper,bb_middle,bb_lower,sma_50,sma_200,label
0,2023-03-03 14:00:00+00:00,80.0,119.0,80.0,98.39,5355.21,,,,,,,0
1,2023-03-03 15:00:00+00:00,98.39,101.92,91.07,101.65,3922.54,,,,,,,0
2,2023-03-03 16:00:00+00:00,101.65,107.44,100.0,105.94,3908.69,,,,,,,2
3,2023-03-03 17:00:00+00:00,105.94,107.39,93.6,97.35,3767.04,,,,,,,0
4,2023-03-03 18:00:00+00:00,97.35,105.0,95.0,102.7,1883.76,,,,,,,0


## So sánh 2 mode

Có thể chạy lại với mode khác để so sánh


In [489]:
# Xem các điểm có buy signal
buy_signals = df[df["label"] == 1].copy()
print(f"Total buy signals: {len(buy_signals)}")
if len(buy_signals) > 0:
    print("\nFirst 10 buy signals:")
    print(buy_signals[["date", "open", "high", "low", "close", "atr", "label"]].head(10))


Total buy signals: 852

First 10 buy signals:
                         date    open    high    low  close       atr  label
8   2023-03-03 22:00:00+00:00   96.67   99.00  93.00  94.63       NaN      1
23  2023-03-04 13:00:00+00:00  102.52  103.00  97.00  97.25  5.886452      1
32  2023-03-04 22:00:00+00:00   99.94  103.00  99.22  99.22  5.038258      1
42  2023-03-05 08:00:00+00:00   99.30  100.27  95.00  95.84  4.458416      1
50  2023-03-05 16:00:00+00:00   95.92   97.00  90.00  92.31  5.511194      1
66  2023-03-06 08:00:00+00:00   94.67   95.78  90.81  90.87  4.062998      1
570 2023-03-27 08:00:00+00:00   66.89   68.74  66.35  68.54  2.037143      1
580 2023-03-27 18:00:00+00:00   69.98   70.98  66.50  67.10  2.816216      1
589 2023-03-28 03:00:00+00:00   68.49   69.24  67.94  67.94  2.208944      1
612 2023-03-29 02:00:00+00:00   66.61   67.52  61.00  62.65  2.345792      1


In [490]:
# Xem các điểm có sell signal
sell_signals = df[df["label"] == 2].copy()
print(f"Total sell signals: {len(sell_signals)}")
if len(sell_signals) > 0:
    print("\nFirst 10 sell signals:")
    print(sell_signals[["date", "open", "high", "low", "close", "atr", "label"]].head(10))


Total sell signals: 872

First 10 sell signals:
                         date    open    high     low   close           atr  \
2   2023-03-03 16:00:00+00:00  101.65  107.44  100.00  105.94           NaN   
12  2023-03-04 02:00:00+00:00   98.82  101.20   95.00  101.20           NaN   
21  2023-03-04 11:00:00+00:00  102.58  105.00  100.00  104.80  6.022158e+00   
35  2023-03-05 01:00:00+00:00  102.50  103.00   98.55  102.96  4.676579e+00   
44  2023-03-05 10:00:00+00:00   97.64  101.99   96.04  101.90  4.587614e+00   
56  2023-03-05 22:00:00+00:00   99.27   99.50   96.52   99.45  4.847901e+00   
67  2023-03-06 09:00:00+00:00   90.87   93.87   87.74   93.68  4.210641e+00   
319 2023-03-16 21:00:00+00:00   94.87   94.87   94.87   94.87  3.619275e-08   
573 2023-03-27 11:00:00+00:00   70.41   72.66   68.50   72.42  2.211550e+00   
583 2023-03-27 21:00:00+00:00   68.10   69.90   67.43   69.72  2.706077e+00   

     label  
2        2  
12       2  
21       2  
35       2  
44       2  
56  

## So sánh side-by-side (2 mode cùng lúc)

Chạy cell này để xem cả 2 mode cùng lúc


In [491]:
# Tính labels cho cả 2 mode
print("Calculating labels for both modes...")
df_tb = calculate_labels(
    df.drop(columns=["label"], errors="ignore"),
    mode="three_barrier",
    barrier_atr_multiplier=BARRIER_ATR_MULTIPLIER,
    horizon=HORIZON,
)
df_sw = calculate_labels(
    df.drop(columns=["label"], errors="ignore"),
    mode="sliding_window",
    window_size=WINDOW_SIZE,
    min_distance=MIN_DISTANCE,
    prominence=PROMINENCE,
    use_percentile=USE_PERCENTILE,
    percentile_threshold=PERCENTILE_THRESHOLD,
)

# So sánh thống kê
print("\n=== Three-Barrier ===")
tb_counts = df_tb["label"].value_counts().sort_index()
print(f"  Hold (0): {tb_counts.get(0, 0):,} ({tb_counts.get(0, 0)/len(df_tb)*100:.2f}%)")
print(f"  Buy  (1): {tb_counts.get(1, 0):,} ({tb_counts.get(1, 0)/len(df_tb)*100:.2f}%)")
print(f"  Sell (2): {tb_counts.get(2, 0):,} ({tb_counts.get(2, 0)/len(df_tb)*100:.2f}%)")

print("\n=== Sliding Window ===")
sw_counts = df_sw["label"].value_counts().sort_index()
print(f"  Hold (0): {sw_counts.get(0, 0):,} ({sw_counts.get(0, 0)/len(df_sw)*100:.2f}%)")
print(f"  Buy  (1): {sw_counts.get(1, 0):,} ({sw_counts.get(1, 0)/len(df_sw)*100:.2f}%)")
print(f"  Sell (2): {sw_counts.get(2, 0):,} ({sw_counts.get(2, 0)/len(df_sw)*100:.2f}%)")


Calculating labels for both modes...



=== Three-Barrier ===
  Hold (0): 20,120 (89.14%)
  Buy  (1): 1,795 (7.95%)
  Sell (2): 656 (2.91%)

=== Sliding Window ===
  Hold (0): 20,847 (92.36%)
  Buy  (1): 852 (3.77%)
  Sell (2): 872 (3.86%)


In [492]:
# Visualize cả 2 mode side-by-side
from plotly.subplots import make_subplots

# Lọc data nếu cần
df_tb_viz = df_tb.tail(SHOW_LAST_N).copy() if SHOW_LAST_N else df_tb.copy()
df_sw_viz = df_sw.tail(SHOW_LAST_N).copy() if SHOW_LAST_N else df_sw.copy()

# Tạo figure với 2 rows, 1 col cho mỗi mode
fig = make_subplots(
    rows=2, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.05,
    row_heights=[0.5, 0.5],
    subplot_titles=(
        f"{COIN.upper()} - Three-Barrier Method",
        f"{COIN.upper()} - Sliding Window Method"
    )
)

# Helper function để thêm traces
def add_candlestick_traces(fig, df, row):
    # Candlestick
    fig.add_trace(
        go.Candlestick(
            x=df["date"],
            open=df["open"],
            high=df["high"],
            low=df["low"],
            close=df["close"],
            name="Price",
            increasing_line_color="#26a69a",
            decreasing_line_color="#ef5350",
            showlegend=(row == 1),
        ),
        row=row, col=1
    )
    
    # BBands
    fig.add_trace(
        go.Scatter(x=df["date"], y=df["bb_upper"], mode="lines", name="BB Upper",
                  line=dict(color="rgba(128,128,128,0.3)", width=1, dash="dash"), showlegend=False),
        row=row, col=1
    )
    fig.add_trace(
        go.Scatter(x=df["date"], y=df["bb_middle"], mode="lines", name="BB Middle",
                  line=dict(color="rgba(128,128,128,0.5)", width=1), showlegend=False),
        row=row, col=1
    )
    fig.add_trace(
        go.Scatter(x=df["date"], y=df["bb_lower"], mode="lines", name="BB Lower",
                  line=dict(color="rgba(128,128,128,0.3)", width=1, dash="dash"),
                  fill="tonexty", fillcolor="rgba(128,128,128,0.1)", showlegend=False),
        row=row, col=1
    )
    
    # SMA
    fig.add_trace(
        go.Scatter(x=df["date"], y=df["sma_50"], mode="lines", name="SMA 50",
                  line=dict(color="blue", width=1.5), showlegend=False),
        row=row, col=1
    )
    fig.add_trace(
        go.Scatter(x=df["date"], y=df["sma_200"], mode="lines", name="SMA 200",
                  line=dict(color="orange", width=1.5), showlegend=False),
        row=row, col=1
    )
    
    # Buy signals
    buy_mask = df["label"] == 1
    if buy_mask.any():
        buy_dates = df.loc[buy_mask, "date"]
        buy_prices = df.loc[buy_mask, "low"] * 0.995
        fig.add_trace(
            go.Scatter(x=buy_dates, y=buy_prices, mode="markers", name="Buy",
                      marker=dict(symbol="triangle-up", size=10, color="green",
                                line=dict(width=1, color="darkgreen")), showlegend=False),
            row=row, col=1
        )
    
    # Sell signals
    sell_mask = df["label"] == 2
    if sell_mask.any():
        sell_dates = df.loc[sell_mask, "date"]
        sell_prices = df.loc[sell_mask, "high"] * 1.005
        fig.add_trace(
            go.Scatter(x=sell_dates, y=sell_prices, mode="markers", name="Sell",
                      marker=dict(symbol="triangle-down", size=10, color="red",
                                line=dict(width=1, color="darkred")), showlegend=False),
            row=row, col=1
        )

# Thêm traces cho cả 2 mode
add_candlestick_traces(fig, df_tb_viz, 1)
add_candlestick_traces(fig, df_sw_viz, 2)

# Update layout
fig.update_layout(
    title=f"{COIN.upper()} - Label Comparison: Three-Barrier vs Sliding Window",
    height=1200,
    xaxis_rangeslider_visible=False,
    hovermode="x unified",
    template="plotly_white",
)

fig.update_yaxes(title_text="Price", row=1, col=1)
fig.update_yaxes(title_text="Price", row=2, col=1)
fig.update_xaxes(title_text="Date", row=2, col=1)

fig.show()


In [493]:
import numpy as np
import pandas as pd
import talib
from pathlib import Path
from numba import jit
import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [494]:
@jit(nopython=True)
def _calculate_labels_numba(
    close_prices: np.ndarray,
    atr_values: np.ndarray,
    barrier_atr_multiplier: float,
    horizon: int,
) -> np.ndarray:
    """
    Numba-accelerated triple-barrier labeling.
    Returns multiclass labels: 0=hold, 1=buy, 2=sell
    """
    n = len(close_prices)
    label_multi = np.zeros(n, dtype=np.int64)

    for i in range(n - 1):
        entry_price = close_prices[i]
        atr = atr_values[i]
        if np.isnan(atr) or atr <= 0:
            continue

        barrier_distance = atr * barrier_atr_multiplier
        upper_barrier = entry_price + barrier_distance
        lower_barrier = entry_price - barrier_distance
        end_idx = min(i + 1 + horizon, n)

        for j in range(i + 1, end_idx):
            if close_prices[j] >= upper_barrier:
                label_multi[i] = 1  # buy
                break
            elif close_prices[j] <= lower_barrier:
                label_multi[i] = 2  # sell
                break
    return label_multi


In [495]:
def load_coin_data(coin: str, data_dir: str = "./data/crypto-1200", end_date: str = None):
    """
    Loads and filters Parquet data for a single coin.
    Tham khảo từ crypto_datamodule.py
    """
    data_path = Path(data_dir)
    filename = f"{coin.lower()}.parquet"
    filepath = data_path / "data" / filename
    
    if not filepath.exists():
        raise FileNotFoundError(f"File not found: {filepath}")
    
    df = pd.read_parquet(filepath)
    df["date"] = pd.to_datetime(df["date"])
    
    if end_date:
        end_date_dt = pd.to_datetime(end_date, utc=True)
        df = df[df["date"] <= end_date_dt]
    
    if df.empty:
        raise ValueError(f"No data found for {coin}")
    
    df = df.drop_duplicates(subset=["date"]).sort_values("date")
    return df[["date", "open", "high", "low", "close", "volume"]].copy()


In [496]:
def calculate_indicators(df: pd.DataFrame):
    """
    Tính toán các chỉ báo kỹ thuật: ATR, BBands, SMA
    Tham khảo từ crypto_datamodule.py
    """
    open_p = df["open"].values
    high_p = df["high"].values
    low_p = df["low"].values
    close_p = df["close"].values
    volume = df["volume"].values
    
    # ATR (14 periods)
    atr = talib.ATR(high_p, low_p, close_p, timeperiod=14)
    
    # Bollinger Bands (20 periods)
    bb_upper, bb_middle, bb_lower = talib.BBANDS(close_p, timeperiod=20)
    
    # SMA (50 và 200 periods)
    sma_50 = talib.SMA(close_p, timeperiod=50)
    sma_200 = talib.SMA(close_p, timeperiod=200)
    
    # Thêm vào dataframe
    df = df.copy()
    df["atr"] = atr
    df["bb_upper"] = bb_upper
    df["bb_middle"] = bb_middle
    df["bb_lower"] = bb_lower
    df["sma_50"] = sma_50
    df["sma_200"] = sma_200
    
    return df


In [497]:
def calculate_labels(df: pd.DataFrame, barrier_atr_multiplier: float = 2.0, horizon: int = 4):
    """
    Tính toán labels 3-class (0=hold, 1=buy, 2=sell) sử dụng triple-barrier method
    """
    close_prices = df["close"].values
    atr_values = df["atr"].values
    
    labels = _calculate_labels_numba(
        close_prices,
        atr_values,
        barrier_atr_multiplier,
        horizon
    )
    
    df = df.copy()
    df["label"] = labels
    
    return df


In [498]:
def visualize_labels(
    df: pd.DataFrame,
    coin_name: str = "BTC",
    show_last_n: int = None,
    barrier_atr_multiplier: float = 2.0,
    horizon: int = 4
):
    """
    Visualize candlestick chart với labels, BBands, và SMA
    
    Args:
        df: DataFrame với columns: date, open, high, low, close, volume, atr, bb_upper, bb_middle, bb_lower, sma_50, sma_200, label
        coin_name: Tên coin để hiển thị
        show_last_n: Chỉ hiển thị n candles cuối (None = hiển thị tất cả)
        barrier_atr_multiplier: Multiplier cho barrier distance
        horizon: Horizon cho triple-barrier labeling
    """
    # Lọc data nếu cần
    if show_last_n:
        df = df.tail(show_last_n).copy()
    
    # Tạo subplot với volume ở dưới
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.03,
        row_heights=[0.7, 0.3],
        subplot_titles=(f"{coin_name} Price Chart", "Volume")
    )
    
    # 1. Candlestick chart
    fig.add_trace(
        go.Candlestick(
            x=df["date"],
            open=df["open"],
            high=df["high"],
            low=df["low"],
            close=df["close"],
            name="Price",
            increasing_line_color="#26a69a",
            decreasing_line_color="#ef5350",
        ),
        row=1, col=1
    )
    
    # 2. Bollinger Bands
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["bb_upper"],
            mode="lines",
            name="BB Upper",
            line=dict(color="rgba(128,128,128,0.3)", width=1, dash="dash"),
            showlegend=True,
        ),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["bb_middle"],
            mode="lines",
            name="BB Middle",
            line=dict(color="rgba(128,128,128,0.5)", width=1),
            showlegend=True,
        ),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["bb_lower"],
            mode="lines",
            name="BB Lower",
            line=dict(color="rgba(128,128,128,0.3)", width=1, dash="dash"),
            showlegend=True,
            fill="tonexty",
            fillcolor="rgba(128,128,128,0.1)",
        ),
        row=1, col=1
    )
    
    # 3. SMA
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["sma_50"],
            mode="lines",
            name="SMA 50",
            line=dict(color="blue", width=1.5),
            showlegend=True,
        ),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=df["date"],
            y=df["sma_200"],
            mode="lines",
            name="SMA 200",
            line=dict(color="orange", width=1.5),
            showlegend=True,
        ),
        row=1, col=1
    )
    
    # 4. Buy signals (triangle up, màu xanh)
    buy_mask = df["label"] == 1
    if buy_mask.any():
        buy_dates = df.loc[buy_mask, "date"]
        buy_prices = df.loc[buy_mask, "low"] * 0.995  # Đặt marker hơi dưới low để dễ nhìn
        fig.add_trace(
            go.Scatter(
                x=buy_dates,
                y=buy_prices,
                mode="markers",
                name="Buy Signal",
                marker=dict(
                    symbol="triangle-up",
                    size=12,
                    color="green",
                    line=dict(width=1, color="darkgreen"),
                ),
                showlegend=True,
            ),
            row=1, col=1
        )
    
    # 5. Sell signals (triangle down, màu đỏ)
    sell_mask = df["label"] == 2
    if sell_mask.any():
        sell_dates = df.loc[sell_mask, "date"]
        sell_prices = df.loc[sell_mask, "high"] * 1.005  # Đặt marker hơi trên high để dễ nhìn
        fig.add_trace(
            go.Scatter(
                x=sell_dates,
                y=sell_prices,
                mode="markers",
                name="Sell Signal",
                marker=dict(
                    symbol="triangle-down",
                    size=12,
                    color="red",
                    line=dict(width=1, color="darkred"),
                ),
                showlegend=True,
            ),
            row=1, col=1
        )
    
    # 6. Volume bars
    colors = ["#26a69a" if close >= open else "#ef5350" 
              for close, open in zip(df["close"], df["open"])]
    fig.add_trace(
        go.Bar(
            x=df["date"],
            y=df["volume"],
            name="Volume",
            marker_color=colors,
            showlegend=False,
        ),
        row=2, col=1
    )
    
    # Update layout
    fig.update_layout(
        title=f"{coin_name} Price Chart with Labels (Barrier ATR Multiplier: {barrier_atr_multiplier}, Horizon: {horizon})",
        height=800,
        xaxis_rangeslider_visible=False,
        hovermode="x unified",
        template="plotly_white",
    )
    
    # Update y-axis labels
    fig.update_yaxes(title_text="Price", row=1, col=1)
    fig.update_yaxes(title_text="Volume", row=2, col=1)
    fig.update_xaxes(title_text="Date", row=2, col=1)
    
    # Update x-axis để không bị overlap
    fig.update_xaxes(rangeslider_visible=False)
    
    return fig


In [499]:
# Cấu hình
COIN = "PEPE"  # Tên coin (không phân biệt hoa thường)
DATA_DIR = "../data/crypto-1200"
END_DATE = "2025-09-29"  # None để lấy tất cả dữ liệu
BARRIER_ATR_MULTIPLIER = 2.0
HORIZON = 2
SHOW_LAST_N = 1000  # None để hiển thị tất cả, hoặc số lượng candles cuối cùng


In [500]:
# Load dữ liệu
print(f"Loading data for {COIN}...")
df = load_coin_data(COIN, data_dir=DATA_DIR, end_date=END_DATE)
print(f"Loaded {len(df)} candles")
print(f"Date range: {df['date'].min()} to {df['date'].max()}")
df.head()


Loading data for PEPE...
Loaded 21473 candles
Date range: 2023-04-18 08:00:00+00:00 to 2025-09-29 00:00:00+00:00


Unnamed: 0,date,open,high,low,close,volume
0,2023-04-18 08:00:00+00:00,2.5e-08,2.5e-08,2.5e-08,2.5e-08,10000000.0
1,2023-04-18 09:00:00+00:00,2.5e-08,2.5e-08,2.5e-08,2.5e-08,0.0
2,2023-04-18 10:00:00+00:00,2.5e-08,8.565e-08,2.5e-08,8.543e-08,520410900000.0
3,2023-04-18 11:00:00+00:00,8.543e-08,9.18e-08,7.892e-08,9.17e-08,526105400000.0
4,2023-04-18 12:00:00+00:00,9.17e-08,9.827e-08,8.383e-08,9.186e-08,398488100000.0


In [501]:
# Tính toán indicators
print("Calculating indicators...")
df = calculate_indicators(df)
print("Done!")
df.head()


Calculating indicators...
Done!


Unnamed: 0,date,open,high,low,close,volume,atr,bb_upper,bb_middle,bb_lower,sma_50,sma_200
0,2023-04-18 08:00:00+00:00,2.5e-08,2.5e-08,2.5e-08,2.5e-08,10000000.0,,,,,,
1,2023-04-18 09:00:00+00:00,2.5e-08,2.5e-08,2.5e-08,2.5e-08,0.0,,,,,,
2,2023-04-18 10:00:00+00:00,2.5e-08,8.565e-08,2.5e-08,8.543e-08,520410900000.0,,,,,,
3,2023-04-18 11:00:00+00:00,8.543e-08,9.18e-08,7.892e-08,9.17e-08,526105400000.0,,,,,,
4,2023-04-18 12:00:00+00:00,9.17e-08,9.827e-08,8.383e-08,9.186e-08,398488100000.0,,,,,,


In [502]:
# Tính toán labels
print("Calculating labels...")
df = calculate_labels(df, barrier_atr_multiplier=BARRIER_ATR_MULTIPLIER, horizon=HORIZON)
print("Done!")

# Thống kê labels
label_counts = df["label"].value_counts().sort_index()
print("\nLabel distribution:")
print(f"  Hold (0): {label_counts.get(0, 0):,} ({label_counts.get(0, 0)/len(df)*100:.2f}%)")
print(f"  Buy  (1): {label_counts.get(1, 0):,} ({label_counts.get(1, 0)/len(df)*100:.2f}%)")
print(f"  Sell (2): {label_counts.get(2, 0):,} ({label_counts.get(2, 0)/len(df)*100:.2f}%)")
df.head()


Calculating labels...
Done!

Label distribution:
  Hold (0): 20,591 (95.89%)
  Buy  (1): 492 (2.29%)
  Sell (2): 390 (1.82%)


Unnamed: 0,date,open,high,low,close,volume,atr,bb_upper,bb_middle,bb_lower,sma_50,sma_200,label
0,2023-04-18 08:00:00+00:00,2.5e-08,2.5e-08,2.5e-08,2.5e-08,10000000.0,,,,,,,0
1,2023-04-18 09:00:00+00:00,2.5e-08,2.5e-08,2.5e-08,2.5e-08,0.0,,,,,,,0
2,2023-04-18 10:00:00+00:00,2.5e-08,8.565e-08,2.5e-08,8.543e-08,520410900000.0,,,,,,,0
3,2023-04-18 11:00:00+00:00,8.543e-08,9.18e-08,7.892e-08,9.17e-08,526105400000.0,,,,,,,0
4,2023-04-18 12:00:00+00:00,9.17e-08,9.827e-08,8.383e-08,9.186e-08,398488100000.0,,,,,,,0


In [None]:
# Visualize
fig = visualize_labels(
    df,
    coin_name=COIN.upper(),
    show_last_n=SHOW_LAST_N,
    barrier_atr_multiplier=BARRIER_ATR_MULTIPLIER,
    horizon=HORIZON
)
fig.show()


## Phân tích chi tiết

Có thể xem chi tiết tại các điểm có signal


In [None]:
# Xem các điểm có buy signal
buy_signals = df[df["label"] == 1].copy()
print(f"Total buy signals: {len(buy_signals)}")
if len(buy_signals) > 0:
    print("\nFirst 10 buy signals:")
    print(buy_signals[["date", "open", "high", "low", "close", "atr", "label"]].head(10))


Total buy signals: 492

First 10 buy signals:
                         date          open          high           low  \
18  2023-04-19 02:00:00+00:00  1.740300e-07  2.084000e-07  1.740300e-07   
19  2023-04-19 03:00:00+00:00  1.906400e-07  2.334300e-07  1.893800e-07   
104 2023-04-22 16:00:00+00:00  2.738500e-07  2.764200e-07  2.626100e-07   
250 2023-04-28 18:00:00+00:00  2.783500e-07  2.783500e-07  2.683600e-07   
269 2023-04-29 13:00:00+00:00  3.256700e-07  3.345700e-07  3.238700e-07   
270 2023-04-29 14:00:00+00:00  3.297300e-07  3.386900e-07  3.250000e-07   
277 2023-04-29 21:00:00+00:00  4.226300e-07  4.492700e-07  4.062500e-07   
278 2023-04-29 22:00:00+00:00  4.168400e-07  4.187800e-07  4.000000e-07   
280 2023-04-30 00:00:00+00:00  4.800000e-07  4.990000e-07  4.379800e-07   
281 2023-04-30 01:00:00+00:00  4.662500e-07  4.912700e-07  4.609900e-07   

            close           atr  label  
18   1.906400e-07  3.433792e-08      1  
19   2.204800e-07  3.503164e-08      1  
104  

In [None]:
# Xem các điểm có sell signal
sell_signals = df[df["label"] == 2].copy()
print(f"Total sell signals: {len(sell_signals)}")
if len(sell_signals) > 0:
    print("\nFirst 10 sell signals:")
    print(sell_signals[["date", "open", "high", "low", "close", "atr", "label"]].head(10))


Total sell signals: 390

First 10 sell signals:
                         date          open          high           low  \
23  2023-04-19 07:00:00+00:00  2.861000e-07  3.748200e-07  2.701600e-07   
297 2023-04-30 17:00:00+00:00  7.749000e-07  8.200000e-07  7.100000e-07   
416 2023-05-05 16:00:00+00:00  3.857400e-06  4.440000e-06  3.748700e-06   
560 2023-05-11 16:00:00+00:00  1.576600e-06  1.615900e-06  1.574200e-06   
568 2023-05-12 00:00:00+00:00  1.393500e-06  1.407900e-06  1.374600e-06   
569 2023-05-12 01:00:00+00:00  1.377900e-06  1.386100e-06  1.324700e-06   
638 2023-05-14 22:00:00+00:00  1.805700e-06  1.810300e-06  1.771300e-06   
639 2023-05-14 23:00:00+00:00  1.786000e-06  1.808600e-06  1.740000e-06   
693 2023-05-17 05:00:00+00:00  1.673600e-06  1.674300e-06  1.653700e-06   
694 2023-05-17 06:00:00+00:00  1.654900e-06  1.658800e-06  1.643900e-06   

            close           atr  label  
23   3.500000e-07  4.567277e-08      2  
297  7.585800e-07  8.004270e-08      2  
416