tunables:
- scaler
- WindowSummarizer
- forecaster: estimator, window_length, strategy
- deseasonalize
- detrend

In [41]:
# import
import pandas as pd
import numpy as np
import plotly.express as px
pd.options.plotting.backend = "plotly"
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from sktime.forecasting.all import (
    MultiplexForecaster,
    AutoETS,
    AutoARIMA,
    NaiveForecaster,
    ForecastingGridSearchCV,
    ExpandingWindowSplitter,
    PolynomialTrendForecaster,
    Detrender,
    Deseasonalizer,
    TransformedTargetForecaster,
    SlidingWindowSplitter,
    ForecastingRandomizedSearchCV,
    ForecastingGridSearchCV,
    MeanAbsoluteScaledError,
    MeanAbsoluteError,
    MeanAbsolutePercentageError,
    MeanSquaredError,
    evaluate,
    )
from sktime.forecasting.compose import make_reduction
from sktime.transformations.series.date import DateTimeFeatures
from sktime.transformations.series.adapt import TabularToSeriesAdaptor
from sktime.transformations.series.summarize import WindowSummarizer

from sklearn.preprocessing import (
    Normalizer, 
    MinMaxScaler,
    )

from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor

mase = MeanAbsoluteScaledError()
mape = MeanAbsolutePercentageError()
mae = MeanAbsoluteError()
rmse = MeanSquaredError(square_root=True)


# data prep

In [34]:
# load data 
df_store = pd.read_pickle("data/df_daily.pkl")
df_store['sales'] = df_store['sales']/1e6
df_exog = pd.read_pickle("data/df_exog.pkl")
ts_company = df_store.groupby("date").sum()["sales"]
horizon = 7



In [35]:
# define scaler
scaler = TabularToSeriesAdaptor(MinMaxScaler())

# prepare data
def data_prep(y, X, horizon):
    '''
    Extract lagged values, means, DateTime features from y

    Parameters
    ----------
    y: target time series
    X: exogenous variables
    horizon: number of steps ahead to forecast

    Returns
    -------
    y_short: raw values of y with length cut to equal length of X; frequency set to daily
    X_trans: transformed version of X
    '''
    # extract lags, means
    kwargs = {
        "lag_config": {
            "lag": ["lag", [[1,i+6] for i in range(horizon)]], 
            "expand_mean": ["mean", [[i,horizon-1] for i in range(2, horizon+1)]], 
            }}

    df_window = WindowSummarizer(**kwargs).fit_transform(y).dropna()

    # extract DateTimeFeatures
    df_from_y = DateTimeFeatures(ts_freq="D", feature_scope="comprehensive").fit_transform(df_window)
    df_X = X.merge(df_from_y, left_index=True, right_index=True)

    # transform X
    X_trans = scaler.fit_transform(df_X)

    # equalize len y & X
    y_short = y.tail(X_trans.shape[0])
    y_short.index.freq = "D"

    return y_short, X_trans

y_short, X_trans = data_prep(
    y=ts_company, 
    X=df_exog, 
    horizon=horizon)


  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_MULTIINDEX_TYPES = (pd.Int64Index, pd.RangeIndex)
  VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
  VALID_MULTIINDEX_TYPES = (pd.Int64Index, pd.RangeIndex)
  VALID_MULTII

In [36]:
# define forecasting model
forecaster_XGB = make_reduction(
    estimator=XGBRegressor(eval_metric=mae), 
    window_length=7, 
    strategy="recursive",
    )

# model selection
forecaster = MultiplexForecaster(forecasters=[
    ("naive", NaiveForecaster()),
    ("ets", AutoETS()),
    ("arima", AutoARIMA(suppress_warnings=True, seasonal=False)),
    ("xgb", forecaster_XGB),
    ])

# pipeline
pipe = TransformedTargetForecaster([
    ("deseasonalize", Deseasonalizer(model="additive", sp=7)),
    ("detrend", Detrender(forecaster=PolynomialTrendForecaster(degree=1))),
    ("scale", scaler),
    ("forecaster", forecaster),
    ])

# config CV
cv_folds = 4
cv = SlidingWindowSplitter(
    fh=[i for i in range(1, horizon+1)],
    window_length=(len(y_short) - horizon * cv_folds),
    step_length=horizon,
    )


In [37]:
param_grid = {
    "selected_forecaster":["ets", "arima", "naive", "xgb"],
}

gscv = ForecastingGridSearchCV(
    cv=cv,
    param_grid=param_grid,
    forecaster=forecaster)
gscv.fit(y_short, X_trans)


ForecastingGridSearchCV(cv=SlidingWindowSplitter(fh=[1, 2, 3, 4, 5, 6, 7], step_length=7,
           window_length=1233),
                        forecaster=MultiplexForecaster(forecasters=[('naive',
                                                                     NaiveForecaster()),
                                                                    ('ets',
                                                                     AutoETS()),
                                                                    ('arima',
                                                                     AutoARIMA(seasonal=False,
                                                                    ('xgb',
                                                                     RecursiveTabularRegressionForecaster(estimator=XGBRegressor(base_score=None,
                                                                                                                                 booster=None,
                

In [46]:
best_forecaster = gscv.best_forecaster_.forecaster_
company_result = evaluate(
    forecaster=best_forecaster, 
    cv=cv, 
    y=y_short, 
    X=X_trans, 
    scoring=mape,
    return_data=True,
    )


In [49]:
company_result['test_MeanAbsolutePercentageError'].mean()


0.19281357690853912