In [16]:
import os
import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX

We first read in our train and test data. We assume that all the data are storedd as csv files in a separate `data` directory.

In [17]:
DIR_TRAIN_BY_STATE = "./data/train_by_state/"

train_data = pd.read_csv('./data/train_round2.csv', engine='python').filter(items=['ID', 'Province_State', 'Date', 'Confirmed', 'Deaths'])
test_data = pd.read_csv('./data/test_round2.csv', engine='python')
train_data.tail()

Unnamed: 0,ID,Province_State,Date,Confirmed,Deaths
11245,11245,Virginia,11-22-2020,217796,3938
11246,11246,Washington,11-22-2020,141260,2619
11247,11247,West Virginia,11-22-2020,40478,662
11248,11248,Wisconsin,11-22-2020,376238,3150
11249,11249,Wyoming,11-22-2020,28169,176


Since we will potentially have a different model for every state, for convenience, we separate the train data into respective states to accelerate the learning steps.

In [18]:
# Get list of state names
states_names = np.unique(np.array([train_data['Province_State']]))
assert(len(states_names) == 50)
print(states_names)

def split_train_data_by_state(train_data):
    for state in states_names:
        state_data = train_data[train_data['Province_State'] == state]
        csv_name = DIR_TRAIN_BY_STATE + state + ".csv"
        state_data.to_csv(csv_name)


['Alabama' 'Alaska' 'Arizona' 'Arkansas' 'California' 'Colorado'
 'Connecticut' 'Delaware' 'Florida' 'Georgia' 'Hawaii' 'Idaho' 'Illinois'
 'Indiana' 'Iowa' 'Kansas' 'Kentucky' 'Louisiana' 'Maine' 'Maryland'
 'Massachusetts' 'Michigan' 'Minnesota' 'Mississippi' 'Missouri' 'Montana'
 'Nebraska' 'Nevada' 'New Hampshire' 'New Jersey' 'New Mexico' 'New York'
 'North Carolina' 'North Dakota' 'Ohio' 'Oklahoma' 'Oregon' 'Pennsylvania'
 'Rhode Island' 'South Carolina' 'South Dakota' 'Tennessee' 'Texas' 'Utah'
 'Vermont' 'Virginia' 'Washington' 'West Virginia' 'Wisconsin' 'Wyoming']


But we only want to do this if we haven't done it already.

In [19]:
if not os.path.exists(DIR_TRAIN_BY_STATE):
    os.mkdir(DIR_TRAIN_BY_STATE)
    
if not len(os.listdir(DIR_TRAIN_BY_STATE)):
    split_train_data_by_state(train_data)

To get the best hyperparameters, we generate candidates to do a grid search for the one with best performance.

In [20]:
p = range(1, 5)
d = [1]
q = range(1, 8)
pdq_candidates = list(itertools.product(p, d, q))
seasonal_pdq_candidates = [(x[0], x[1], x[2], 12) for x in list(itertools.product(p, d, q))]
print(pdq_candidates)
print(seasonal_pdq_candidates)

[(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4), (1, 1, 5), (1, 1, 6), (1, 1, 7), (2, 1, 1), (2, 1, 2), (2, 1, 3), (2, 1, 4), (2, 1, 5), (2, 1, 6), (2, 1, 7), (3, 1, 1), (3, 1, 2), (3, 1, 3), (3, 1, 4), (3, 1, 5), (3, 1, 6), (3, 1, 7), (4, 1, 1), (4, 1, 2), (4, 1, 3), (4, 1, 4), (4, 1, 5), (4, 1, 6), (4, 1, 7)]
[(1, 1, 1, 12), (1, 1, 2, 12), (1, 1, 3, 12), (1, 1, 4, 12), (1, 1, 5, 12), (1, 1, 6, 12), (1, 1, 7, 12), (2, 1, 1, 12), (2, 1, 2, 12), (2, 1, 3, 12), (2, 1, 4, 12), (2, 1, 5, 12), (2, 1, 6, 12), (2, 1, 7, 12), (3, 1, 1, 12), (3, 1, 2, 12), (3, 1, 3, 12), (3, 1, 4, 12), (3, 1, 5, 12), (3, 1, 6, 12), (3, 1, 7, 12), (4, 1, 1, 12), (4, 1, 2, 12), (4, 1, 3, 12), (4, 1, 4, 12), (4, 1, 5, 12), (4, 1, 6, 12), (4, 1, 7, 12)]


In [21]:
def mape(pred, gt):
    ape = np.abs(pred - gt) / np.abs(gt)
    return np.mean(ape) * 100

