In [1]:
# Imports
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../") 

import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpyro
import pandas as pd

import arviz as az

numpyro.set_host_device_count(4)

from epimodel import run_model_with_settings, default_model, latent_nn_model_legacy, arma_model, latent_nn_model, EpidemiologicalParameters, preprocess_data
from epimodel.models.model_predict_utils import *
from epimodel.models.model_build_utils import *

Importing plotly failed. Interactive plots will not work.


In [2]:
# get all_data and settings
end_date = "2020-12-10"

all_data = preprocess_data('../data/all_merged_data_2021-01-22.csv')
all_data.featurize()

data = preprocess_data('../data/all_merged_data_2021-01-22.csv', end_date=end_date)
data.featurize()

ep = EpidemiologicalParameters()

Processing data from 2020-08-01 00:00:00 to 2021-01-22 00:00:00
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded
Processing data from 2020-08-01 00:00:00 to 2020-12-10 00:00:00
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded


In [87]:
model_settings_nn = {'input_death':True,
                    'preprocessed_data': 'None',
                    'num_percentiles':None,
                    'n_days_nn_input':21,
                    'D_layers': [15,10,5],
                    'report_period':True,
                    'infer_cfr':False}

run_model_with_settings(
    end_date="2020-12-10",
    num_samples=10,
    num_warmup=30,
    save_results=True,
    model_kwargs=model_settings_nn,
)

In [3]:
# Gets samples from pre-run models

samples_arma = netcdf_to_dict('samples/arma/learn/cross_validate/2020-12-10_p2q0_learn.netcdf')

samples_nn = netcdf_to_dict('samples/2.netcdf')


In [5]:
prediction_date = pd.to_datetime(end_date)

samp = list(range(0,150))
# plot infections, cases and deaths
    
for reg in range(0,3):
    plot_graph((samples_nn['expected_cases'][samp,:,:], data.new_cases.data), 'cases_nn.png', percentiles=None, label=['Inference', 'Actual Cases'], title='Cases Inferred by NN model', region=reg)
    plot_graph((samples_nn['expected_deaths'][samp,:,:], data.new_deaths.data), 'deaths_nn.png',percentiles=None,label=['Inference', 'Actual Deaths'],title='Deaths Inferred by NN model', region=reg)
    plot_graph((samples_nn['total_infections'][samp,:,:]), 'infections_nn.png',label=['Inference', 'Actual Deaths'],title='Deaths Inferred by NN model', region=reg)
    plot_graph(samples_nn['Rt'][samp,:,:], 'R_nn.png',percentiles=None,  region=reg,title='R_t Inferred by Epi-NN')
    plot_graph(samples_nn['cfr'][samp,:,:], 'cfr_nn.png',percentiles=None, region=reg)
    
    



<Figure size 432x288 with 0 Axes>

In [24]:
# Num days to predict after end of data

look_ahead = 23

In [26]:
# PREDICTION USING ARMA METHOD
# predict future cases and death using the arma method

# gets the cms just after this point
future_cms = all_data.active_cms[:,:, data.nDs:]

# gets expected cases and deaths
cases_total_arma, deaths_total_arma, Rt_total_arma, \
            Rt_cms_total_arma, Rt_noise_total_arma = arma_noise_predictor(samples_arma, 
                                                                          ep,                  
                                                                          data.active_cms, 
                                                                          look_ahead, 
                                                                          future_cms=future_cms,        
                                                                          ignore_last_days=6)


In [35]:
# PREDICTION USING NEURAL NET METHOD

# predict future cases and death using the latent nn method

cases_total_nn, deaths_total_nn, Rt_total_nn, cfr_total_nn = nn_predictor(samples_nn, data, ep, look_ahead)

In [16]:
# PREDICTION USING EpiNow2
# Use EpiNow2 to get the future instead

cases_total_epinow2 = get_data_from_epinow2()

In [141]:
# PREDICTION USING PROPHET
# Use Prophet to get the future instead

cases_total_prophet, deaths_total_prophet = prophet_predictor(data, look_ahead, end_date, nS=100)

INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)
INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)
  components = components.append(new_comp)
  components = components.append(new_comp)
INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)
INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=

In [27]:
skip_crps=True
verbose=False

