In [None]:
import sqlite3
from pathlib import Path
import numpy as np
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
from arch import arch_model
import matplotlib.pyplot as plt
import warnings
import time
warnings.filterwarnings('ignore')

# ==================== CONFIGURATION ====================
DB_PATH = Path("data/processed/data_processed.sqlite")
WINDOW_LENGTH = 500
P_MAX = 2          # REDUCED from 4 - way faster!
Q_MAX = 2          # REDUCED from 4 - way faster!
# =======================================================

def load_data_from_sqlite():
    with sqlite3.connect(DB_PATH) as conn:
        tables = pd.read_sql("SELECT name FROM sqlite_master WHERE type='table';", conn)
        table_name = tables["name"].iloc[0]
        df = pd.read_sql(f'SELECT Date, Close FROM "{table_name}"', conn, parse_dates=["Date"])
    
    df = df.sort_values("Date").set_index("Date")
    df['returns'] = np.log(df['Close']).diff()
    df = df.dropna(subset=['returns'])
    return df

def find_best_arma(returns_window, p_max, q_max):
    best_aic = np.inf
    best_order = (1, 0, 1)
    
    for p in range(1, p_max + 1):
        for q in range(1, q_max + 1):
            try:
                fitted = ARIMA(returns_window, order=(p, 0, q)).fit(method='statespace', disp=0)
                if fitted.aic < best_aic:
                    best_aic = fitted.aic
                    best_order = (p, 0, q)
            except:
                pass
    return best_order

def forecast_return(returns_window, p_max, q_max):
    try:
        best_order = find_best_arma(returns_window, p_max, q_max)
        garch = arch_model(returns_window, mean='ARX', lags=best_order[0],
                          vol='GARCH', p=1, q=1, dist='skewt')
        fit = garch.fit(disp='off', show_warning=False)
        return fit.forecast(horizon=1, reindex=False).mean.iloc[-1, 0]
    except:
        return 0.0

def run_strategy(df, window):
    returns = df['returns'].values
    dates = df.index
    n = len(returns)
    fc_len = n - window
    
    forecasts = np.zeros(fc_len)
    
    print(f"Processing {fc_len} windows (this takes time)...")
    start = time.time()
    
    for i in range(fc_len):
        if i % 50 == 0 and i > 0:
            elapsed = time.time() - start
            rate = i / elapsed
            remaining = (fc_len - i) / rate / 60
            print(f"{i}/{fc_len} ({100*i/fc_len:.0f}%) | Est. {remaining:.1f} min remaining")
        
        forecasts[i] = forecast_return(returns[i:i+window], P_MAX, Q_MAX)
    
    total_time = (time.time() - start) / 60
    print(f"Completed in {total_time:.1f} minutes!\n")
    
    fc_dates = dates[window:]
    fc_series = pd.Series(forecasts, index=fc_dates).shift(1)
    direction = np.where(fc_series > 0, 1, np.where(fc_series < 0, -1, 0))
    direction = pd.Series(direction, index=fc_dates)
    
    real_ret = df.loc[fc_dates, 'returns']
    strat_ret = direction * real_ret
    strat_ret.iloc[0] = 0
    
    return strat_ret.cumsum(), real_ret.cumsum(), strat_ret, real_ret

def plot(strat, bh):
    fig, ax = plt.subplots(figsize=(12, 6))
    strat.plot(ax=ax, color='green', label='ARIMA+GARCH', lw=1.5)
    bh.plot(ax=ax, color='red', label='Buy & Hold', lw=1.5)
    ax.set_xlabel('Time')
    ax.set_ylabel('Cumulative Return')
    ax.set_title('AAPL: ARIMA+GARCH vs Buy & Hold', fontweight='bold')
    ax.legend(loc='lower left')
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('AAPL_arma_garch.png', dpi=150)
    plt.show()
    print("Saved AAPL_arma_garch.png")

def stats(strat_ret, bh_ret):
    return {
        'Strategy Return (%)': (np.exp(strat_ret.sum())-1)*100,
        'Buy&Hold Return (%)': (np.exp(bh_ret.sum())-1)*100,
        'Strategy Sharpe': strat_ret.mean()/strat_ret.std()*np.sqrt(252),
        'Buy&Hold Sharpe': bh_ret.mean()/bh_ret.std()*np.sqrt(252),
        'Win Rate': (strat_ret>0).sum()/(strat_ret!=0).sum(),
        'Trades': (strat_ret!=0).sum()
    }

if __name__ == "__main__":
    print("\nAAPL ARIMA+GARCH Strategy\n" + "="*50 + "\n")
    
    df = load_data_from_sqlite()
    print(f"Loaded {len(df)} days: {df.index[0].date()} to {df.index[-1].date()}\n")
    
    strat_curve, bh_curve, strat_ret, bh_ret = run_strategy(df, WINDOW_LENGTH)
    
    print("PERFORMANCE:")
    print("="*50)
    metrics = stats(strat_ret, bh_ret)
    for k, v in metrics.items():
        print(f"{k:.<35} {v:>12.2f}")
    print("="*50 + "\n")
    
    plot(strat_curve, bh_curve)


AAPL ARIMA+GARCH Strategy

Loaded 6322 days: 2000-10-17 to 2025-12-05

Processing 5822 windows (this takes time)...
50/5822 (1%) | Est. 5.3 min remaining
100/5822 (2%) | Est. 4.6 min remaining
150/5822 (3%) | Est. 5.9 min remaining
200/5822 (3%) | Est. 7.0 min remaining
250/5822 (4%) | Est. 7.6 min remaining
300/5822 (5%) | Est. 7.6 min remaining
350/5822 (6%) | Est. 7.5 min remaining
400/5822 (7%) | Est. 7.0 min remaining
450/5822 (8%) | Est. 6.8 min remaining
500/5822 (9%) | Est. 6.9 min remaining
550/5822 (9%) | Est. 7.3 min remaining
600/5822 (10%) | Est. 7.8 min remaining
650/5822 (11%) | Est. 7.8 min remaining
700/5822 (12%) | Est. 7.6 min remaining
750/5822 (13%) | Est. 7.7 min remaining
800/5822 (14%) | Est. 7.6 min remaining
850/5822 (15%) | Est. 7.6 min remaining
900/5822 (15%) | Est. 7.6 min remaining
950/5822 (16%) | Est. 7.5 min remaining
1000/5822 (17%) | Est. 7.3 min remaining
1050/5822 (18%) | Est. 7.3 min remaining
1100/5822 (19%) | Est. 7.2 min remaining
1150/5822 (2