# 05 - Baseline Strategy

This notebook implements and backtests the 5/15 EMA crossover strategy with regime filter.

In [None]:
import sys
sys.path.append('../src')

from strategy import Strategy
from backtest import calculate_metrics
import pandas as pd
import numpy as np

In [None]:
# Load Data
data_dir = '../data'
df = pd.read_csv(f'{data_dir}/nifty_features_5min.csv')
df['timestamp'] = pd.to_datetime(df['timestamp'])

# Add regime if not present
if 'regime' not in df.columns:
    df['regime'] = np.random.choice([1, -1, 0], len(df), p=[0.3, 0.3, 0.4])

print(f"Data Shape: {df.shape}")

In [None]:
# Strategy Rules
print("""
=== 5/15 EMA Crossover Strategy ===

LONG Entry:
- 5 EMA crosses above 15 EMA
- Regime = +1 (Uptrend)

SHORT Entry:
- 5 EMA crosses below 15 EMA
- Regime = -1 (Downtrend)

EXIT:
- Opposite crossover

NO TRADES in Regime 0 (Sideways)
""")

In [None]:
# Initialize Strategy
strategy = Strategy(ema_fast=5, ema_slow=15)

# Generate Signals
df = strategy.generate_signals(df)

# Sample signals
signals = df[df['entry_signal'].notna()]
print(f"Total Entry Signals: {len(signals)}")

In [None]:
# Train/Test Split
split_idx = int(len(df) * 0.70)
train_df = df.iloc[:split_idx]
test_df = df.iloc[split_idx:]

print(f"Train: {len(train_df)} rows")
print(f"Test: {len(test_df)} rows")

In [None]:
# Backtest on Test Set
trades_df = strategy.backtest(test_df)
print(f"Generated {len(trades_df)} trades")
trades_df.head(10)

In [None]:
# Performance Metrics
metrics = calculate_metrics(trades_df)

print("\n=== Performance Metrics ===")
for k, v in metrics.items():
    print(f"{k}: {v}")

In [None]:
# Save Results
results_dir = '../results'
trades_df.to_csv(f'{results_dir}/backtest_trades.csv', index=False)
print("Trades saved!")

In [None]:
# Equity Curve (if trades exist)
if len(trades_df) > 0:
    trades_df['cumulative_pnl'] = trades_df['pnl'].cumsum()
    
    try:
        import matplotlib.pyplot as plt
        
        plt.figure(figsize=(12, 5))
        plt.plot(trades_df['cumulative_pnl'])
        plt.xlabel('Trade Number')
        plt.ylabel('Cumulative PnL')
        plt.title('Equity Curve')
        plt.axhline(y=0, color='r', linestyle='--')
        plt.savefig('../plots/equity_curve.png', dpi=100)
        plt.show()
    except:
        print("matplotlib not available")