# How to add a custom forecasting model

This notebook provides a minimal example on how to add your own forecasting model to Test-of-time step-by-step.
Eventually, you can run it in a Test-of-time benchmark.
For this tutorial, we choose to implement the **seasonal naive forecasting model** available in the **darts** library.

**TL;DR:** To implement a new model, you need to follow these steps:

* **Step 1:** Implement a new model class that inherits from the abstract class `test_of_time.tot.models.Model`.
* **Step 2:** Implement the `__post_init__()` to initialize the required class attributes
* **Step 3:** Implement the abstract `fit()` method
* **Step 4:** Implement the abstract `predict()` returning the model's forecast
* **Step 5:** [Optional] Implement the parent class method `maybe_extend_df()` and `maybe_drop_added_values_from_df()` for model-specific pre-/ post-processing
* **Step 6:** Running your new model in a simple benchmark


## Step 1 - Import model
You have two options of models you can use. (1) You import an existing model from a library such as darts, gluonts, or kats. In this case the model will be already provided as a class.
(2) You implement your own model in a class.
We import `NaiveSeasonal` which we import from the **darts** library.

In [5]:
>>> from darts.models import NaiveSeasonal

## Step 2 - Implement new model class
* We import the abstract base class `tot.models.Model` and define a new class with the name `CustomSeasonalNaiveModel`
which inherits from `Model`. We want this class to be a dataclass and decorate it with `@dataclass` imported from
`dataclasses`. We can see this class as a wrapper of the actual model we are adding which represents and interface between
the test-of-time library and the model conventions. Hence, we will call it **model wrapper** in the following.
* We assign the attribute `model_name` which is a non-optional attribute to `CustomSeasonalNaive`
*We assign the attribute `model_type` with the model class we have imported. In this case we assign it to `NaiveSeasonal`.

In [6]:
>>> from dataclasses import dataclass
>>> from tot.models import Model
>>> from typing import Type
>>> @dataclass
>>> class CustomSeasonalNaiveModel(Model):
>>>    model_name: str = "CustomSeasonalNaive"
>>>    model_class: Type = NaiveSeasonal

## Step 3 - Implement the `__post_init__()`
We implement the `__post_init__()` to initialize the required model wrapper class attributes and local parameters.
This includes the sub-steps: (1) assign `model_params` and instantiate model, (2) assign and verify model wrapper attributes.
Let's have a look at it step-by-step:

1. First, we want to make sure that all model parameter that are relevant for model fitting and predicting are assigned.
Therefore, we only extract the parameters that the custom model needs. The `NaiveSeasonal` needs 1 input parameters to
make predictions, the `K`, which is the seasonal period in number of time steps. To instantiate the actual
model, we pass the `model_params` to `model_class`.
2. Next, we assign the model wrapper attributes `self.freq`, `self.season_length`, and `self.n_forecasts`. Parameters like we assign the `freq`
are defined by the dataset and hence provided as a `_data_param`. The other two attributes are provided by the input accessible by `self.param`
For both attributes we verify valid inputs.

Remark:This is the minimum required initialization, further attributes could be added.

In [7]:
>>> def __post_init__(self):
>>>         # extract model parameters and instantiate actual model
>>>         model_params = self.params["K"] # K is the parameter for the seasonal period defined by darts
>>>         self.model = self.model_class(**model_params)
>>>
>>>         # re-assign the frequency as model wrapper attribute
>>>         # remark: structure will change in the future
>>>         self.freq = self.params["_data_params"]["freq"]
>>>         # Set forecast horizon as model wrapper attribute horizon and verify
>>>         self.n_forecasts = self.params["n_forecasts"]
>>>         assert self.n_forecasts >= 1, "Model parameter n_forecasts must be >=1. "
>>>         # Set season length as model wrapper attribute horizon and verify
>>>         self.season_length = model_params["K"]
>>>         assert self.season_length is not None, (
>>>             "Dataset does not provide a seasonality. Assign a seasonality to each of the datasets "
>>>             "OR input desired season_length as model parameter to be used for all datasets "
>>>             "without specified seasonality."
>>>         )

## Step 4 - Fit() method
The `fit()` method of the model wrapper can be considered as an interface to the `fit()` method of the actual model.
It includes the model-specific data pre-processing of the data.
* The model-specific data pre-processing in this case comprises to check if the dataframe contains enough samples for fitting
 via calling `_check_min_df_len()` and converting the dataframe to the `TimeSeries` format from darts. Both functions
 are available as a helper function in test-of-time.
* We pass the series of type `TimeSeries` to the `fit()` method of the instantiated model

