In [None]:
# --------------------------------------------
# RSI Trend-Following Analysis (OMQS-Lab)
# ----------------------------------------------
# - Fetches market data via tvDatafeed
# - Computes classical RSI(14) and a 5-period low-pass filter (LPF) on RSI
# - Builds distribution and forward-return tables
# - Plots price + RSI/RSI_LPF with 50/30/70 reference lines
# - (Optional) exports tables to CSV and figure to HTML
# ------------------------------------------------

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, Tuple
import numpy as np
import pandas as pd
from tvDatafeed import TvDatafeed, Interval

import plotly.graph_objects as go
from plotly.subplots import make_subplots


# ========== Configuration ==========
@dataclass
class RunConfig:
    symbol: str = "US500"                 # e.g., "US500", "BTCUSD", "EURUSD", "XAUUSD"
    exchange: str = "CAPITALCOM"          # e.g., "CAPITALCOM", "BINANCE", "OANDA"
    interval: Interval = Interval.in_4_hour
    n_bars: int = 10_000                  # number of candles
    rsi_period: int = 14
    rsi_lpf_period: int = 5               # 5-period LPF for RSI
    horizons: Tuple[int, ...] = (1, 5, 10)  # forward-return horizons (bars)
    export_tables: bool = False
    export_dir: str = "./tables"
    export_figure_html: bool = False
    figure_html_path: str = "./figures/rsi_chart.html"
    chart_title_prefix: str = "RSI Trend-Following"


# ========== Indicator Functions ==========
def rsi(series: pd.Series, period: int = 14) -> pd.Series:
    """
    Classical RSI with simple rolling means for average gains/losses.

    Notes:
    - Wilder's original smoothing is equivalent to an EMA.
      Here we keep simple rolling averages to match the paper’s baseline.
    - Returns 50 for initial NaNs to keep plotting/tables convenient.
    """
    delta = series.diff()
    gain = delta.clip(lower=0)
    loss = -delta.clip(upper=0)

    avg_gain = gain.rolling(period, min_periods=period).mean()
    avg_loss = loss.rolling(period, min_periods=period).mean()

    rs = avg_gain / avg_loss.replace(0, np.nan)
    out = 100 - (100 / (1 + rs))
    return out.fillna(50)


def rsi_lpf(rsi_series: pd.Series, lpf_period: int = 5) -> pd.Series:
    """Low-pass filter the RSI with a simple moving average."""
    return rsi_series.rolling(lpf_period, min_periods=lpf_period).mean()


# ========== Data Helpers ==========
def fetch_data(cfg: RunConfig) -> pd.DataFrame:
    """
    Fetch historical data via tvDatafeed and return a tidy DataFrame
    indexed by timestamp with columns: open, high, low, close.
    """
    tv = TvDatafeed()
    df = tv.get_hist(
        symbol=cfg.symbol,
        exchange=cfg.exchange,
        interval=cfg.interval,
        n_bars=cfg.n_bars
    )
    if df is None or df.empty:
        raise ValueError("No data returned from tvDatafeed. Check symbol/exchange/interval.")

    # keep only OHLC; drop extras if present
    cols = [c for c in ("open", "high", "low", "close") if c in df.columns]
    df = df.loc[:, cols].copy()
    df.sort_index(inplace=True)
    return df


def calc_rangebreaks_from_index(idx: pd.Index) -> Iterable[Dict[str, Iterable[str]]]:
    """
    Build Plotly 'rangebreaks' from gaps in a monotone datetime index.
    This visually removes long gaps (like weekends) in the chart.
    """
    ts = pd.Series(idx)
    timedeltas = ts.diff()
    if timedeltas.isna().all():
        return []

    threshold = timedeltas.median() * 1.5
    missing_ix = np.where(timedeltas > threshold)[0]
    off = pd.Timedelta(microseconds=100)

    return [
        {"bounds": [str(ts.iloc[t - 1] + off), str(ts.iloc[t] - off)]}
        for t in missing_ix
    ]


# ========== Statistical Tables ==========
def forward_returns(close: pd.Series, horizons: Iterable[int]) -> pd.DataFrame:
    """Compute forward percentage returns for each horizon."""
    fr = {}
    for h in horizons:
        fr[f"fr_{h}"] = close.shift(-h) / close - 1.0
    return pd.DataFrame(fr)


def regime_masks(rsi_series: pd.Series) -> Dict[str, pd.Series]:
    """Boolean masks for RSI regimes (oscillator 70/30 and domain 50)."""
    return {
        "RSI>70": (rsi_series > 70),
        "RSI<30": (rsi_series < 30),
        "RSI>50": (rsi_series > 50),
        "RSI<50": (rsi_series < 50),
    }


def table_distribution(rsi_series: pd.Series) -> pd.DataFrame:
    """
    Distribution of observations across RSI regimes.
    Returns a table with counts and % of observations.
    """
    masks = regime_masks(rsi_series)
    n = rsi_series.dropna().shape[0]

    rows = []
    for label, mask in masks.items():
        count = int(mask.sum())
        share = (count / n * 100) if n else np.nan
        rows.append({"Regime": label, "Count": count, "% of observations": round(share, 2)})

    out = pd.DataFrame(rows).set_index("Regime").sort_index()
    return out


def table_forward_returns_by_regime(
    close: pd.Series,
    rsi_series: pd.Series,
    horizons: Iterable[int]
) -> pd.DataFrame:
    """
    Mean forward returns (%) conditional on raw RSI regimes (70/30 and 50).
    """
    fr = forward_returns(close, horizons)
    masks = regime_masks(rsi_series)

    rows = []
    for label, mask in masks.items():
        nobs = int(mask.sum())
        row = {"Regime": label, "N obs": nobs}
        for h in horizons:
            m = fr.loc[mask, f"fr_{h}"].mean()
            val = np.nan if nobs == 0 else float(np.round(m * 100.0, 3))
            row[f"Avg FR{h}"] = val
        rows.append(row)

    out = pd.DataFrame(rows).set_index("Regime").sort_index()
    return out


