In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import LSTM, Dense
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error


df = pd.read_excel('data/state_month_overdose.xlsx')
df['Deaths'] = df['Deaths'].apply(lambda x: 0 if x == 'Suppressed' else int(x))
df['Month'] = pd.to_datetime(df['Month'])
df = df.groupby(['Month']).agg({'Deaths': 'sum'}).reset_index()


def create_dataset(dataset, look_back=1):
    dataX, dataY = [], []
    for i in range(len(dataset)-look_back):
        a = dataset.iloc[i:(i+look_back)].values
        dataX.append(a)
        dataY.append(dataset.iloc[i + look_back]) 
    return np.array(dataX), np.array(dataY)

# def ci_overlap_percentage(pred1, pred2, ci1, ci2):
#     overlap_count = 0
#     for i in range(len(pred1)):
#         if (pred1[i] - ci1[i] <= pred2[i] + ci2[i]) and (pred1[i] + ci1[i] >= pred2[i] - ci2[i]):
#             overlap_count += 1
#     return (overlap_count / len(pred1)) * 100


def calculate_confidence_intervals(predictions, alpha=0.05):
    # Calculate mean and standard deviation
    mean_pred = np.mean(predictions)
    std_pred = np.std(predictions)
    
    # Calculate the z-score for the confidence level
    z_score = 1.96  # for 95% confidence
    margin_of_error = z_score * (std_pred / np.sqrt(len(predictions)))
    
    lower_bound = predictions - margin_of_error
    upper_bound = predictions + margin_of_error
    
    return lower_bound, upper_bound


def calculate_overlap(lower1, upper1, lower2, upper2):
    # Initialize overlap count
    overlap_count = 0

    for l1, u1, l2, u2 in zip(lower1, upper1, lower2, upper2):
        # Check for overlap
        if u1 >= l2 and l1 <= u2:
            overlap_count += 1

    # Calculate percent overlap
    percent_overlap = (overlap_count / len(lower1)) * 100
    return percent_overlap
    

validation_periods = [
    ('2019-11-01', '2020-01-01'), #if training will be up until 2019-12-01
    ('2019-09-01', '2020-01-01'),
    ('2019-07-01', '2020-01-01'),
    ('2019-01-01', '2020-01-01'),
    ('2018-07-01', '2020-01-01'),
    ('2018-01-01', '2020-01-01')
]

look_back_periods = range(3, 12, 2)  # 1, 3, 5, ..., 11 months look-back

def generate_forecast(model, initial_sequence, look_back, num_predictions=12):
    predictions = []
    for_model = initial_sequence
    
    for _ in range(num_predictions):
        # Generate the next prediction
        pred = model.predict(for_model)
        predictions.append(pred[0][0])
        
        # Update the input for the next prediction
        # We need to construct a new input array of the same shape as the original input
        new_input = np.append(for_model[:, 1:], pred[0][0])  # Shift and append the new prediction
        for_model = new_input.reshape((1, look_back, 1))

    return np.array(predictions)


results = []


