In [1]:
# =====================================================
# ✅ Quick Test for BaseLSTM (single-ticker prototype)
# =====================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from volsense_core.forecasters.forecaster_api import VolSenseForecaster
from volsense_core.models.lstm_forecaster import TrainConfig, train_baselstm
from volsense_core.data_fetching.fetch_yf import fetch_ohlcv, compute_returns_vol
from volsense_core.data_fetching.feature_engineering import build_features

In [2]:
# --------------------------
# 1️⃣ Fetch and preprocess data
# --------------------------
TICKER = "AAPL"
print(f"🔍 Fetching {TICKER} historical data...")
raw_df = fetch_ohlcv(TICKER, start="2005-01-01", end="2024-12-31")
raw_df = compute_returns_vol(raw_df, window=21, ticker=TICKER)
multi_df = build_features(raw_df)

# Ensure necessary columns exist
assert "realized_vol" in multi_df.columns, "Missing realized_vol after feature engineering"
assert "return" in multi_df.columns, "Missing return after feature engineering"

🔍 Fetching AAPL historical data...


In [5]:
# ==============================================================
# 🧪 Unified Forecasting Test: BaseLSTM | GlobalLSTM | GARCH
# ==============================================================

device = "cuda" if torch.cuda.is_available() else "cpu"

# ==============================================================
# 🧩 BaseLSTM (Single-Ticker LSTM)
# ==============================================================
print("\n================ BaseLSTM Test ================")

lstm_forecaster = VolSenseForecaster(
    method="lstm",
    window=30,
    horizons=[1, 5, 10],
    epochs=3,
    device=device,
)
lstm_forecaster.fit(multi_df[multi_df["ticker"] == TICKER])

# Eval mode — produces full rolling predictions + realized vols
df_lstm_eval = lstm_forecaster.predict(mode="eval")
# Inference mode — only latest window forecasts
df_lstm_live = lstm_forecaster.predict(mode="inference")

print("\n✅ BaseLSTM Outputs:")
display(df_lstm_eval.head(5))
display(df_lstm_live)


🧩 Training BaseLSTM Forecaster...
⚠️ Dropping constant features: ['market_stress_1d_lag', 'market_stress']
⚠️ Dropping constant features: ['market_stress_1d_lag', 'market_stress']
Epoch 1/3 | LR: 3.75e-04 | Train: 0.610743 | Val: 0.131602
Epoch 2/3 | LR: 1.26e-04 | Train: 0.343486 | Val: 0.110889
Epoch 3/3 | LR: 1.00e-06 | Train: 0.246532 | Val: 0.076221

✅ BaseLSTM Outputs:


Unnamed: 0,asof_date,date,ticker,horizon,forecast_vol,realized_vol,model
0,2023-02-14,2023-02-15,AAPL,1,0.258597,0.247493,BaseLSTM
462,2023-02-14,2023-02-21,AAPL,5,0.269924,0.256443,BaseLSTM
924,2023-02-14,2023-02-28,AAPL,10,0.274504,0.235008,BaseLSTM
1,2023-02-15,2023-02-16,AAPL,1,0.241848,0.251398,BaseLSTM
463,2023-02-15,2023-02-22,AAPL,5,0.251198,0.255062,BaseLSTM


Unnamed: 0,asof_date,date,ticker,horizon,forecast_vol,realized_vol,model
0,2024-12-13,2024-12-16,AAPL,1,0.141135,,BaseLSTM
1,2024-12-13,2024-12-20,AAPL,5,0.151823,,BaseLSTM
2,2024-12-13,2024-12-27,AAPL,10,0.159008,,BaseLSTM


In [6]:
# ==============================================================
# 🌐 GlobalVolForecaster (Multi-Ticker LSTM)
# ==============================================================
print("\n================ GlobalLSTM Test ================")

device = "cuda" if torch.cuda.is_available() else "cpu"

