In [1]:
import pandas as pd
import numpy as np
from itertools import product
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_absolute_error, mean_squared_error
import warnings

warnings.filterwarnings("ignore")

df = pd.read_excel('data/state_month_overdose.xlsx')
df['Deaths'] = df['Deaths'].apply(lambda x: 0 if x == 'Suppressed' else int(x))
df['Month'] = pd.to_datetime(df['Month'])

# Create train, val, test splits
def create_train_val_test_split(df, train_end='2019-01-01', val_end='2020-01-01', test_end='2020-12-01'):
    train = df[df['Month'] < train_end]
    val = df[(df['Month'] >= train_end) & (df['Month'] < val_end)]
    test = df[(df['Month'] >= val_end) & (df['Month'] <= test_end)]
    return train, val, test

train, val, _ = create_train_val_test_split(df)

# Combine train and val target series
train_series = train['Deaths']
val_series = val['Deaths']

# SARIMA parameter grid
p = d = q = range(0, 3)
P = D = Q = range(0, 3)
s = 12  # Seasonal period (monthly data)

# Try different combinations
param_grid = list(product(p, d, q))
seasonal_grid = list(product(P, D, Q))
results = []

for order in param_grid:
    for seasonal_order in seasonal_grid:
        try:
            model = SARIMAX(train_series,
                            order=order,
                            seasonal_order=(seasonal_order[0], seasonal_order[1], seasonal_order[2], s),
                            enforce_stationarity=False,
                            enforce_invertibility=False)
            fit_model = model.fit(disp=False)

            forecast = fit_model.predict(start=len(train_series), end=len(train_series)+len(val_series)-1)
            mae = mean_absolute_error(val_series, forecast)
            rmse = np.sqrt(mean_squared_error(val_series, forecast))

            results.append({
                'order': order,
                'seasonal_order': seasonal_order,
                'AIC': fit_model.aic,
                'BIC': fit_model.bic,
                'MAE': mae,
                'RMSE': rmse
            })

        except Exception as e:
            continue

# Convert results to DataFrame
results_df = pd.DataFrame(results)
results_df = results_df.sort_values(by='BIC')

# Save or print results
results_df.to_csv('sarima_gridsearch_results.csv', index=False)
results_df.head(10)  # Show top 10

Unnamed: 0,order,seasonal_order,AIC,BIC,MAE,RMSE
23,"(0, 0, 0)","(2, 1, 2)",25916.773737,25945.712721,65.469877,86.574793
71,"(0, 0, 2)","(1, 2, 2)",23421.05554,23455.74738,65.843706,86.589528
552,"(2, 0, 2)","(1, 1, 0)",22434.456243,22469.210337,65.569664,86.606415
543,"(2, 0, 2)","(0, 1, 0)",23150.675294,23179.659695,65.586964,86.890458
22,"(0, 0, 0)","(2, 1, 1)",25961.546957,25984.699803,65.406615,86.960717
68,"(0, 0, 2)","(1, 1, 2)",23385.554165,23420.275967,65.342886,87.088986
38,"(0, 0, 1)","(1, 0, 2)",24719.056669,24748.018413,65.77453,87.092361
77,"(0, 0, 2)","(2, 1, 2)",23387.541497,23428.050266,65.338006,87.09516
76,"(0, 0, 2)","(2, 1, 1)",23412.259887,23446.989156,65.337155,87.104067
67,"(0, 0, 2)","(1, 1, 1)",23496.842357,23525.802036,65.335511,87.109098


In [None]:
# results_df = pd.DataFrame(results)
results_df = results_df.sort_values(by='BIC')

# Save or print results
# results_df.to_csv('sarima_gridsearch_results.csv', index=False)
results_df.head(10)  # Show top 10

In [None]:
from tqdm import tqdm
import warnings
import random

warnings.filterwarnings("ignore")

df = pd.read_excel('data/state_month_overdose.xlsx')
df['Deaths'] = df['Deaths'].apply(lambda x: 0 if x == 'Suppressed' else int(x))
df['Month'] = pd.to_datetime(df['Month'])

def create_train_val_test_split(df, train_end='2019-01-01', val_end='2020-01-01', test_end='2020-12-01'):
    train = df[df['Month'] < train_end]
    val = df[(df['Month'] >= train_end) & (df['Month'] < val_end)]
    test = df[(df['Month'] >= val_end) & (df['Month'] <= test_end)]
    return train, val, test

train_df, val_df, _ = create_train_val_test_split(df)
train_series = train_df['Deaths'].reset_index(drop=True)
val_series = val_df['Deaths'].reset_index(drop=True)

# SARIMA parameter grid
p = d = q = range(0, 3)
P = D = Q = range(0, 3)
s = 12
param_grid = list(product(p, d, q))
seasonal_grid = list(product(P, D, Q))

# How many times to repeat each combination
N_TRIALS = 30

# Aggregate metrics for each configuration
all_results = []

print("Running SARIMA grid search with 30 trials per configuration...")
for order in tqdm(param_grid):
    for seasonal_order in seasonal_grid:
        metrics = {'AIC': [], 'BIC': [], 'MAE': [], 'RMSE': []}
        for _ in range(N_TRIALS):
            try:
                # Randomized seed to introduce variability
                random_seed = random.randint(0, 10000)
                np.random.seed(random_seed)

                model = SARIMAX(train_series,
                                order=order,
                                seasonal_order=(seasonal_order[0], seasonal_order[1], seasonal_order[2], s),
                                enforce_stationarity=False,
                                enforce_invertibility=False)
                results = model.fit(disp=False)

                preds = results.predict(start=len(train_series), end=len(train_series)+len(val_series)-1)
                preds = preds[:len(val_series)]  # safeguard

                metrics['AIC'].append(results.aic)
                metrics['BIC'].append(results.bic)
                metrics['MAE'].append(mean_absolute_error(val_series, preds))
                metrics['RMSE'].append(np.sqrt(mean_squared_error(val_series, preds)))

            except Exception:
                continue

        if len(metrics['RMSE']) >= 5:  # Keep only those with enough valid runs
            all_results.append({
                'order': order,
                'seasonal_order': seasonal_order,
                'mean_RMSE': np.mean(metrics['RMSE']),
                'std_RMSE': np.std(metrics['RMSE']),
                'mean_MAE': np.mean(metrics['MAE']),
                'std_MAE': np.std(metrics['MAE']),
                'mean_AIC': np.mean(metrics['AIC']),
                'std_AIC': np.std(metrics['AIC']),
                'mean_BIC': np.mean(metrics['BIC']),
                'std_BIC': np.std(metrics['BIC']),
                'num_successful_trials': len(metrics['RMSE'])
            })

# Save and display sorted results
results_df_agg = pd.DataFrame(all_results)
results_df_agg = results_df.sort_values(by='mean_RMSE')  # or use 'mean_BIC'

results_df_agg.to_csv('sarima_gridsearch_30trials_rmse.csv', index=False)
# print(results_df_agg.head(10))

results_df_agg = pd.DataFrame(all_results)
results_df_agg = results_df.sort_values(by='mean_BIC')  # or use 'mean_BIC'

results_df_agg.to_csv('sarima_gridsearch_30trials_bic.csv', index=False)

Running SARIMA grid search with 30 trials per configuration...


  0%|                                                                                    | 0/27 [00:00<?, ?it/s]