In [None]:
import datetime as dt
import numpy as np
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
import seaborn as sns
import traceback

import ipywidgets as widgets
from IPython.display import display, clear_output

plt.style.use('seaborn-v0_8')
sns.set_style("whitegrid")
sns.set_context("talk", font_scale=1.0)

class DataHandler:
    def __init__(self, symbol: str, start: str, end: str):
        self.symbol = symbol
        self.start = start
        self.end = end
    
    def load_data(self) -> pd.DataFrame:
        data = yf.download(self.symbol, start=self.start, end=self.end, progress=False, auto_adjust=False)
        if data.empty:
            raise ValueError(f"No data fetched for {self.symbol}. Check symbol and dates.")
        data = data[['Open', 'High', 'Low', 'Close', 'Volume']].dropna()
        print(f"Loaded {len(data)} rows for {self.symbol} from {data.index[0].date()} to {data.index[-1].date()}")
        return data

def compute_indicators(data: pd.DataFrame, short_window: int, long_window: int):
    df = data.copy()
    price = df['Close']
    if isinstance(price, pd.DataFrame):
        price = price.squeeze()
    df['short_ma'] = price.rolling(window=short_window).mean()
    df['long_ma']  = price.rolling(window=long_window).mean()
    df['short_ema'] = price.ewm(span=short_window, adjust=False).mean()
    df['long_ema']  = price.ewm(span=long_window,  adjust=False).mean()
    return df

def generate_signal(df: pd.DataFrame, kind: str, long_window: int):
    out = df.copy()
    out['signal'] = 0
    idx0 = long_window
    if kind == 'SMA':
        out.loc[out.index[idx0:], 'signal'] = np.where(out['short_ma'][idx0:] > out['long_ma'][idx0:], 1, -1)
    else:
        out.loc[out.index[idx0:], 'signal'] = np.where(out['short_ema'][idx0:] > out['long_ema'][idx0:], 1, -1)
    out['positions'] = out['signal'].diff().fillna(0)
    return out

class Backtester:
    def __init__(self, initial_capital: float = 100000, commission: float = 0.002, slippage: float = 0.001):
        self.initial_capital = initial_capital
        self.commission = commission
        self.slippage = slippage

    def run(self, data: pd.DataFrame, allocation_pct: float = 0.5) -> tuple:
        df = data.copy()
        portfolio = pd.DataFrame(index=df.index)
        close = df['Close']
        if isinstance(close, pd.DataFrame):
            close = close.squeeze()
        raw_shares = (self.initial_capital * allocation_pct) / close
        shares = pd.Series(np.floor(raw_shares).astype(int), index=close.index)
        portfolio['positions'] = df['signal'] * shares
        pos_diff = portfolio['positions'].diff().fillna(0)
        slippage_factor = 1 + (self.slippage * np.sign(pos_diff))
        effective_prices = close * slippage_factor
        portfolio['holdings'] = portfolio['positions'] * close
        trade_notional = np.abs(pos_diff) * effective_prices
        commission_costs = trade_notional * self.commission
        total_costs = (trade_notional + commission_costs).fillna(0)
        portfolio['cash'] = self.initial_capital - total_costs.cumsum()
        portfolio['total'] = portfolio['cash'] + portfolio['holdings']
        portfolio['returns'] = portfolio['total'].pct_change().fillna(0)
        return portfolio, trade_notional, commission_costs

def performance_metrics(portfolio: pd.DataFrame) -> dict:
    returns = portfolio['returns']
    ann_ret = returns.mean() * 252
    ann_vol = returns.std() * np.sqrt(252)
    rf = 0.01
    sharpe = (ann_ret - rf) / ann_vol if ann_vol > 0 else np.nan
    rolling_max = portfolio['total'].cummax()
    drawdown = (portfolio['total'] - rolling_max) / rolling_max
    max_dd = drawdown.min()
    win_rate = (returns[portfolio['positions'].abs() > 0] > 0).mean() if (portfolio['positions'].abs() > 0).any() else 0.0
    return {
        'Annual Return': ann_ret,
        'Annual Volatility': ann_vol,
        'Sharpe Ratio': sharpe,
        'Max Drawdown': max_dd,
        'Win Rate': win_rate
    }