In [9]:
>>> import pandas as pd
>>> from tot.utils import convert_df_to_DartsTimeSeries
>>> from tot.df_utils import _check_min_df_len
>>>
>>> def fit(self, df: pd.DataFrame, freq: str):
>>>     # check if df contains enough samples for fitting
>>>     _check_min_df_len(df=df, min_len= self.n_forecasts + self.season_length)
>>>     self.freq = freq
>>>     series = convert_df_to_DartsTimeSeries(df, value_cols=df.columns.values[1:-1].tolist(), freq=self.freq)
>>>     self.model = self.model.fit(series)

## Step 5 - Predict() method
The `predict()` method of the model wrapper can be considered as an interface to the `predict()` method of the actual
model. It includes the model-specific data pre- and post-processing of the data.
* First, we implement the model-specific data pre-processing. Special for darts models: Last we set the `n_req_past_obs` which for the  `NaiveSeasonal` has to be greater than 3 and increase it by 1 to be consitent with the prediction range of darts model
 that have retraining activated.
* Next, we predict the forecast by calling `_predict_darts_model()`. This function is an available wrapper to predict
model from the darts library.
* Last, we implement the model-specific post-processing, which comprises to drop the previously added samples via
`maybe_drop_added_values_from_df()`.

Remarks
* Data format: Our input and output of the test-of-time environment is of type pd.Dataframe. In case we work with any
other data format in between, in this case `TimeSeries` from darts, we need to ensure to convert this data format from/to a
pd.Dataframe. For darts model we offer a helper function `_predict_darts_model` that incorporates this capability.
for the returned forecast.
* Backtesting: Test-of-time is a framework that per default executes backtesting. That means it forecasts the selected
forecast horizon in a rolling manner on the complete available data. Some libraries offer that capability along with
their models. For other libraries, this procedure needs to be implemented in the `predict()` wrapper.
rolling historical procedure

In [10]:
>>> from tot.utils import _predict_darts_model
>>> from tot.df_utils import _check_min_df_len, prep_or_copy_df
>>>
>>> def predict(self, df: pd.DataFrame, df_historic: pd.DataFrame = None):
>>>     # check if df contains enough samples for fitting
>>>     _check_min_df_len(df=df, min_len=1)
>>>     # optional: extend test df with historic train data. df_historic is passed from upper level
>>>     if df_historic is not None:
>>>         df = self.maybe_extend_df(df_historic, df)
>>>     # ensure that df has an ID
>>>     df, received_ID_col, received_single_time_series, _ = prep_or_copy_df(df)
>>>     # min. past observations !> 3 and 1 needs to be added for darts models because retrain=True
>>>     n_req_past_obs = 3 if self.season_length < 3 else self.season_length
>>>     n_req_past_obs += 1
>>>     # predict
>>>     fcst_df = _predict_darts_model(df=df, model=self, n_req_past_obs=n_req_past_obs, n_req_future_obs=self.n_forecasts, retrain=True)
>>>     # drop values from extended df
>>>     if df_historic is not None:
>>>         fcst_df, df = self.maybe_drop_added_values_from_df(fcst_df, df)
>>>     return fcst_df

## Step 5 - Implement the parent class method
[Optional] The abstract parent class `Model` has 2 class methods `maybe_extend_df()` and `maybe_drop_added_values_from_df()` that
must be reimplemented in case they should be active. Since we want to have the active for our custom model, we implement
them.
In the `maybe_extend_df()` we add `self.season_length` samples of the train dataframe to the test dataframe. In the
`drop_first_inputs_from_df()` we will drop them again.

In [12]:
>>> from tot.df_utils import add_first_inputs_to_df, drop_first_inputs_from_df
>>>
>>> def maybe_extend_df(self, df_train, df_test):
>>>     samples = self.season_length
>>>     df_test = add_first_inputs_to_df(samples=samples, df_train=df_train, df_test=df_test)
>>>     return df_test
>>>
>>> def maybe_drop_added_values_from_df(self, predicted, df):
>>>     samples = self.season_length
>>>     predicted, df = drop_first_inputs_from_df(samples=samples, predicted=predicted, df=df)
>>>     return predicted, df

## Putting it all together

In [15]:
#imports
from copy import deepcopy
from dataclasses import dataclass
from typing import Type
from tot.models import Model
from tot.utils import convert_df_to_DartsTimeSeries, _predict_darts_model
from tot.df_utils import _check_min_df_len, prep_or_copy_df, add_first_inputs_to_df, drop_first_inputs_from_df
from darts.models import NaiveSeasonal

