# Forecast Combination

In [1]:
import os
os.chdir("../../")

import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from scripts.python.tsa.mtsmodel import *
from scripts.python.tsa.ts_eval import *

import seaborn as sns
sns.set_style("whitegrid")
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

As [Timmermann (2004)](https://doi.org/10.1016/S1574-0706(05)01004-9) summarizes Stock and Watson (2001)'s relative performance weights. Let $MSE_{t+h,t,i} = (1/v)\sum_{\tau=t-v}^{t} e^{2}_{\tau,\tau−h,i}$ be the $i$th forecasting model’s MSE at time $t$, computed over a window of the previous $v$ periods. Then

$$ \hat{y}_{t+h,t} = \sum_{i=1}^{N} \hat{\omega}_{t+h,t,i} \hat{y}_{t+h,t,i}, \text{ where } \hat{\omega}_{t+h,t,i} = \frac{(1/MSE_{t+h,t,i})}{\sum_{j=1}^{N} (1/MSE_{t+h,t,j})}$$

Below are the functions to calculate the relative performance weight where $i \in \{sarimax, lf, var\}$.

In [2]:
def calculate_mse(predictions_df: pd.DataFrame, method: str) -> pd.Series:
    total = predictions_df["total"]
    prediction = predictions_df[method]
    mse = np.square(total - prediction).cumsum() / (predictions_df.index + 1)
    return mse


def calculate_rpw(predictions_df: pd.DataFrame, methods: list) -> pd.Series:
    mse_dict = {method: calculate_mse(predictions_df, method)
                for method in methods}
    denominator = sum(1 / mse_dict[method] for method in methods)
    rpw_dict = {}
    for method in methods:
        numerator = 1 / mse_dict[method]
        omega = numerator / denominator
        rpw_dict[method] = omega
    return pd.Series(rpw_dict)


def get_rpw(pred_df: pd.DataFrame, 
            methods: list = ["sarimax", "var", "lf"]) -> pd.Series:
    predictions_df = pred_df.copy()
    rpw_series = calculate_rpw(predictions_df, methods)

    combo_cols = []
    for method in methods:
        weight = str(method) + "_weight"
        predictions_df[weight] = predictions_df[method] * rpw_series[method]
        combo_cols.append(weight)

    rpw = predictions_df[combo_cols].sum(axis=1)
    return rpw

In [3]:
for country in ["palau", "samoa", "tonga", "solomon_islands", "vanuatu"]:

    folderpath = os.getcwd() + "/data/tourism/" + str(country) + "/model/"
    mappings = [("sarimax", "train_pred"),
                ("var", "pred_total"), ("lf", "pred_mean")]

    country_pred = pd.DataFrame()
    for mapping in mappings:
        model, column = mapping
        filepath = folderpath + str(model) + "_" + str(country) + ".csv"
        pred_df = pd.read_csv(filepath).drop("Unnamed: 0", axis=1)
        pred_df["date"] = pd.to_datetime(pred_df["date"])

        model_col = (pred_df[["date", "total", column]]
                     .rename({column: model}, axis=1))

        if country_pred.empty:
            country_pred = model_col
        else:
            country_pred = country_pred.merge(model_col).fillna(0)

    # Mean
    country_pred["mean_ensemble"] = (
        country_pred[["sarimax", "var", "lf"]].mean(axis=1))

    # Relative Performance Weights
    country_pred["rpw"] = get_rpw(country_pred)

    # OLS (regularized)
    ols = smf.ols("total~sarimax+var+lf", data=country_pred)
    ols_res = ols.fit()
    ols_reg = ols.fit_regularized()
    country_pred["ols"] = ols_res.fittedvalues
    country_pred["ols_regularized"] = ols_reg.fittedvalues
    
    #
    country_pred.to_csv(folderpath+"forecast_combo.csv",
                        encoding="utf-8")

    evals = pd.DataFrame()
    for col in ["sarimax", "var", "lf", "mean_ensemble", "rpw", "ols", "ols_regularized"]:
        mod_eval = pd.DataFrame(calculate_evaluation(country_pred["total"], country_pred[col]),
                                index=[col])
        evals = pd.concat([evals, mod_eval], axis=0)

    evals.columns.name = str(country)
    evals = evals.style.apply(
        lambda x: ['background-color: yellow' if v == x.min() else '' for v in x])
    display(evals)

palau,MSE,RMSE,MAE,SMAPE
sarimax,1586348.581999,1259.503308,701.765452,53.857541
var,1127006.129408,1061.605449,554.935555,38.243616
lf,494503.384915,703.209346,389.892735,40.776367
mean_ensemble,532519.96134,729.739653,412.870015,33.837139
rpw,403353.773427,635.101388,346.9387,33.486236
ols,439753.349762,663.139012,387.886317,52.265748
ols_regularized,470422.071425,685.873218,388.073317,45.216361


samoa,MSE,RMSE,MAE,SMAPE
sarimax,8076301.733757,2841.883483,1410.434675,141.993878
var,10290887.267182,3207.941282,1757.557409,141.614714
lf,2107650.783058,1451.775046,763.990597,131.108537
mean_ensemble,3794368.839523,1947.913971,1093.282593,135.822954
rpw,2242890.336299,1497.628237,746.332632,131.773854
ols,1960603.16188,1400.215398,801.094036,131.276441
ols_regularized,2003690.119502,1415.517615,835.530337,131.80255


tonga,MSE,RMSE,MAE,SMAPE
sarimax,776703.265602,881.307702,381.831288,80.236281
var,1226895.371895,1107.653092,472.092955,54.359031
lf,678451.09755,823.68143,319.198937,102.021722
mean_ensemble,329857.066939,574.331844,223.2797,82.645615
rpw,219927.81439,468.96462,201.580152,91.059784
ols,192916.53801,439.222652,183.856214,61.485168
ols_regularized,193816.099256,440.245499,183.024479,60.651486


solomon_islands,MSE,RMSE,MAE,SMAPE
sarimax,47538.016148,218.032145,149.903809,32.72611
var,65021.669854,254.99347,154.338247,17.694166
lf,48753.426027,220.80178,150.232442,26.614403
mean_ensemble,33204.357994,182.22063,125.903946,23.072342
rpw,29425.623378,171.538985,122.944008,23.240392
ols,32887.208317,181.348307,127.21158,32.506166
ols_regularized,38922.471964,197.28779,139.112394,44.268386


vanuatu,MSE,RMSE,MAE,SMAPE
sarimax,669250.974064,818.077609,372.363525,133.369731
var,1600258.286965,1265.013157,554.204122,134.493067
lf,637546.169465,798.464883,495.963575,128.001919
mean_ensemble,495178.430077,703.689157,350.782283,131.534514
rpw,332258.497424,576.418682,289.61672,130.571867
ols,392860.854657,626.786132,349.258667,131.047249
ols_regularized,419613.535471,647.775837,339.038167,131.555661