for val_start, val_end in validation_periods:
    for look_back in [look_back_periods[0]]:
        df = df.copy()
        
        # Split data
        train = df[df['Month'] <= val_start]
        val = df[(df['Month'] >= val_start) & (df['Month'] <= val_end)]
        test = df[df['Month'] >= '2020-01-01']
        
        # Create datasets for LSTM
        trainX, trainY = create_dataset(train['Deaths'], look_back)
        valX, valY = create_dataset(val['Deaths'], look_back)
        testX, testY = create_dataset(test['Deaths'], look_back)
        
        trainX = trainX.reshape(trainX.shape[0], trainX.shape[1], 1)
        valX = valX.reshape(valX.shape[0], valX.shape[1], 1)
        testX = testX.reshape(testX.shape[0], testX.shape[1], 1)
        
        # Build and train initial LSTM model on the training data
        model = Sequential()
        model.add(LSTM(50, activation='relu', input_shape=(look_back, 1)))
        model.add(Dense(1))
        model.compile(loss='mean_squared_error', optimizer='adam')
        model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=0)
        
        val_initial_sequence = np.array(train['Deaths'][-look_back:]).reshape((look_back, 1))
        val_initial_sequence = np.array([val_initial_sequence])
        valPredict = generate_forecast(model, val_initial_sequence, look_back, num_predictions=len(valX))
        
        # Retrain the LSTM model on both training and validation data
        combined_train_val = pd.concat([train, val], axis=0)
        combinedX, combinedY = create_dataset(combined_train_val['Deaths'], look_back)
        combinedX = combinedX.reshape(combinedX.shape[0], combinedX.shape[1], 1)
        model.fit(combinedX, combinedY, epochs=100, batch_size=1, verbose=0)
        
        test_initial_sequence = np.array([[valPredict[-1]]])
        test_initial_sequence = np.array([test_initial_sequence])
        testPredict = generate_forecast(model, test_initial_sequence, look_back, num_predictions=len(testX))

        trainPredict = model.predict(trainX)
        
        # LSTM metrics
        lstm_mape = mean_absolute_percentage_error(testY, testPredict)
        lstm_mse = mean_squared_error(testY, testPredict)
        lstm_rmse = np.sqrt(lstm_mse)
        
        combined_array = [0] * look_back + trainPredict.flatten().tolist() + valPredict.flatten().tolist() + testPredict.flatten().tolist()
        
        df['LSTM Predictions'] = combined_array[:len(df)]
        
        # SARIMA model retraining on combined training + validation data
        sarima_model = SARIMAX(combined_train_val['Deaths'], order=(1, 1, 1), seasonal_order=(1, 1, 1, 12),
                               enforce_stationarity=False, enforce_invertibility=False)
        sarima_result = sarima_model.fit(disp=False)
        
        sarima_predictions = sarima_result.predict(start=0, end=df.shape[0]-1, dynamic=False)
        df['SARIMA Predictions'] = sarima_predictions
        sarimaTestPredict = df[df['Month'] > '2020-01-01']['SARIMA Predictions']
        
        # SARIMA metrics
        sarima_mape = mean_absolute_percentage_error(testY, sarimaTestPredict)
        sarima_mse = mean_squared_error(testY, sarimaTestPredict)
        sarima_rmse = np.sqrt(sarima_mse)
        
        # Calculate CI overlap
        lower_bound_test, upper_bound_test = calculate_confidence_intervals(testPredict)
        lower_bound_sarima, upper_bound_sarima = calculate_confidence_intervals(sarimaTestPredict)
        
        ci_overlap = calculate_overlap(lower_bound_test, upper_bound_test, lower_bound_sarima, upper_bound_sarima)
        
        results.append({
            'Validation Period': f"{val_start} to {val_end}",
            'Look-back': look_back,
            'LSTM MAPE': lstm_mape,
            'LSTM MSE': lstm_mse,
            'LSTM RMSE': lstm_rmse,
            'SARIMA MAPE': sarima_mape,
            'SARIMA MSE': sarima_mse,
            'SARIMA RMSE': sarima_rmse,
            'CI Overlap %': ci_overlap
        })

results_df = pd.DataFrame(results)

  df['Month'] = pd.to_datetime(df['Month'])
  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 197ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 185ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(
  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 179ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 158ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(
  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 152ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 149ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(
  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 157ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 203ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(
  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 166ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 165ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(
  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 187ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 216ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(


In [4]:
df.to_csv('flexible_lookback_predictions_test.csv')
results_df

Unnamed: 0,Validation Period,Look-back,LSTM MAPE,LSTM MSE,LSTM RMSE,SARIMA MAPE,SARIMA MSE,SARIMA RMSE,CI Overlap %
0,2019-11-01 to 2020-01-01,1,0.147086,1299144.0,1139.80017,0.114675,975519.6,987.683959,72.727273
1,2019-09-01 to 2020-01-01,1,0.246569,2707182.0,1645.351597,0.155781,1324804.0,1151.001302,0.0
2,2019-07-01 to 2020-01-01,1,0.20518,2034954.0,1426.518007,0.156734,1337892.0,1156.672851,9.090909
3,2019-01-01 to 2020-01-01,1,0.235765,2522153.0,1588.128715,0.169823,1521660.0,1233.555764,0.0
4,2018-07-01 to 2020-01-01,1,0.129163,1120490.0,1058.531926,0.169374,1530380.0,1237.085348,36.363636
5,2018-01-01 to 2020-01-01,1,0.085364,620789.7,787.9021,0.171116,1597333.0,1263.856518,0.0


In [5]:
for val_start, val_end in validation_periods:
    for look_back in [look_back_periods[0]]:
        print(val_start, val_end, look_back)

2019-11-01 2020-01-01 1
2019-09-01 2020-01-01 1
2019-07-01 2020-01-01 1
2019-01-01 2020-01-01 1
2018-07-01 2020-01-01 1
2018-01-01 2020-01-01 1