In [22]:
def train_valid_split():
    valid_data = train_data[(train_data['Date'] >= "11-01-2020") & (train_data['Date'] <= "11-22-2020")]
    train_set = train_data[(train_data['Date'] < "11-01-2020")]
    valid_confirmed_dict = {}    
    valid_death_dict = {}
    train_confirmed_dict = {}    
    train_death_dict = {}

    for state in states_names:
        state_valid = valid_data[valid_data["Province_State"] == state]
        state_train = train_set[train_set["Province_State"] == state]

        state_valid_c = np.array(state_valid["Confirmed"].tolist())
        state_valid_d = np.array(state_valid["Deaths"].tolist())
        valid_confirmed_dict[state] = state_valid_c
        valid_death_dict[state] = state_valid_d
        
        state_train_c = np.array(state_train["Confirmed"].tolist())
        state_train_d = np.array(state_train["Deaths"].tolist())
        train_confirmed_dict[state] = state_train_c
        train_death_dict[state] = state_train_d

    return train_confirmed_dict, train_death_dict, valid_confirmed_dict, valid_death_dict


In [23]:
train_confirmed_dict, train_death_dict, valid_confirmed_dict, valid_death_dict = train_valid_split()
train_confirmed_dict["Alabama"].shape[0]

203

In [36]:
def run_arima():
    predictions = []
    for state in states_names:
        state_data = pd.read_csv(DIR_TRAIN_BY_STATE + state + ".csv")
        
        mape_confirmed = 1e7
        pdq_confirmed = None
        model_confirmed = None

        mape_death = 1e7
        pdq_death = None
        model_death = None

        for pdq in pdq_candidates:
            mod = ARIMA(train_confirmed_dict[state], order=pdq, enforce_stationarity=False)
            f = mod.fit(method="statespace")
            pred_c = f.predict(start=train_confirmed_dict[state].shape[0], end=train_confirmed_dict[state].shape[0] + valid_confirmed_dict[state].shape[0] - 1)
            error = mape(np.array(pred_c.tolist()), valid_confirmed_dict[state])
            if error < mape_confirmed:
                print("Updating param: ", error, pdq)
                mape_confirmed = error
                pdq_confirmed = pdq

        print("Best parameter for ", state, "'s confirmed is PDQ: ", pdq_confirmed)
        model_confirmed = ARIMA(state_data['Confirmed'], order=pdq_confirmed, enforce_stationarity=False)

        for pdq in pdq_candidates:
            mod = ARIMA(train_death_dict[state], order=pdq, enforce_stationarity=False)

            f = mod.fit(method="statespace")
            pred_d = f.predict(start=train_death_dict[state].shape[0], end=train_death_dict[state].shape[0] + valid_death_dict[state].shape[0] - 1)
            error = mape(np.array(pred_d.tolist()), valid_death_dict[state])
            if error < mape_death:
                print("Updating param: ", error, pdq)
                mape_death = error
                pdq_death = pdq
        
        print("Best parameter for ", state, "'s death is PDQ: ", pdq_death)
        model_death = ARIMA(state_data['Deaths'], order=pdq_death, enforce_stationarity=False)

        fit_confirmed = model_confirmed.fit(method="statespace")
        predict_confirmed = fit_confirmed.predict(start=225, end=245)
        fit_death = model_death.fit(method="statespace")

        predict_death = fit_death.predict(start=225, end=245)
        predictions.append((np.around(predict_confirmed, decimals=0), np.around(predict_death, decimals=0)))

    return predictions