def plot_results(symbol: str, df: pd.DataFrame, portfolio: pd.DataFrame, kind: str, short_window: int, long_window: int):
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(13, 12), sharex=True)
    color_price = '#2E2E2E'
    color_sma = '#1f77b4'
    color_ema = '#ff7f0e'
    color_equity = '#4C72B0'
    color_buy = '#2ca02c'
    color_sell = '#d62728'
    ax1.plot(df.index, df['Close'], label='Close', color=color_price, alpha=0.85, linewidth=1.6)
    if kind == 'SMA':
        ax1.plot(df.index, df['short_ma'], label=f'SMA {short_window}', color=color_sma, linewidth=2.2)
        ax1.plot(df.index, df['long_ma'],  label=f'SMA {long_window}', color=color_sma, linestyle='--', linewidth=2.2)
    else:
        ax1.plot(df.index, df['short_ema'], label=f'EMA {short_window}', color=color_ema, linewidth=2.2)
        ax1.plot(df.index, df['long_ema'],  label=f'EMA {long_window}', color=color_ema, linestyle='--', linewidth=2.2)
    buys  = df[df['positions'] > 0]
    sells = df[df['positions'] < 0]
    if not buys.empty:
        ax1.scatter(buys.index, buys['Close'], marker='^', color=color_buy, s=64, label='Buy', zorder=5)
    if not sells.empty:
        ax1.scatter(sells.index, sells['Close'], marker='v', color=color_sell, s=64, label='Sell', zorder=5)
    ax1.set_title(f'{symbol} Price with {kind} Crossover')
    ax1.set_ylabel('Price')
    ax1.legend(ncols=3, frameon=True)
    ax1.grid(True, alpha=0.25)
    ax2.plot(portfolio.index, portfolio['total'], label='Equity', color=color_equity, linewidth=2.6)
    ax2.axhline(y=portfolio['total'].iloc[0], color='gray', linestyle='--', alpha=0.5, label='Initial Capital')
    ax2.set_title('Portfolio Equity Curve')
    ax2.set_ylabel('Portfolio Value')
    ax2.legend(frameon=True)
    ax2.grid(True, alpha=0.25)
    rolling_max = portfolio['total'].cummax()
    drawdown = (portfolio['total'] - rolling_max) / rolling_max * 100
    ax3.fill_between(portfolio.index, drawdown, 0, color=color_sell, alpha=0.25)
    ax3.plot(portfolio.index, drawdown, color=color_sell, linewidth=1.25)
    ax3.set_title('Drawdown (%)')
    ax3.set_ylabel('Drawdown (%)')
    ax3.set_xlabel('Date')
    ax3.grid(True, alpha=0.25)
    plt.tight_layout()
    plt.show()

title = widgets.HTML(
    value="<h3 style='margin:8px 0;'>Moving Average Crossover Backtest</h3><p style='margin:0;color:#666;'>Select inputs and click Run Backtest.</p>"
)

symbol_w   = widgets.Text(value='AAPL', description='Symbol:', placeholder='Ticker, e.g., AAPL', layout=widgets.Layout(width='220px'))
start_w    = widgets.DatePicker(description='Start:', value=dt.date(2024, 1, 1))
end_w      = widgets.DatePicker(description='End:', value=dt.date.today())
strategy_w = widgets.Dropdown(options=['SMA', 'EMA'], value='SMA', description='MA Kind:')
short_w    = widgets.IntSlider(description='Short MA', value=50, min=5, max=200, step=1, continuous_update=False)
long_w     = widgets.IntSlider(description='Long MA', value=200, min=20, max=400, step=5, continuous_update=False)

capital_w    = widgets.FloatText(description='Capital', value=100000.0, step=1000.0)
alloc_w      = widgets.FloatSlider(description='Alloc %', value=0.50, min=0.05, max=1.0, step=0.05, readout_format='.2f', continuous_update=False)
commission_w = widgets.FloatText(description='Comm %', value=0.002, step=0.001)
slippage_w   = widgets.FloatText(description='Slip %', value=0.001, step=0.001)

run_btn = widgets.Button(description='Run Backtest', button_style='success', icon='play')
out = widgets.Output()

controls_top = widgets.HBox([symbol_w, strategy_w, short_w, long_w], layout=widgets.Layout(justify_content='space-between', width='100%', gap='10px'))
controls_mid = widgets.HBox([start_w, end_w], layout=widgets.Layout(gap='12px'))
controls_bot = widgets.HBox([capital_w, alloc_w, commission_w, slippage_w], layout=widgets.Layout(justify_content='space-between', width='100%', gap='10px'))
toolbar = widgets.HBox([run_btn], layout=widgets.Layout(justify_content='flex-start', padding='4px 0'))
dashboard = widgets.VBox([title, controls_top, controls_mid, controls_bot, toolbar, out],
                          layout=widgets.Layout(border='1px solid #eee', padding='8px', width='100%'))

def on_run_clicked(b):
    with out:
        clear_output(wait=True)
        try:
            if start_w.value is None or end_w.value is None:
                print("Please select both start and end dates.")
                return
            if start_w.value >= end_w.value:
                print("Start date must be earlier than end date.")
                return
            symbol = symbol_w.value.strip().upper()
            kind = strategy_w.value
            short_n = int(short_w.value)
            long_n = int(long_w.value)
            if short_n >= long_n:
                print("Short MA window must be less than long MA window.")
                return
            initial_capital = float(capital_w.value)
            alloc_pct = float(alloc_w.value)
            commission = float(commission_w.value)
            slippage = float(slippage_w.value)
            handler = DataHandler(symbol, start_w.value.isoformat(), end_w.value.isoformat())
            data = handler.load_data()
            data = compute_indicators(data, short_n, long_n)
            data = generate_signal(data, kind, long_n)
            bt = Backtester(initial_capital=initial_capital, commission=commission, slippage=slippage)
            portfolio, trade_notional, commission_costs = bt.run(data, allocation_pct=alloc_pct)
            metrics = performance_metrics(portfolio)
            print("=== Performance Metrics ===")
            for k, v in metrics.items():
                print(f"{k}: {v:.4f}" if isinstance(v, (float, np.floating)) else f"{k}: {v}")
            plot_results(symbol, data, portfolio, kind, short_n, long_n)
        except Exception as e:
            traceback.print_exc()
            print(f"Error: {str(e)}")

run_btn.on_click(on_run_clicked)
display(dashboard)