In [30]:
cases_mse_arma, cases_crps_arma = get_prediction_metrics(cases_total_arma, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)
deaths_mse_arma, deaths_crps_arma = get_prediction_metrics(deaths_total_arma, all_data.new_deaths.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [36]:
cases_mse_nn, cases_crps_nn = get_prediction_metrics(cases_total_nn, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)
deaths_mse_nn, deaths_crps_nn = get_prediction_metrics(deaths_total_nn, all_data.new_deaths.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [154]:
cases_mse_prophet, cases_crps_prophet = get_prediction_metrics(cases_total_prophet, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)
deaths_mse_prophet, deaths_crps_prophet = get_prediction_metrics(deaths_total_prophet, all_data.new_deaths.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [152]:
cases_mse_epinow2, cases_crps_epinow2 = get_prediction_metrics(cases_total_epinow2, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [166]:
prediction_date = pd.to_datetime(end_date)

# plot infections, cases and deaths
    
for reg in range(19,20):
    plot_graph((cases_total_nn, all_data.new_cases.data), 'cases_nn.png', prediction_date, title='Cases Predicted by Epi-NN' ,label=['Prediction', 'Actual Cases'], region=reg)
    plot_graph((deaths_total_nn, all_data.new_deaths.data), 'deaths_nn.png', prediction_date, title='Deaths Predicted by Epi-NN', region=reg)
    plot_graph(Rt_total_nn, 'R_nn.png', prediction_date, region=reg, percentiles=None, label=False)
    plot_graph(cfr_total_nn, 'cfr_nn.png', prediction_date, region=reg)

    
    plot_graph((cases_total_arma, all_data.new_cases.data), 'cases_arma.png', prediction_date, title='Cases Predicted by Epi-ARMA' ,label=['Prediction', 'Actual Cases'], region=reg)
    plot_graph((deaths_total_arma, all_data.new_deaths.data), 'deaths_arma.png', prediction_date, title='Deaths Predicted by Epi-ARMA', label=['Prediction', 'Actual Deaths'], region=reg)
    plot_graph(Rt_total_arma, 'R_arma.png', prediction_date, title='R_t Predicted by Epi-ARMA', region=reg)
    plot_graph(Rt_noise_total_arma, 'R_noise_arma.png', prediction_date, title='R_{noise,t} Predicted by Epi-ARMA', region=reg)
    plot_graph(Rt_cms_total_arma, 'R_cms_arma.png', prediction_date, title='R_cms Predicted by Epi-ARMA', region=reg)
    
    plot_graph((cases_total_prophet, all_data.new_cases.data), 'cases_prophet.png', prediction_date, title='Cases Predicted by Prophet', label=['Prediction', 'Actual Cases'], region=reg)
    plot_graph((deaths_total_prophet, all_data.new_deaths.data), 'deaths_prophet.png', prediction_date, title='Deaths Predicted by Prophet', label=['Prediction', 'Actual Deaths'], region=reg)

    plot_graph((cases_total_epinow2[:,:,:-8], all_data.new_cases.data), 'cases_EpiNow2.png', prediction_date, title='Cases Predicted by EpiNow2', label=['Prediction', 'Actual Cases'], region=reg)

    
plot_graph((cases_mse_nn, cases_mse_arma, cases_mse_prophet, cases_mse_epinow2[:-8]), 'cases_mse.png', prediction_date, title='Cases MSE' ,label=['Epi-NN', 'Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph((deaths_mse_nn, deaths_mse_arma, deaths_mse_prophet), 'deaths_mse.png', prediction_date, title='Deaths MSE' ,label=['Epi-NN', 'Epi-ARMA','Prophet'], region=None)

plot_graph((cases_crps_nn, cases_crps_arma, cases_crps_prophet, cases_crps_epinow2[:-8]), 'cases_crps.png', prediction_date, title='Cases CRPS' ,label=['Epi-NN', 'Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph((deaths_crps_nn, deaths_crps_arma, deaths_crps_prophet), 'deaths_crps.png', prediction_date, title='Deaths CRPS' ,label=['Epi-NN', 'Epi-ARMA','Prophet'], region=None)

    



NameError: name 'cases_nrmse_nn' is not defined

<Figure size 432x288 with 0 Axes>

In [39]:
prediction_date = pd.to_datetime(end_date)

plot_graph((cases_mse_nn, cases_mse_arma), 'cases_mse.png', prediction_date, title='Cases MSE' ,label=['Epi-NN', 'Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph((deaths_mse_nn, deaths_mse_arma, ), 'deaths_mse.png', prediction_date, title='Deaths MSE' ,label=['Epi-NN', 'Epi-ARMA','Prophet'], region=None)

plot_graph((cases_crps_nn, cases_crps_arma), 'cases_crps.png', prediction_date, title='Cases CRPS' ,label=['Epi-NN', 'Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph((deaths_crps_nn, deaths_crps_arma), 'deaths_crps.png', prediction_date, title='Deaths CRPS' ,label=['Epi-NN', 'Epi-ARMA','Prophet'], region=None)

    

<Figure size 432x288 with 0 Axes>