In [53]:
def run_seasonal_arima():
    predictions = []
    for state in states_names:
        state_data = pd.read_csv(DIR_TRAIN_BY_STATE + state + ".csv")

        mape_confirmed = 1e7
        pdq_confirmed = None
        seasonal_pdq_confirmed = None
        model_confirmed = None

        mape_death = 1e7
        pdq_death = None
        seasonal_pdq_death = None
        model_death = None

        for pdq in pdq_candidates:
            try:
                mod = SARIMAX(train_confirmed_dict[state], order=pdq,# seasonal_order=seasonal_pdq,
                              enforce_stationarity=False, enforce_invertibility=False)

                f = mod.fit(disp=False, method='powell')
                pred_c = f.predict(start=train_confirmed_dict[state].shape[0], end=train_confirmed_dict[state].shape[0] + valid_confirmed_dict[state].shape[0] - 1)
                error = mape(np.array(pred_c.tolist()), valid_confirmed_dict[state])
                if error < mape_confirmed:
                    print("Updating confirmed param: ", error, pdq)
                    mape_confirmed = error
                    pdq_confirmed = pdq
            except:
                continue
       
        print("Best parameter for ", state, "'s confirmed is PDQ: ", pdq_confirmed)
        model_confirmed = SARIMAX(state_data['Confirmed'], order=pdq_confirmed,# seasonal_order=seasonal_pdq_confirmed,
                                  enforce_stationarity=False, enforce_invertibility=False)

        for pdq in pdq_candidates:
                try:
                    mod = SARIMAX(train_death_dict[state], order=pdq,# seasonal_order=seasonal_pdq,
                                  enforce_stationarity=False, enforce_invertibility=False)
                    f = mod.fit(disp=False, method='powell')
                    pred_d = f.predict(start=train_death_dict[state].shape[0], end=train_death_dict[state].shape[0] + valid_death_dict[state].shape[0] - 1)
                    error = mape(np.array(pred_d.tolist()), valid_death_dict[state])
                    if error < mape_death:
                        print("Updating death param: ", error, pdq)
                        mape_death = error
                        pdq_death = pdq
                except:
                    continue
        
        print("Best parameter for ", state, "'s deaths is PDQ: ", pdq_death)
        model_death = SARIMAX(state_data['Deaths'],order=pdq_death,# seasonal_order=seasonal_pdq_death,
                                  enforce_stationarity=False, enforce_invertibility=False)

        fit_c = model_confirmed.fit(disp=False, method='powell')
        y_pred_confirmed = fit_c.predict(start=239, end=245)
        fit_d = model_death.fit(disp=False, method='powell')

        y_pred_deaths = fit_d.predict(start=239, end=245)
        predictions.append((np.around(y_pred_confirmed, decimals=0), np.around(y_pred_deaths, decimals=0)))

    return predictions

In [54]:
# predictions = run_arima()
predictions = run_seasonal_arima()

predictions

Updating confirmed param:  0.6774556777337909 (1, 1, 1)
Updating confirmed param:  0.6769914314694178 (1, 1, 3)
Updating confirmed param:  0.6751151804847361 (2, 1, 1)
Updating confirmed param:  0.6683758927423774 (2, 1, 2)
Updating confirmed param:  0.6675154099749339 (3, 1, 3)
Updating confirmed param:  0.641834075566313 (4, 1, 3)
Best parameter for  Alabama 's confirmed is PDQ:  (4, 1, 3)
Updating death param:  2.0474580561732436 (1, 1, 1)
Updating death param:  1.9806573099945355 (1, 1, 2)
Updating death param:  1.9785982262792188 (1, 1, 3)
Updating death param:  1.881865264048907 (2, 1, 2)
Updating death param:  1.6226314745785977 (3, 1, 4)
Best parameter for  Alabama 's deaths is PDQ:  (3, 1, 4)
Updating confirmed param:  10.60378562451849 (1, 1, 1)
Updating confirmed param:  9.29597159060175 (1, 1, 3)
Updating confirmed param:  9.25385117652727 (1, 1, 4)
Updating confirmed param:  8.288351586127307 (2, 1, 3)
Updating confirmed param:  6.956529129539603 (4, 1, 1)
Best parameter f

Best parameter for  Illinois 's confirmed is PDQ:  (4, 1, 1)
Updating death param:  3.7929740208973945 (1, 1, 1)
Updating death param:  3.3843497138070333 (1, 1, 4)
Updating death param:  3.352701285507554 (1, 1, 5)
Updating death param:  3.2577296647015737 (1, 1, 7)
Best parameter for  Illinois 's deaths is PDQ:  (1, 1, 7)
Updating confirmed param:  3.712871722445235 (1, 1, 1)
Best parameter for  Indiana 's confirmed is PDQ:  (1, 1, 1)
Updating death param:  3.650356147616493 (1, 1, 1)
Updating death param:  2.362762110419291 (1, 1, 3)
Updating death param:  1.8636747058529957 (1, 1, 5)
Updating death param:  1.6388536486968 (2, 1, 1)
Updating death param:  1.2851402736374429 (2, 1, 5)
Best parameter for  Indiana 's deaths is PDQ:  (2, 1, 5)
Updating confirmed param:  6.2840822820016795 (1, 1, 1)
Updating confirmed param:  2.0168341961576997 (2, 1, 6)
Best parameter for  Iowa 's confirmed is PDQ:  (2, 1, 6)
Updating death param:  4.724807295198071 (1, 1, 1)
Updating death param:  4.07

