In [2]:
# Imports
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../") 
sys.path.append("../../") 

import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpyro
import pandas as pd

import arviz as az

numpyro.set_host_device_count(4)

from epimodel import run_model_with_settings, default_model, latent_nn_model_legacy, arma_model, latent_nn_model, EpidemiologicalParameters, preprocess_data
from epimodel.models.model_predict_utils import *
from epimodel.models.model_build_utils import *

Importing plotly failed. Interactive plots will not work.


In [3]:
# get all_data and settings
# end_date = "2020-12-10"
end_date = "2021-01-02"


all_data = preprocess_data('../../data/all_merged_data_2021-01-22.csv')
all_data.featurize()

data = preprocess_data('../../data/all_merged_data_2021-01-22.csv', end_date=end_date)
data.featurize()

ep = EpidemiologicalParameters()

Processing data from 2020-08-01 00:00:00 to 2021-01-22 00:00:00
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded
Processing data from 2020-08-01 00:00:00 to 2021-01-02 00:00:00
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded


In [None]:
model_settings_nn = {'input_death':True,
                    'preprocessed_data': 'None',
                    'num_percentiles':None,
                    'n_days_nn_input':21,
                    'D_layers': [15,10,5],
                    'report_period':False,
                    'infer_cfr':True}

run_model_with_settings(
    end_date=end_date,
    num_samples=10,
    num_warmup=30,
    save_results=True,
    model_kwargs=model_settings_nn,
)

Processing data from 2020-08-01 00:00:00 to 2021-01-02 00:00:00
Running 4 chains, 10 per chain with 30 warmup steps
Warmup Started: 2022-08-21 11:54:17


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [21]:
# Gets samples from pre-run models

samples_arma = netcdf_to_dict('../samples/arma/learn/test/2021-01-02_p2q1_learn.netcdf')

samples_nn = netcdf_to_dict('../Predict/None.netcdf')

In [23]:
jnp.shape(samples_nn['Rt'])

(1000, 114, 155)

In [56]:
# R = np.array([])

# for region in range(114):
#     for time in range(15,132):
#         new = (np.array([np.array(np.corrcoef(samples_nn['cfr'][:,region,time], samples_nn['Rt'][:,region,time])[0,1])]))
#         R = np.concatenate((R,new))

        
# print(np.mean(R))
# print(np.std(R))

0.01490266668910671
0.3570980510029074


In [14]:
prediction_date = pd.to_datetime(end_date)

samp = list(range(0,40)) 
# plot infections, cases and deaths
    
for reg in range(0,3):
    plot_graph((samples_nn['expected_cases'][:,:,:], data.new_cases.data), 'cases_nn.png', percentiles=None, label=['Inference', 'Actual Cases'], title='Cases Inferred by Epi-NN', region=reg)
    plot_graph((samples_nn['expected_deaths'][:,:,:], data.new_deaths.data), 'deaths_nn.png',percentiles=None,label=['Inference', 'Actual Deaths'],title='Deaths Inferred by Epi-NN', region=reg)
    plot_graph((samples_nn['total_infections'][:,:,:]), 'infections_nn.png',percentiles=None,label=['Inference', 'Actual Deaths'],title='Infections Inferred by Epi-NN', region=reg)
    plot_graph(samples_nn['Rt'][:,:,:], 'R_nn.png',  region=reg, percentiles=None,title='R_t Inferred by Epi-NN')
    plot_graph(samples_nn['cfr'][:,:,7:], 'cfr_nn.png', region=reg,percentiles=None, title='cfr_t Inferred by Epi-NN')

    



<Figure size 432x288 with 0 Axes>

In [4]:
# Num days to predict after end of data

look_ahead = 20

In [16]:
# PREDICTION USING ARMA METHOD
# predict future cases and death using the arma method

# gets the cms just after this point
future_cms = all_data.active_cms[:,:, data.nDs:]

