In [45]:
# Imports
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../") 

import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpyro
import pandas as pd
import arviz as az

numpyro.set_host_device_count(4)

from epimodel import run_model, default_model, arma_model, latent_nn_model, EpidemiologicalParameters, preprocess_data
from epimodel.models.model_predict_utils import *
from epimodel.models.model_build_utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
# dates for data processing
start_date="2020-08-01"
# end_date="2021-01-02"
end_date="2020-12-10"

# Gets data
data = preprocess_data('../data/all_merged_data_2021-01-22.csv', start_date=start_date, end_date=end_date)
all_data = preprocess_data('../data/all_merged_data_2021-01-22.csv')

data.featurize()
all_data.featurize()


Processing data from 2020-08-01 00:00:00 to 2021-01-02 00:00:00
Processing data from 2020-08-01 00:00:00 to 2021-01-22 00:00:00
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded
Note: under drop_outdoor gathering aggregation, the gatherings_aggregation_type is disregarded


In [47]:
# Num days to predict after end of data

look_ahead = 23

ep = EpidemiologicalParameters()

In [51]:
samples = []

for sample_name in os.listdir('samples/arma/learn/test'):

    f_name = 'samples/arma/learn/test/' + sample_name
    try:
        samples.append(netcdf_to_dict(f_name))
    except:
        continue


In [54]:
prediction_date = pd.to_datetime(end_date)


plot_graph(tuple(cases_mse_list), '1.png', prediction_date, region=None)
plot_graph(tuple(deaths_mse_list), '2.png', prediction_date, region=None)

plot_graph(tuple(cases_crps_list), '3.png', prediction_date, region=None)
plot_graph(tuple(deaths_crps_list), '4.png', prediction_date, region=None)



<Figure size 432x288 with 0 Axes>

In [50]:
all_settings = []
for sample in samples:
    
    all_settings.append(samples_to_model_settings(sample))
    print(all_settings[-1])

In [43]:
future_cms = all_data.active_cms[:,:, data.nDs:]

cases_list = []
deaths_list = []
Rt_list = []
Rt_cms_list = []
Rt_noise_list = []

cases_mse_list = []
deaths_mse_list = []

cases_crps_list = []
deaths_crps_list = []

cases_mse_sums_list = []
deaths_mse_sums_list = []

cases_crps_sums_list = []
deaths_crps_sums_list = []

# gets expected cases and deaths

for ind, sample in enumerate(samples):
    
    cases_total_arma, deaths_total_arma, Rt_total_arma, \
            Rt_cms_total_arma, Rt_noise_total_arma = arma_noise_predictor(sample, 
                                                                          ep,                  
                                                                          data.active_cms, 
                                                                          look_ahead, 
                                                                          future_cms=future_cms,        
                                                                          ignore_last_days=0)
    cases_list.append(cases_total_arma)
    deaths_list.append(deaths_total_arma)
    Rt_list.append(Rt_total_arma)
    Rt_cms_list.append(Rt_cms_total_arma)
    Rt_noise_list.append(Rt_noise_total_arma)

    cases_mse_arma, cases_crps_arma = get_prediction_metrics(cases_total_arma, all_data.new_cases.data, data.nDs, verbose=False, skip_crps=False)
    deaths_mse_arma, deaths_crps_arma = get_prediction_metrics(deaths_total_arma, all_data.new_deaths.data, data.nDs, verbose=False, skip_crps=False)

    cases_mse_list.append(cases_mse_arma)
    deaths_mse_list.append(deaths_mse_arma)
    cases_crps_list.append(cases_crps_arma)
    deaths_crps_list.append(deaths_crps_arma)
    
    cases_mse_sums_list.append(jnp.sum(cases_mse_arma[-look_ahead:]))
    deaths_mse_sums_list.append(jnp.sum(deaths_mse_arma[-look_ahead:]))
    cases_crps_sums_list.append(jnp.sum(cases_crps_arma[-look_ahead:]))
    deaths_crps_sums_list.append(jnp.sum(deaths_crps_arma[-look_ahead:]))

In [32]:
cases_crps_sums_list