Updating death param:  3.6068650032242346 (1, 1, 6)
Updating death param:  3.319241731958697 (1, 1, 7)
Updating death param:  3.2197334279871694 (3, 1, 7)
Best parameter for  Missouri 's deaths is PDQ:  (3, 1, 7)
Updating confirmed param:  4.834775497822874 (1, 1, 1)
Updating confirmed param:  4.7797477230822745 (1, 1, 2)
Updating confirmed param:  3.3348549770205476 (1, 1, 4)
Updating confirmed param:  3.172474064639552 (1, 1, 5)
Updating confirmed param:  3.169835821750818 (1, 1, 6)
Updating confirmed param:  1.5261121679600644 (1, 1, 7)
Updating confirmed param:  1.4233356436472668 (2, 1, 7)
Best parameter for  Montana 's confirmed is PDQ:  (2, 1, 7)
Updating death param:  20.120697689798998 (1, 1, 1)
Updating death param:  19.995903165244922 (2, 1, 3)
Updating death param:  18.596628633431518 (3, 1, 3)
Best parameter for  Montana 's deaths is PDQ:  (3, 1, 3)
Updating confirmed param:  6.084775164143013 (1, 1, 1)
Updating confirmed param:  5.644351279208759 (1, 1, 7)
Updating confir

Updating confirmed param:  6.377057252183256 (1, 1, 3)
Updating confirmed param:  5.584260610704412 (1, 1, 6)
Updating confirmed param:  5.4654564261557885 (4, 1, 5)
Best parameter for  Rhode Island 's confirmed is PDQ:  (4, 1, 5)
Updating death param:  1.233038728526294 (1, 1, 1)
Updating death param:  1.2136264610769372 (1, 1, 3)
Updating death param:  1.1977540563323275 (1, 1, 6)
Updating death param:  0.9107697572941041 (1, 1, 7)
Updating death param:  0.8904225454293822 (2, 1, 7)
Best parameter for  Rhode Island 's deaths is PDQ:  (2, 1, 7)
Updating confirmed param:  1.3532760676435152 (1, 1, 1)
Updating confirmed param:  1.2435840775090008 (1, 1, 4)
Updating confirmed param:  1.1969891343223793 (1, 1, 5)
Updating confirmed param:  0.7295726080416387 (3, 1, 3)
Updating confirmed param:  0.7192076505951744 (4, 1, 3)
Best parameter for  South Carolina 's confirmed is PDQ:  (4, 1, 3)
Updating death param:  2.0614035106261572 (1, 1, 1)
Updating death param:  2.057475364485243 (1, 1, 2

[(239    261321.0
  240    262865.0
  241    264413.0
  242    265964.0
  243    267518.0
  244    269077.0
  245    270638.0
  dtype: float64,
  239    5208.0
  240    5249.0
  241    5291.0
  242    5333.0
  243    5375.0
  244    5417.0
  245    5460.0
  dtype: float64),
 (239    12366.0
  240    12444.0
  241    12522.0
  242    12600.0
  243    12678.0
  244    12757.0
  245    12836.0
  dtype: float64,
  239    282.0
  240    289.0
  241    295.0
  242    301.0
  243    308.0
  244    315.0
  245    322.0
  dtype: float64),
 (239    93959.0
  240    93375.0
  241    92855.0
  242    92399.0
  243    92010.0
  244    91688.0
  245    91434.0
  dtype: float64,
  239    6158.0
  240    6160.0
  241    6163.0
  242    6165.0
  243    6167.0
  244    6169.0
  245    6171.0
  dtype: float64),
 (239    73771.0
  240    73793.0
  241    73814.0
  242    73835.0
  243    73855.0
  244    73875.0
  245    73894.0
  dtype: float64,
  239    4275.0
  240    4348.0
  241    4422.0
  242    44

In [214]:
test_submission = test_data.sort_values(["Province_State", "Date"])
confirmed = []
deaths = []
for i in range(50):
    confirmed.append(predictions[i][0].astype(int).tolist())
    deaths.append(predictions[i][1].astype(int).tolist())

confirmed = list(itertools.chain.from_iterable(confirmed))
deaths = list(itertools.chain.from_iterable(deaths))

test_submission.loc[:, "Confirmed"] = confirmed
test_submission.loc[:, "Deaths"] = deaths

test_submission = test_submission.sort_index().filter(items=['ForecastID', 'Confirmed', 'Deaths'])
test_submission.to_csv("Team15_arima_round2.csv", index=False)