In [46]:
@dataclass
class CustomSeasonalNaiveModel(Model):
    model_name: str = "CustomSeasonalNaive"
    model_class: Type = NaiveSeasonal

    def __post_init__(self):
         # extract model parameters and instantiate actual model
         model_params = self.params["K"] # K is the parameter for the seasonal period defined by darts
         self.model = self.model_class(model_params)

         # re-assign the frequency as model wrapper attribute
         # remark: structure will change in the future
         self.freq = self.params["_data_params"]["freq"]
         # Set forecast horizon as model wrapper attribute horizon and verify
         self.n_forecasts = self.params["n_forecasts"]
         assert self.n_forecasts >= 1, "Model parameter n_forecasts must be >=1. "
         # Set season length as model wrapper attribute horizon and verify
         self.season_length = self.params["K"]
         assert self.season_length is not None, (
             "Dataset does not provide a seasonality. Assign a seasonality to each of the datasets "
             "OR input desired season_length as model parameter to be used for all datasets "
             "without specified seasonality."
         )

    def fit(self, df: pd.DataFrame, freq: str):
        # check if df contains enough samples for fitting
        _check_min_df_len(df=df, min_len= self.n_forecasts + self.season_length)
        self.freq = freq
        series = convert_df_to_DartsTimeSeries(df, value_cols=df.columns.values[1:-1].tolist(), freq=self.freq)
        self.model = self.model.fit(series)

    def predict(self, df: pd.DataFrame, df_historic: pd.DataFrame = None):
        # check if df contains enough samples for fitting
        _check_min_df_len(df=df, min_len=1)
        # optional: extend test df with historic train data. df_historic is passed from upper level
        if df_historic is not None:
            df = self.maybe_extend_df(df_historic, df)
        # ensure that df has an ID
        df, received_ID_col, received_single_time_series, _ = prep_or_copy_df(df)
        # min. past observations !> 3 and 1 needs to be added for darts models because retrain=True
        n_req_past_obs = 3 if self.season_length < 3 else self.season_length
        n_req_past_obs += 1
        # predict
        fcst_df = _predict_darts_model(df=df, model=self, n_req_past_obs=n_req_past_obs, n_req_future_obs=self.n_forecasts, retrain=True)
        # drop values from extended df
        if df_historic is not None:
         fcst_df, df = self.maybe_drop_added_values_from_df(fcst_df, df)
        return fcst_df

    def maybe_extend_df(self, df_train, df_test):
        samples = self.season_length
        df_test = add_first_inputs_to_df(samples=samples, df_train=df_train, df_test=df_test)
        return df_test

    def maybe_drop_added_values_from_df(self, predicted, df):
        samples = self.season_length
        predicted, df = drop_first_inputs_from_df(samples=samples, predicted=predicted, df=df)
        return predicted, df

## Benchmark your own model
For running your new model in a benchmark, we load some sample datasets

In [43]:
data_location = "https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/datasets/"

air_passengers_df = pd.read_csv(data_location + 'air_passengers.csv')
peyton_manning_df = pd.read_csv(data_location + 'wp_log_peyton_manning.csv')
yosemite_temps_df = pd.read_csv(data_location +  'yosemite_temps.csv')
ercot_load_df = pd.read_csv(data_location +  'multivariate/load_ercot_regions.csv')

Let's set up the `SimpleBenchmark` template with our CustomSeasonalNaiveModel and run the benchmark.

In [48]:
from tot.dataset import Dataset
from tot.benchmark import SimpleBenchmark
dataset_list = [
    Dataset(df = air_passengers_df, name = "air_passengers", freq = "MS"),
    Dataset(df = peyton_manning_df, name = "peyton_manning", freq = "D"),
    Dataset(df = yosemite_temps_df, name = "yosemite_temps", freq = "5min"),
    # Dataset(df = ercot_load_df, name = "ercot_load", freq = "H"),
]
model_classes_and_params = [
    (CustomSeasonalNaiveModel, {"K": 1, "n_forecasts":3}),
]
benchmark = SimpleBenchmark(
    model_classes_and_params=model_classes_and_params, # iterate over this list of tuples
    datasets=dataset_list, # iterate over this list
    metrics=["MAE", "MSE", "MASE", "RMSE"],
    test_percentage=25,
)

In [51]:
results_train, results_test = benchmark.run()

In [50]:
results_test

Unnamed: 0,data,model,params,experiment,MAE,MSE,MASE,RMSE
0,air_passengers,CustomSeasonalNaive,"{'K': 1, 'n_forecasts': 3, '_data_params': {'f...",air_passengers_CustomSeasonalNaive_K_1_n_forec...,73.016663,8050.25,3.301137,86.131256
1,peyton_manning,CustomSeasonalNaive,"{'K': 1, 'n_forecasts': 3, '_data_params': {'f...",peyton_manning_CustomSeasonalNaive_K_1_n_forec...,0.573357,0.648822,1.832438,0.791932
2,yosemite_temps,CustomSeasonalNaive,"{'K': 1, 'n_forecasts': 3, '_data_params': {'f...",yosemite_temps_CustomSeasonalNaive_K_1_n_forec...,0.631667,0.588167,1.627998,0.735115