global_forecaster = VolSenseForecaster(
    method="global_lstm",
    window=30,
    horizons=[1, 5, 10],
    epochs=3,
    device=device,
)
global_forecaster.fit(multi_df)

# Eval mode — rolling multi-horizon predictions across windows
df_global_eval = global_forecaster.predict(data=multi_df, mode="eval")
# Inference mode — only latest window forecasts
df_global_live = global_forecaster.predict(data=multi_df, mode="inference")

print("\n✅ GlobalVolForecaster Outputs:")
display(df_global_eval.head(5))
display(df_global_live)


🌐 Training GlobalVolForecaster...

🚀 Training GlobalVolForecaster on 1 tickers...

Epoch 1/3 | Train Loss: 0.5189 | Val Loss: 0.1910
Epoch 2/3 | Train Loss: 0.2036 | Val Loss: 0.1321
Epoch 3/3 | Train Loss: 0.1913 | Val Loss: 0.1552

✅ Training complete with feature set: ['return']



Rolling eval forecasts: 100%|██████████| 1/1 [00:11<00:00, 11.25s/it]


✅ GlobalVolForecaster Outputs:





Unnamed: 0,asof_date,date,ticker,horizon,forecast_vol,realized_vol,model
0,2005-03-16,2005-03-17,AAPL,1,-0.005159,,GlobalVolForecaster
997,2005-03-16,2005-03-23,AAPL,5,-0.408811,,GlobalVolForecaster
1994,2005-03-16,2005-03-30,AAPL,10,-0.167082,,GlobalVolForecaster
1,2005-03-23,2005-03-24,AAPL,1,-0.002897,,GlobalVolForecaster
998,2005-03-23,2005-03-30,AAPL,5,-0.408228,,GlobalVolForecaster


Unnamed: 0,asof_date,date,ticker,horizon,forecast_vol,realized_vol,model
0,2024-12-30,2024-12-31,AAPL,1,0.001974,,GlobalVolForecaster
1,2024-12-30,2025-01-06,AAPL,5,-0.409342,,GlobalVolForecaster
2,2024-12-30,2025-01-13,AAPL,10,-0.175042,,GlobalVolForecaster


In [7]:
len(df_lstm_eval), len(df_lstm_live), len(df_global_eval), len(df_global_live)

(1386, 3, 2991, 3)

In [3]:
print("\n================ GARCH Test ================")

garch_forecaster = VolSenseForecaster(method="garch", p=1, q=1)
garch_forecaster.fit(multi_df[multi_df["ticker"] == "AAPL"])

df_garch_eval = garch_forecaster.predict(mode="eval")
df_garch_live = garch_forecaster.predict(mode="inference", horizon=[1, 5, 10])

display(df_garch_eval.head())
display(df_garch_live)


📈 Fitting GARCH Forecaster...
✅ GARCH fit complete for AAPL (5011 obs).
🌀 Running rolling 1-step-ahead GARCH evaluation on AAPL...
✅ GARCH evaluation complete (4981 rows).
✅ GARCH inference complete (3 horizons).


Unnamed: 0,asof_date,date,ticker,horizon,forecast_vol,realized_vol,model
0,2005-03-17,2005-03-18,AAPL,1,0.377967,0.372598,GARCH
1,2005-03-18,2005-03-21,AAPL,1,0.377967,0.370795,GARCH
2,2005-03-21,2005-03-22,AAPL,1,0.377967,0.365991,GARCH
3,2005-03-22,2005-03-23,AAPL,1,0.377967,0.370504,GARCH
4,2005-03-23,2005-03-24,AAPL,1,0.377967,0.366032,GARCH


Unnamed: 0,asof_date,date,ticker,horizon,forecast_vol,realized_vol,model
0,2024-12-30,2024-12-31,AAPL,1,0.20367,,GARCH
1,2024-12-30,2025-01-06,AAPL,5,0.211269,,GARCH
2,2024-12-30,2025-01-13,AAPL,10,0.220089,,GARCH
