# Grid search rewrite

In [2]:
import itertools
import numpy as np
import statsmodels.api as sm
import pandas as pd

from joblib import Parallel, cpu_count, delayed
from warnings import catch_warnings, filterwarnings
from sklearn.metrics import mean_squared_error

In [3]:
# create a set of sarima configs to try
def sarima_configs():
    models = list()
    # Define p, d, and q to take any value between 0 and 2
    p = d = q = range(0, 2)
    # Generate all possible combinations of p, d, and q
    pdq = list(itertools.product(p, d, q))
    # Generate all possible combinations of seasonal parameters
    seasonal_pdq = [(x[0], x[1], x[2], 12) for x in list(itertools.product(p, d, q))]
    
    # define config lists
    t_params = ['n','c','t','ct']
    # create config instances
    for param in pdq:
        for param_seasonal in seasonal_pdq:
            for t in t_params:
                cfg = [param, param_seasonal, t]
                models.append(cfg)
    return models

In [4]:
# one-step sarima forecast
def sarima_forecast(data, config, date):
    order, sorder, trend = config
    # define model
    model = sm.tsa.statespace.SARIMAX(data, order=order, seasonal_order=sorder, trend=trend, enforce_stationarity=False, enforce_invertibility=False)
    # fit model
    model_fit = model.fit(disp=False)
    # predict next month
    predictions = model_fit.get_prediction(start=pd.to_datetime(date), dynamic=False)
    return predictions.predicted_mean

In [5]:
# root mean squared error or rmse
def measure_rmse(actual, predicted):
    predicted = pd.DataFrame(predicted, columns=['predicted_mean'])
    predicted.rename(columns={'predicted_mean':'Amount'}, inplace=True)
    mse = (np.square(predicted['Amount'] - actual['Amount'])).mean()
    return np.sqrt(mse)

In [6]:
# walk-forward validation for univariate data
def walk_forward_validation(data, cfg):
    predictions = sarima_forecast(data, cfg, date='2023-12-01')
    # estimate prediction error
    error = measure_rmse(data['2023-12-01':]['Amount'], predictions)
    return error

In [7]:
# score a model, return None on failure
def score_model(data, cfg, debug=False):
    result = None
    # convert config to a key
    key = str(cfg)
    # show all warnings and fail on exception if debugging
    if debug:
        result = walk_forward_validation(data, cfg)
    else:
        # one failure during model validation suggests an unstable config
        try:
            # never show warnings when grid searching, too noisy
            with catch_warnings():
                filterwarnings("ignore")
                result = walk_forward_validation(data, cfg)
        except:
            error = None
    # check for an interesting result
    if result is not None:
        print(' > Model[%s] %.3f' % (key, result))
    return (key, result)

In [8]:
# grid search configs
def grid_search(data, cfg_list, n_test, parallel=True):
    scores = None
    if parallel:
        # execute configs in parallel
        executor = Parallel(n_jobs=cpu_count(), backend='multiprocessing')
        tasks = (delayed(score_model)(data, n_test, cfg) for cfg in cfg_list)
        scores = executor(tasks)
    else:
        scores = [score_model(data, n_test, cfg) for cfg in cfg_list]
    # remove empty results
    scores = [r for r in scores if r[1] != None]
    # sort configs by error, asc
    scores.sort(key=lambda tup: tup[1])
    return scores