# M2 - SIR based

On top of solving the ODE parameters, data fitting can also be done by machine learning. An LSTM model, that looks back in time and is able to remember values from the past, can be used to learn the relationship of the model parameter values and the NPI-scores. Here, an assumption is made that there will be a delay between the first day of implementing an NPI and the day that actual effect is seen in the number of confirmed cases. This delay is computed by **Change Point Analysis**. 

In [6]:
import os
os.chdir("../LSTM")

In [7]:
from core.nn.LSTM_M2 import LSTM_M2

%load_ext autoreload
%autoreload 2

from SIR_ODE import SIR
import math
import pickle
import datetime
from numpy import array
import matplotlib.pylab as plt
import pandas as pd
import covsirphy as cs
import requests, io, json, urllib
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from itertools import cycle
import os.path
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import math
from sklearn.metrics import mean_squared_error
from itertools import cycle

import seaborn as sns
sns.set()
%matplotlib inline

### Prepare datasets

In [8]:
# Download datasets
data_loader = cs.DataLoader("input")
jhu_data = data_loader.jhu()
population_data = data_loader.population()
population_df = population_data.cleaned()
oxcgrt_data = data_loader.oxcgrt()
TARGET_NPI = "Stringency_index"
target_column = "Infected"
df = jhu_data.cleaned()

# Save copy
old_df = df.copy()

Retrieving datasets from COVID-19 Data Hub: https://covid19datahub.io/

Please set verbose=2 to see the detailed citation list.


Retrieving COVID-19 dataset in Japan from https://github.com/lisphilar/covid19-sir/data/japan


In [9]:
NPIS = [ 'Stringency_index', 'School_closing', 'Workplace_closing',
       'Gatherings_restrictions', 'Stay_home_restrictions',
       'International_movement_restrictions', 'Testing_policy']

### Select Country

In [13]:
COUNTRY = "Netherlands"
N = population_df[population_df["Country"] == COUNTRY]["Population"].values[0]
print("Population:", N)

Population: 17231624


In [7]:
def preprocess_data(COUNTRY):
    df = jhu_data.cleaned()
    df = df[df["Country"] == COUNTRY]
    old_df = df.copy()
    df = df[df["Province"] == "-"]
    df[df.columns[-4:]] = df[df.columns[-4:]].rolling(7).mean()
    df["New Confirmed"] = df.Confirmed.diff()
    
    SIR_LSTM = LSTM_M2(COUNTRY)
    DELAY_START, df_params, NPI_dates, days_delay = SIR_LSTM.estimate_country(jhu_data, 
                                                                              population_data, 
                                                                              oxcgrt_data, 
                                                                              TARGET_NPI)

    train = df[df["Date"] <= DELAY_START.strftime("%Y-%m-%d")]
    test = df[df["Date"] >= DELAY_START.strftime("%Y-%m-%d")]
    
    return df, train, test

### Compute associated SIR model parameters for a given Non-pharmaceutical Intervention (NPI).

In [16]:
def calc_param(df, dates):
    """
    Compute model parameters associated with NPIs, implemented
    on a given date range. 
    
        df: input dataframe with total model parameters over 
        time for a given country.
        dates: a list of dates (in timestamp format) that 
        cover the dates of implementing an NPI.
    """
    total_params = ["theta", "kappa", "rho", "sigma"]
    calc_params_df = {}
    for param in total_params:
        values = []
        for date in df["Date"].values:
            if date in dates.values:
                values.append(np.mean(df[df['Date'] == date][param]))
        calc_params_df[param] = np.mean(values)
    return calc_params_df

In [17]:
def get_res_df(NPI, selection, plot=True):
    params_total = {}
    sir_params_total = {}
    for p in NPI_dates:
        res = calc_param(df_params, pd.Series(NPI_dates[p]))
        
        # Check if parameter exists
        if not math.isnan(res["kappa"]):
            params_total[p] = res
            sir = SIR(N=N, I0=selection[target_column], R0=selection["Recovered"], 
                      beta=res["rho"], gamma=res["theta"], rho=res["rho"], sigma=res["sigma"],
                      days=len(test))
            SIR_results = sir.simulate(target="Infected", plot=False)
            sir_params_total[p] = SIR_results
    test["SIR Infected" + NPI] = SIR_results["I"]

    SIR_LSTM = LSTM_M2(COUNTRY, DELAY_START, FUTURE_DAYS=len(test))
    SIR_LSTM.input_data(df)
    results = SIR_LSTM.simulate();
    
    test["LSTM"] = results["pred"]
    test["M2"] = SIR_LSTM.update_predictions(test["SIR Infected" + NPI], tau=days_delay)
    subset_old_df = old_df[old_df["Country"] == COUNTRY]
    test["Observed"] = subset_old_df[subset_old_df["Date"] >= DELAY_START.strftime("%Y-%m-%d")]["Infected"]
    
    if plot:
        train_copy = train[train["Date"] >= pd.to_datetime('2020-08-01')].copy()
        ax = train_copy.plot(x="Date", y="Infected", label="Train");
        test.plot(x="Date", y=["Infected", "SIR Infected" + NPI, "LSTM", "M2", "Observed"], ax=ax);
        ax.axvline(x=DELAY_START.strftime("%Y-%m-%d"), color="grey");
    return test

In [18]:
COUNTRIES = ["United Kingdom", "Netherlands", "Japan", "United States", "China", "Australia"]
for COUNTRY in COUNTRIES:
    SIR_LSTM = LSTM_M2(COUNTRY)

    N = population_df[population_df["Country"] == COUNTRY]["Population"].values[0]
    print('Population in', COUNTRY, ":", N)
    results = pd.DataFrame(columns = ["RMSE", "MAE", "MAPE"])
    DELAY_START, df_params, NPI_dates, days_delay = SIR_LSTM.estimate_country(jhu_data, 
                                                                          population_data, 
                                                                          oxcgrt_data, 
                                                                          NPIS[0])
    for intervention in NPIS:
        print(intervention)
        if DELAY_START is False:
            errors = pd.DataFrame(columns = ["RMSE", "MAE", "MAPE"])
            errors = errors.append(pd.Series([np.nan, np.nan, np.nan], index=errors.columns ), ignore_index=True)
            errors = errors.set_index(pd.Index([intervention]))
            results = results.append(errors)
            print("too little data", intervention)
        else:
            df, train, test = preprocess_data(COUNTRY)
            selection = train.iloc[-1]
            res_df = get_res_df(NPI=intervention, selection=selection, plot=True)
            errors = SIR_LSTM.compute_errors(N, test)
            errors = errors.set_index(pd.Index([intervention]))
            results = results.append(errors)
    display(results)
    results.to_pickle('./results/' + COUNTRY + "_errors_SIR")