def table_forward_returns_50_lpf(
    close: pd.Series,
    rsi_lpf_series: pd.Series,
    horizons: Iterable[int]
) -> pd.DataFrame:
    """
    Mean forward returns (%) conditional on the proposed trend-following read:
    RSI_LPF above/below 50.
    """
    fr = forward_returns(close, horizons)
    masks = {
        "RSI_LPF>50": (rsi_lpf_series > 50),
        "RSI_LPF<50": (rsi_lpf_series < 50),
    }

    rows = []
    for label, mask in masks.items():
        nobs = int(mask.sum())
        row = {"Regime": label, "N obs": nobs}
        for h in horizons:
            m = fr.loc[mask, f"fr_{h}"].mean()
            val = np.nan if nobs == 0 else float(np.round(m * 100.0, 3))
            row[f"Avg FR{h}"] = val
        rows.append(row)

    out = pd.DataFrame(rows).set_index("Regime").sort_index()
    return out


# ========== Plotting ==========
def plot_price_and_rsi(df: pd.DataFrame, cfg: RunConfig) -> go.Figure:
    """
    Build a 2-row Plotly figure with price (row 1) and RSI/RSI_LPF (row 2),
    with horizontal lines at 30/50/70. Applies rangebreaks for readability.
    """
    fig = make_subplots(
        rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05,
        row_heights=[0.6, 0.4]
    )

    # Row 1: Price
    fig.add_trace(
        go.Scatter(
            x=df.index, y=df["close"], name="Close",
            line=dict(width=1)
        ),
        row=1, col=1
    )

    # Row 2: RSI lines
    fig.add_trace(
        go.Scatter(
            x=df.index, y=df["RSI"], name="RSI(14)",
            mode="lines", line=dict(width=1)
        ),
        row=2, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=df.index, y=df["RSI_LPF"], name=f"RSI-LPF({cfg.rsi_lpf_period})",
            mode="lines", line=dict(width=1)
        ),
        row=2, col=1
    )

    # Reference lines
    for y in (50, 30, 70):
        fig.add_hline(y=y, line_dash="dash", opacity=0.7, row=2, col=1)

    # Titles & layout
    title = f"{cfg.chart_title_prefix} — {cfg.symbol} ({cfg.interval.name})"
    fig.update_layout(
        title=title,
        height=900, width=1200,
        showlegend=True,
        yaxis_title="Close Price",
        yaxis2_title="RSI",
        hovermode="x unified"
    )

    # Rangebreaks from index gaps
    rangebreaks = list(calc_rangebreaks_from_index(df.index))
    if rangebreaks:
        fig.update_xaxes(rangebreaks=rangebreaks, row=1, col=1)
        fig.update_xaxes(rangebreaks=rangebreaks, row=2, col=1)

    return fig


# ========== Orchestration ==========
def run(cfg: RunConfig):
    # 1) Fetch data
    df = fetch_data(cfg)

    # 2) Indicators
    df["RSI"] = rsi(df["close"], period=cfg.rsi_period)
    df["RSI_LPF"] = rsi_lpf(df["RSI"], lpf_period=cfg.rsi_lpf_period)

    # Drop initial NaNs caused by rolling windows
    df = df.dropna().copy()

    # 3) Tables
    tbl_dist = table_distribution(df["RSI"])
    tbl_fr = table_forward_returns_by_regime(df["close"], df["RSI"], cfg.horizons)
    tbl_fr_lpf = table_forward_returns_50_lpf(df["close"], df["RSI_LPF"], cfg.horizons)

    # 4) Plot
    fig = plot_price_and_rsi(df, cfg)

    # 5) Optional exports
    if cfg.export_tables:
        Path = __import__("pathlib").Path
        outdir = Path(cfg.export_dir)
        outdir.mkdir(parents=True, exist_ok=True)
        tbl_dist.to_csv(outdir / f"dist_{cfg.symbol}_{cfg.interval.name}.csv")
        tbl_fr.to_csv(outdir / f"forward_returns_raw_{cfg.symbol}_{cfg.interval.name}.csv")
        tbl_fr_lpf.to_csv(outdir / f"forward_returns_lpf_{cfg.symbol}_{cfg.interval.name}.csv")

    if cfg.export_figure_html:
        Path = __import__("pathlib").Path
        Path(cfg.figure_html_path).parent.mkdir(parents=True, exist_ok=True)
        fig.write_html(cfg.figure_html_path)

    # 6) Display (for notebook use)
    print("\nTABLE – Distribution of RSI(14) regimes")
    display(tbl_dist)
    print("\nTABLE – Mean forward returns (%) conditional on RSI(14) regimes")
    display(tbl_fr)
    print("\nTABLE – Mean forward returns (%) conditional on RSI-LPF around 50")
    display(tbl_fr_lpf)

    fig.show()

    return df, tbl_dist, tbl_fr, tbl_fr_lpf, fig


# ========== Run (Notebook Cell) ==========
# Example usage in a notebook cell:
cfg = RunConfig(
    symbol="US500",
    exchange="CAPITALCOM",
    interval=Interval.in_4_hour,
    n_bars=10_000,
    rsi_period=14,
    rsi_lpf_period=5,
    horizons=(1, 5, 10),
    export_tables=False,
    export_figure_html=False,
)
_ = run(cfg)
