In [1]:
# Imports
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append("../") 

import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpyro
import pandas as pd
import arviz as az


numpyro.set_host_device_count(4)

from epimodel import run_model, default_model, arma_model, latent_nn_model, EpidemiologicalParameters, preprocess_data
from epimodel.models.model_predict_utils import *
from epimodel.models.model_build_utils import *

Importing plotly failed. Interactive plots will not work.


In [2]:
# dates for data processing
start_date="2020-08-01"
end_date="2021-01-02"

# Gets data
data = preprocess_data('../data/all_merged_data_2021-01-22.csv', start_date=start_date, end_date=end_date)
all_data = preprocess_data('../data/all_merged_data_2021-01-22.csv')
# data = preprocess_data('../data/medium_dataset.csv', start_date=start_date, end_date=end_date)
# all_data = preprocess_data('../data/medium_dataset.csv')

data.featurize()
all_data.featurize()


Processing data from 2020-08-01 00:00:00 to 2021-01-02 00:00:00
Processing data from 2020-08-01 00:00:00 to 2021-01-22 00:00:00
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded


In [3]:
# Num days to predict after end of data

look_ahead = 20

ep = EpidemiologicalParameters()

In [4]:
samples = []

for sample_name in os.listdir('samples/nn/cross_validate'):

    f_name = 'samples/nn/cross_validate/' + sample_name
    
    samples.append(netcdf_to_dict(f_name))


In [5]:
for sample in samples:
    
    print(samples_to_model_settings(sample))

{'n_days_nn_input': 21, 'infer_cfr': False, 'input_death': False, 'D_layers': [10, 2], 'R_period': 1, 'bnn_regulariser': 0.3, 'preprocessed_data': 'summary'}
{'n_days_nn_input': 21, 'infer_cfr': False, 'input_death': False, 'D_layers': [5, 5, 5], 'R_period': 1, 'bnn_regulariser': 0.3, 'preprocessed_data': 'summary'}
{'n_days_nn_input': 21, 'infer_cfr': False, 'input_death': True, 'D_layers': [5, 5, 5], 'R_period': 1, 'bnn_regulariser': 0.3, 'preprocessed_data': 'summary'}
{'n_days_nn_input': 21, 'infer_cfr': False, 'input_death': True, 'D_layers': [10, 2], 'R_period': 1, 'bnn_regulariser': 0.3, 'preprocessed_data': 'summary'}
{'n_days_nn_input': 21, 'infer_cfr': False, 'input_death': True, 'D_layers': [5, 5, 5], 'R_period': 1, 'bnn_regulariser': 1, 'preprocessed_data': 'summary'}
{'n_days_nn_input': 21, 'infer_cfr': False, 'input_death': False, 'D_layers': [10, 2], 'R_period': 1, 'bnn_regulariser': 1, 'preprocessed_data': 'summary'}
{'n_days_nn_input': 21, 'infer_cfr': False, 'input_de

In [6]:
samples_all_data_arma = netcdf_to_dict('samples/arma/p1q0_all.netcdf')

true_cases_arma = np.percentile(samples_all_data_arma['expected_cases'], 50, axis=0) 
true_deaths_arma = np.percentile(samples_all_data_arma['expected_deaths'], 50, axis=0) 

In [7]:
case_totals_list = []
deaths_total_list = []
Rt_total_list = []

cases_mse_list = []
deaths_mse_list = []

skip_crps = True
verbose = False

for sample in samples:
    
    cases_total_nn, deaths_total_nn, Rt_total_nn, cfr_total_nn = nn_predictor(sample, data, ep, look_ahead)
    
    case_totals_list.append(cases_total_nn)
    deaths_total_list.append(deaths_total_list)
    Rt_total_list.append(deaths_total_list)
    
    cases_mse, cases_crps = get_prediction_metrics(cases_total_nn, true_cases_arma, data.nDs, verbose=False, skip_crps=True)
    deaths_mse, deaths_crps = get_prediction_metrics(deaths_total_nn, true_deaths_arma, data.nDs, verbose=False, skip_crps=True)
    
    cases_mse_list.append(cases_mse)
    deaths_mse_list.append(deaths_mse)


In [9]:
prediction_date = pd.to_datetime(end_date)

plot_graph(tuple(cases_mse_list), 'nn_cases_mse.png', prediction_date, region=None)
plot_graph(tuple(deaths_mse_list), 'nn_deaths_mse.png', prediction_date, region=None)


<Figure size 432x288 with 0 Axes>

In [13]:
plot_graph((Rt_total_nn), '1.png', prediction_date, region=0)


<Figure size 432x288 with 0 Axes>