In [29]:
import itertools
import pyreadr
import numpy as np
import pandas as pd
from prophet import Prophet
from prophet.diagnostics import cross_validation, performance_metrics

data = pyreadr.read_r("./data/data_final.rds")
df = data[None]
df.columns = ['ds', 'y', 'PRCP', 'TMAX', 'TMIN', 'SNOW', 'TMAX_PRCP', 'TMIN_PRCP']
df['ds'] = pd.to_datetime(df['ds'])
df['y'] = np.log(df['y'])
df = df.drop(['TMAX', 'SNOW', 'TMAX_PRCP'], axis=1)
df.head()

Unnamed: 0,ds,y,PRCP,TMIN,TMIN_PRCP
0,1989-08-01,6.423247,3,139,417
1,1989-08-02,6.419995,0,100,0
2,1989-08-03,6.357842,0,50,0
3,1989-08-04,6.33328,0,50,0
4,1989-08-05,6.306275,0,61,0


In [34]:
param_grid = {
    'seasonality_prior_scale': [0.001, 0.01, .1, 1, 10], # default 10
    'growth': ['flat'],
    'weekly_seasonality': [False],
    'seasonality_mode': ['additive', 'multiplicative'],
}
all_params = [dict(zip(param_grid.keys(), v)) for v in itertools.product(*param_grid.values())]
regressors = {
    'none': None, 
    'PRCP': None, 
    'TMIN': None, 
    'TMIN_PRCP': None
}

for regressor in regressors:
    rmses = []
    maes = []
    mapes = []
    for params in all_params:
        m = Prophet(**params)
        if not regressor == 'none':
            m.add_regressor(name=regressor)
        m.fit(df)
        df_cv = cross_validation(
            m, 
            initial=f'{365*15} days',
            period='365 days',
            horizon='365 days',
            parallel='processes',
        )
        df_p = performance_metrics(df_cv, rolling_window=1)
        rmses.append(np.mean(df_p['rmse']))
        maes.append(np.mean(df_p['mae']))
        mapes.append(np.mean(df_p['mape']))
    
    tuning_results = pd.DataFrame(all_params)
    tuning_results['rmse'] = rmses
    tuning_results['mae'] = maes
    tuning_results['mape'] = mapes

    regressors[regressor] = tuning_results


15:04:18 - cmdstanpy - INFO - Chain [1] start processing
15:04:18 - cmdstanpy - INFO - Chain [1] done processing
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
15:04:20 - cmdstanpy - INFO - Chain [1] start processing
15:04:20 - cmdstanpy - INFO - Chain [1] done processing
15:04:20 - cmdstanpy - INFO - Chain [1] start processing
15:04:20 - cmdstanpy - INFO - Chain [1] start processing
15:04:20 - cmdstanpy - INFO - Chain [1] start processing
15:04:20 - cmdstanpy - INFO - Chain [1] start processing
15:04:20 - cmdstanpy - INFO - Chain [1] start processing
15:04:20 - cmdstanpy - IN

In [35]:
dfs_with_key = []

for key, results in regressors.items():
    df_copy = results.copy()
    df_copy['key'] = key
    dfs_with_key.append(df_copy)
combined_df = pd.concat(dfs_with_key, ignore_index=True)
combined_df.sort_values(by='rmse')

Unnamed: 0,seasonality_prior_scale,growth,weekly_seasonality,seasonality_mode,rmse,mae,mape,key
24,0.1,flat,False,additive,0.260258,0.184281,0.028135,TMIN
29,10.0,flat,False,multiplicative,0.260258,0.184281,0.028135,TMIN
26,1.0,flat,False,additive,0.260259,0.18428,0.028135,TMIN
28,10.0,flat,False,additive,0.260259,0.184281,0.028135,TMIN
27,1.0,flat,False,multiplicative,0.260259,0.184281,0.028135,TMIN
25,0.1,flat,False,multiplicative,0.26026,0.184284,0.028136,TMIN
22,0.01,flat,False,additive,0.260289,0.184395,0.028156,TMIN
23,0.01,flat,False,multiplicative,0.260334,0.184548,0.028184,TMIN
15,0.1,flat,False,multiplicative,0.263936,0.186747,0.028516,PRCP
17,1.0,flat,False,multiplicative,0.263936,0.186749,0.028516,PRCP


In [17]:
tuning_results

Unnamed: 0,changepoint_prior_scale,growth,seasonality_mode,rmse,mae,mape
0,0.001,linear,additive,0.280118,0.204488,0.031478
1,0.001,linear,multiplicative,0.281102,0.203783,0.031336
2,0.001,flat,additive,0.265715,0.186218,0.028303
3,0.001,flat,multiplicative,0.265715,0.186217,0.028303
4,0.01,linear,additive,0.292469,0.219416,0.033876
5,0.01,linear,multiplicative,0.294512,0.21982,0.033896
6,0.01,flat,additive,0.265715,0.186218,0.028303
7,0.01,flat,multiplicative,0.265715,0.186217,0.028303
8,0.05,linear,additive,0.296802,0.2245,0.034681
9,0.05,linear,multiplicative,0.299949,0.225749,0.034818