# gets expected cases and deaths
cases_total_arma, deaths_total_arma, Rt_total_arma, \
            Rt_cms_total_arma, Rt_noise_total_arma = arma_noise_predictor(samples_arma, 
                                                                          ep,                  
                                                                          data.active_cms, 
                                                                          look_ahead, 
                                                                          future_cms=future_cms,        
                                                                          ignore_last_days=0)


In [17]:
# PREDICTION USING NEURAL NET METHOD

# predict future cases and death using the latent nn method

cases_total_nn, deaths_total_nn, Rt_total_nn, cfr_total_nn = nn_predictor(samples_nn, data, ep, look_ahead)

In [17]:
# PREDICTION USING EpiNow2
# Use EpiNow2 to get the future instead

cases_total_epinow2 = get_data_from_epinow2()

In [18]:
# PREDICTION USING PROPHET
# Use Prophet to get the future instead

cases_total_prophet, deaths_total_prophet = prophet_predictor(data, look_ahead, end_date, nS=100)

INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)
INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)
  components = components.append(new_comp)
  components = components.append(new_comp)
INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)
INFO:prophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=

In [7]:
skip_crps=True
verbose=False

In [20]:
cases_mse_arma, cases_crps_arma = get_prediction_metrics2(cases_total_arma, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)
deaths_mse_arma, deaths_crps_arma = get_prediction_metrics2(deaths_total_arma, all_data.new_deaths.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [18]:
cases_mse_nn, cases_crps_nn = get_prediction_metrics2(cases_total_nn, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)
deaths_mse_nn, deaths_crps_nn = get_prediction_metrics2(deaths_total_nn, all_data.new_deaths.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [21]:
cases_mse_prophet, cases_crps_prophet = get_prediction_metrics2(cases_total_prophet, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)
deaths_mse_prophet, deaths_crps_prophet = get_prediction_metrics2(deaths_total_prophet, all_data.new_deaths.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [None]:
cases_mse_epinow2, cases_crps_epinow2 = get_prediction_metrics2(cases_total_epinow2, all_data.new_cases.data, data.nDs, skip_crps=skip_crps, verbose=verbose)

In [21]:
jnp.sum(cases_mse_nn[-20:])

DeviceArray(27.804903, dtype=float32)

In [8]:
df = pd.read_csv('arma_test.csv')


cases_mse_arma = df['arma_case_nmse'].to_numpy()
deaths_mse_arma = df['arma_death_nmse'].to_numpy() 
cases_crps_arma = df['arma_case_ncrps'].to_numpy() 
deaths_crps_arma = df['arma_death_ncrps'].to_numpy()

cases_mse_prophet = df['prophet_case_nmse'].to_numpy() 
deaths_mse_prophet = df['prophet_death_nmse'].to_numpy() 
cases_crps_prophet = df['prophet_case_ncrps'].to_numpy()
deaths_crps_prophet = df['prophet_death_ncrps'].to_numpy()

cases_mse_epinow2 = df['EpiNow2_case_nmse'].to_numpy()
cases_crps_epinow2 = df['EpiNow2_case_ncrps'].to_numpy()

cases_mse_nn = df['nn_case_nmse']
deaths_mse_nn = df['nn_death_nmse']
cases_crps_nn = df['nn_case_ncrps']
deaths_crps_nn = df['nn_death_ncrps']

In [20]:
np.sum(deaths_crps_nn[-20:])

14.0606425

In [15]:
prediction_date = pd.to_datetime(end_date)

# plot infections, cases and deaths
    
# for reg in range(58,59):
#     plot_graph((cases_total_nn, all_data.new_cases.data), 'cases_nn.png', prediction_date, title='Cases Predicted by Epi-NN' ,label=['Prediction', 'Actual Cases'], region=reg)
#     plot_graph((deaths_total_nn, all_data.new_deaths.data), 'deaths_nn.png', prediction_date, title='Deaths Predicted by Epi-NN', region=reg)
#     plot_graph(Rt_total_nn, 'R_nn.png', prediction_date, region=reg, percentiles=None, label=False)
#     plot_graph(cfr_total_nn, 'cfr_nn.png', prediction_date, region=reg)

    
#     plot_graph((cases_total_arma, all_data.new_cases.data), 'cases_arma.png', prediction_date, title='Cases Predicted by Epi-ARMA' ,label=['Prediction', 'Actual Cases'], region=reg)
#     plot_graph((deaths_total_arma, all_data.new_deaths.data), 'deaths_arma.png', prediction_date, title='Deaths Predicted by Epi-ARMA', label=['Prediction', 'Actual Deaths'], region=reg)
#     plot_graph(Rt_total_arma, 'R_arma.png', prediction_date, title='R_t Predicted by Epi-ARMA', region=reg)
#     plot_graph(Rt_noise_total_arma, 'R_noise_arma.png', prediction_date, title='R_{noise,t} Predicted by Epi-ARMA', region=reg)
#     plot_graph(Rt_cms_total_arma, 'R_cms_arma.png', prediction_date, title='R_cms Predicted by Epi-ARMA', region=reg)
    
#     plot_graph((cases_total_prophet, all_data.new_cases.data), 'cases_prophet.png', prediction_date, title='Cases Predicted by Prophet', label=['Prediction', 'Actual Cases'], region=reg)
#     plot_graph((deaths_total_prophet, all_data.new_deaths.data), 'deaths_prophet.png', prediction_date, title='Deaths Predicted by Prophet', label=['Prediction', 'Actual Deaths'], region=reg)

#     plot_graph((cases_total_epinow2[:,:,:-8], all_data.new_cases.data), 'cases_EpiNow2.png', prediction_date, title='Cases Predicted by EpiNow2', label=['Prediction', 'Actual Cases'], region=reg)

    
plot_graph((cases_mse_nn, cases_mse_arma, cases_mse_prophet[:-1], cases_mse_epinow2[:-8]), 'cases_nmse.png', prediction_date, title='Cases NMSE' ,label=['Epi-NN', 'Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph((deaths_mse_nn, deaths_mse_arma, deaths_mse_prophet[:-1]), 'deaths_nmse.png', prediction_date, title='Deaths NMSE' ,label=['Epi-NN', 'Epi-ARMA','Prophet'], region=None)

plot_graph((cases_crps_nn, cases_crps_arma, cases_crps_prophet[:-1], cases_crps_epinow2[:-8]), 'cases_ncrps.png', prediction_date, title='Cases NCRPS' ,label=['Epi-NN', 'Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph((deaths_crps_nn, deaths_crps_arma, deaths_crps_prophet[:-1]), 'deaths_ncrps.png', prediction_date, title='Deaths NCRPS' ,label=['Epi-NN', 'Epi-ARMA','Prophet'], region=None)

    

<Figure size 432x288 with 0 Axes>

In [23]:
prediction_date = pd.to_datetime(end_date)

# plot infections, cases and deaths
    
plot_graph((cases_mse_arma, cases_mse_prophet, cases_mse_epinow2[:-3]), 'arma_cases_nmse.png', prediction_date, title='Cases NMSE' ,label=['Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph(( deaths_mse_arma, deaths_mse_prophet), 'arma_deaths_nmse.png', prediction_date, title='Deaths NMSE' ,label=['Epi-ARMA','Prophet'], region=None)

plot_graph((cases_crps_arma, cases_crps_prophet, cases_crps_epinow2[:-3]), 'arma_cases_ncrps.png', prediction_date, title='Cases NCRPS' ,label=['Epi-ARMA','Prophet', 'EpiNow2'], region=None)
plot_graph((deaths_crps_arma, deaths_crps_prophet), 'arma_deaths_ncrps.png', prediction_date, title='Deaths NCRPS' ,label=['Epi-ARMA','Prophet'], region=None)

    

<Figure size 432x288 with 0 Axes>

In [10]:
jnp.sum(cases_mse_arma[-20:])

DeviceArray(14.342705, dtype=float32)

In [11]:
jnp.sum(cases_crps_arma[-20:])

DeviceArray(9.071289, dtype=float32)