[DeviceArray(2547.1824, dtype=float32),
 DeviceArray(2513.879, dtype=float32),
 DeviceArray(2489.246, dtype=float32),
 DeviceArray(2596.3713, dtype=float32),
 DeviceArray(2505.2095, dtype=float32),
 DeviceArray(3113.7966, dtype=float32),
 DeviceArray(2625.2368, dtype=float32),
 DeviceArray(2511.5679, dtype=float32),
 DeviceArray(2592.6082, dtype=float32),
 DeviceArray(3070.517, dtype=float32),
 DeviceArray(2528.1345, dtype=float32),
 DeviceArray(2494.2434, dtype=float32),
 DeviceArray(2636.453, dtype=float32),
 DeviceArray(2664.6462, dtype=float32)]

In [36]:
deaths_crps_sums_list

[DeviceArray(77.01107, dtype=float32),
 DeviceArray(76.43733, dtype=float32),
 DeviceArray(76.26774, dtype=float32),
 DeviceArray(76.34344, dtype=float32),
 DeviceArray(76.34181, dtype=float32),
 DeviceArray(75.58963, dtype=float32),
 DeviceArray(75.9007, dtype=float32),
 DeviceArray(76.3267, dtype=float32),
 DeviceArray(76.924866, dtype=float32),
 DeviceArray(72.96018, dtype=float32),
 DeviceArray(76.70638, dtype=float32),
 DeviceArray(76.28216, dtype=float32),
 DeviceArray(76.106834, dtype=float32),
 DeviceArray(77.377464, dtype=float32)]

In [42]:
prediction_date = pd.to_datetime(end_date)


plot_graph(tuple(cases_mse_list), 'arma_p_q_(ignore_last_0)_cases_mse.png', prediction_date, region=None)
plot_graph(tuple(deaths_mse_list), 'arma_p_q_(ignore_last_0)_deaths_mse.png', prediction_date, region=None)

plot_graph(tuple(cases_crps_list), 'arma_p_q_(ignore_last_0)_cases_crps.png', prediction_date, region=None)
plot_graph(tuple(deaths_crps_list), 'arma_p_q_(ignore_last_0)_deaths_crps.png', prediction_date, region=None)



<Figure size 432x288 with 0 Axes>

In [None]:
# BEST IS p = max, q = min so take this sample

sample = netcdf_to_dict('samples/arma/learn/cross_validate/')


In [44]:
future_cms = all_data.active_cms[:,:, data.nDs:]

cases_list = []
deaths_list = []
Rt_list = []
Rt_cms_list = []
Rt_noise_list = []

cases_mse_list = []
deaths_mse_list = []



for ignore_last_days in range(0,28,2):
    
    cases_total_arma, deaths_total_arma, Rt_total_arma, \
            Rt_cms_total_arma, Rt_noise_total_arma = arma_noise_predictor(sample, 
                                                                          ep,                  
                                                                          data.active_cms, 
                                                                          look_ahead, 
                                                                          future_cms=future_cms,        
                                                                          ignore_last_days=ignore_last_days)
    cases_list.append(cases_total_arma)
    deaths_list.append(deaths_total_arma)
    Rt_list.append(Rt_total_arma)
    Rt_cms_list.append(Rt_cms_total_arma)
    Rt_noise_list.append(Rt_noise_total_arma)

    cases_mse_arma, cases_crps_arma = get_prediction_metrics(cases_total_arma, true_cases_arma, data.nDs, verbose=False, skip_crps=True)
    deaths_mse_arma, deaths_crps_arma = get_prediction_metrics(deaths_total_arma, true_deaths_arma, data.nDs, verbose=False, skip_crps=True)

    cases_mse_list.append(cases_mse_arma)
    deaths_mse_list.append(deaths_mse_arma)

NameError: name 'true_cases_arma' is not defined

In [11]:
prediction_date = pd.to_datetime(end_date)


plot_graph(tuple(cases_mse_list), 'arma_ignore_last_cases_mse.png', prediction_date, region=None)
plot_graph(tuple(deaths_mse_list), 'arma_ignore_last_deaths_mse.png', prediction_date, region=None)


<Figure size 432x288 with 0 Axes>

In [None]:
# BEST IS IGNORE_LAST = 4/5 days