In [1]:
%load_ext autoreload
%autoreload 2

from datetime import datetime, timedelta
from sqlalchemy.orm import sessionmaker
import os, sys
from enum import Enum
from functools import partial
from collections import defaultdict
from typing import List, Any, Optional, Tuple

project_dir = os.path.abspath('..')
sys.path.insert(0, project_dir)

from sqlalchemy import create_engine
import pandas as pd
import numpy as np
import pandas_market_calendars as mcal

import unittest
from unittest.mock import patch
from volatility_forecast.data.dataloader import (
    TiingoEoDDataLoader,
    PriceVolumeDatabaseLoader,
    TiingoEodDataLoaderProd,
)
from volatility_forecast.data.datamanager import (
    LagReturnDataManager,
    LagAbsReturnDataManager, 
    LagSquareReturnDataManager,
    SquareReturnDataManager,
)
from volatility_forecast.data.date_util import get_closest_next_business_day, get_closest_prev_business_day
from volatility_forecast.data.base import Field, DataSet, DateLike
from volatility_forecast.data.dataset import PriceVolume
from volatility_forecast.data.persistence import persist_data, load_data_from_db
from volatility_forecast.data.database import engine, Base

from volatility_forecast.model.stes_model import STESModel
from volatility_forecast.model.xgboost_stes_model import XGBoostSTESModel, DEFAULT_XGBOOST_PARAMS

from volatility_forecast.evaluation.model_evaluator import evaluate_model, compare_models, root_mean_squared_error, generate_model_forecasts


In [2]:
ModelName = Enum(
    "ModelName", "ES STES_AE STES_SE STES_ESE STES_EAE STES_AESE STES_EAESE XGBoost_STES"
)

In [3]:
tickers = ("SPY", )
start_date = "2000-01-01"
end_date = pd.Timestamp.today()

In [4]:
def equity_data_provider(tickers: Tuple[str], start_date: DateLike, end_date: DateLike, model_name: ModelName) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    returns = LagReturnDataManager().get_data(tickers, start_date, end_date) * 1e2
    realized_var = SquareReturnDataManager().get_data(tickers, start_date, end_date) * 1e4
    feature_sets = np.hstack([
        LagReturnDataManager().get_data(tickers, start_date, end_date) * 1e2,
        LagAbsReturnDataManager().get_data(tickers, start_date, end_date) * 1e2,
        LagSquareReturnDataManager().get_data(tickers, start_date, end_date) * 1e4,
    ])
    if model_name == ModelName.ES:
        return np.ones((len(returns), 1)), realized_var, returns
    elif model_name == ModelName.STES_AE:
        return np.hstack([np.ones((len(returns), 1)), feature_sets[:, [1]], ]), realized_var, returns
    elif model_name == ModelName.STES_SE:
        return np.hstack([np.ones((len(returns), 1)), feature_sets[:, [2]], ]), realized_var, returns
    elif model_name == ModelName.STES_EAE:
        return np.hstack([np.ones((len(returns), 1)), feature_sets[:, [0, 1]], ]), realized_var, returns
    elif model_name == ModelName.STES_ESE:
        return np.hstack([np.ones((len(returns), 1)), feature_sets[:, [0, 2]], ]), realized_var, returns
    elif model_name == ModelName.STES_AESE:
        return np.hstack([np.ones((len(returns), 1)), feature_sets[:, [1, 2]], ]), realized_var, returns
    elif model_name == ModelName.STES_EAESE:
        return np.hstack([np.ones((len(returns), 1)), feature_sets, ]), realized_var, returns
    elif model_name == ModelName.XGBoost_STES:
        return feature_sets, realized_var, returns
    else:
        raise ValueError(f"Unknown model name: {model_name}")


In [5]:
spy_models = {m: STESModel() if m != ModelName.XGBoost_STES else XGBoostSTESModel(**DEFAULT_XGBOOST_PARAMS) for m in ModelName}

In [None]:
spy_results = {}
for model_name in ModelName:
    np.random.seed(0)

    model = spy_models[model_name]
    data_provider = partial(
        equity_data_provider, 
        tickers=tickers, 
        start_date=start_date, 
        end_date=end_date, 
        model_name=model_name
    )

    os_res, is_res = evaluate_model(
        data_provider,
        model, 
        root_mean_squared_error,
        10, 4000
    )
    spy_results[model_name] = (os_res, is_res)
