In [1]:
import sys
import os
import numpy as np
import multiprocessing
import pickle
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import numpy as np
from statsmodels.api import OLS, add_constant
import time
from scipy.stats import norm
import datetime as dt

from vax_sims_LHS_samples import *
from plot_utils import *

In [2]:
UNCERTAINTY_PARAMS = ['vax_susc_mult', 'vax_transmission_mult', 'contacts_per_day_mult', 'outside_infection_rate_mult',
                      'cases_isolated_per_contact_trace', 'initial_ID_prevalence']

UNCERTAINTY_PARAM_RANGES = {
    'vax_susc_mult': (0.097608, 0.941192), # 0.5194 +/- 1.96 * 0.2152
    'vax_transmission_mult': (0.25, 1),
    'contacts_per_day_mult': (0.9,2.7),
    'outside_infection_rate_mult': (1, 5),
    'cases_isolated_per_contact_trace': (0.5,1.5),
    'initial_ID_prevalence': (0.003, 0.0054)
}

def sample_from_prior():
    return_point = list()
    for param in UNCERTAINTY_PARAMS:
        mean = np.mean(UNCERTAINTY_PARAM_RANGES[param])
        sd = (UNCERTAINTY_PARAM_RANGES[param][1] - UNCERTAINTY_PARAM_RANGES[param][0])/(2*1.96)
        return_point.append(np.random.normal(mean, sd))
    return return_point

In [3]:
def map_lhs_point_to_vax_sim(lhs_point, param_modifiers=None):
    base_params, base_group_names, contact_matrix, vax_rates = load_calibrated_params()

    vax_susc_mult = lhs_point[0]
    vax_trans_mult = lhs_point[1]

    contact_matrix = contact_matrix * lhs_point[2]

    for params in base_params:
        params['daily_outside_infection_p'] = params['daily_outside_infection_p'] * lhs_point[3]


    base_params[0]['cases_isolated_per_contact'] *= lhs_point[4]
    base_params[1]['cases_isolated_per_contact'] *= lhs_point[4]
    base_params[2]['cases_isolated_per_contact'] *= lhs_point[4]
    
    base_params[0]['initial_ID_prevalence'] = lhs_point[5]
    base_params[1]['initial_ID_prevalence'] = lhs_point[5]
    base_params[2]['initial_ID_prevalence'] = lhs_point[5]
    
    vax_sim = generate_vax_unvax_multigroup_sim(base_params, base_group_names,
                                    vax_rates, contact_matrix,
                                    vax_trans_mult, vax_susc_mult)

    update_vax_sim_params(vax_sim, param_modifiers)

    return vax_sim

In [None]:
# def load_sim_output(folder, npoints, lb=0):
#     scenario_data = pd.DataFrame(columns=UNCERTAINTY_PARAMS+['student_inf_10','student_inf_50','student_inf_90']+\
#                                 ['staff_inf_10', 'staff_inf_50', 'staff_inf_90'])
#     for idx in range(lb, npoints):
#         inf_file = folder + '/list_of_infs_by_group_{}.dill'.format(idx)
#         point_file = folder + '/lhs_point_{}.dill'.format(idx)
#         with open(inf_file, 'rb') as fhandle:
#             inf_matrix = np.array(dill.load(fhandle))
#             student_infxns = np.sum(inf_matrix[:,:-2], axis=1)
#             staff_infxns = np.sum(inf_matrix[:,-2:], axis=1)
#         with open(point_file, 'rb') as fhandle:
#             uncertainty_point = dill.load(fhandle)

#         new_row = dict()
#         for index, col_name in enumerate(UNCERTAINTY_PARAMS):
#             if type(uncertainty_point) == dict:
#                 new_row[col_name] = uncertainty_point[col_name]
#             else:
#                 new_row[col_name] = uncertainty_point[index]
#         new_row['student_inf_10'] = np.quantile(student_infxns, 0.1)
#         new_row['student_inf_50'] = np.quantile(student_infxns, 0.5)
#         new_row['student_inf_90'] = np.quantile(student_infxns, 0.9)
#         new_row['staff_inf_10'] = np.quantile(staff_infxns, 0.1)
#         new_row['staff_inf_50'] = np.quantile(staff_infxns, 0.5)
#         new_row['staff_inf_90'] = np.quantile(staff_infxns, 0.9)
#         new_row['cornell_inf_10'] = np.quantile(student_infxns + staff_infxns, 0.1)
#         new_row['cornell_inf_50'] = np.quantile(student_infxns + staff_infxns, 0.5)
#         new_row['cornell_inf_90'] = np.quantile(student_infxns + staff_infxns, 0.9)
        
#         scenario_data = scenario_data.append(new_row, ignore_index=True)

#     return scenario_data

# def residential_regression_student(scenario_data):
#     columns = scenario_data.columns[0:5]
#     target = 'student_inf_50'
#     X_res = scenario_data[columns]
#     Y_res_outcomes = np.array(scenario_data[[target]])

#     X = add_constant(X_res)
#     model = OLS(Y_res_outcomes,X)
#     results = model.fit()
#     return results

# def residential_regression_staff(scenario_data):
#     columns = scenario_data.columns[0:5]
#     target = 'staff_inf_50'
#     X_res = scenario_data[columns]
#     Y_res_outcomes = np.array(scenario_data[[target]])

#     X = add_constant(X_res)
#     model = OLS(Y_res_outcomes,X)
#     results = model.fit()
#     return results


In [None]:
# def calculate_pessimistic_scenario(results, q=0.99, beta=1.96):
#     # the keys in dict(results.params) specify whether this is for residential
#     # or virtual vs. residential
#     lr_results = dict(results.params)
#     sd_dict = dict()
#     pess_direction = dict()
#     params = set(lr_results.keys()) - set(['const'])
#     centre_infections = lr_results['const']

#     invquantile = norm.ppf(q)

#     for param in params:
#         sd_dict[param] = (UNCERTAINTY_PARAM_RANGES[param][1] - UNCERTAINTY_PARAM_RANGES[param][0])/(2*beta)
#         centre_infections += np.mean(UNCERTAINTY_PARAM_RANGES[param]) * lr_results[param]

#     sum_squares_Sigma_1 = 0

#     for param in params:
#         sum_squares_Sigma_1 += (lr_results[param]*sd_dict[param]) ** 2

#     for param in params:
#         pess_direction[param] = lr_results[param]*(sd_dict[param])**2 / np.sqrt(sum_squares_Sigma_1)

#     mp_pess_scenario = dict()
#     for param in params:
#         mp_pess_scenario[param] = np.mean(UNCERTAINTY_PARAM_RANGES[param]) + invquantile * pess_direction[param]

#     return mp_pess_scenario

In [None]:
# LHS_data = load_sim_output('/home/aaj54/group-testing/notebooks/vax_sims/lhs_vax_sims_test_delay:1631153818.0076091', 200)

In [4]:
PARAMS_POST_MOVEIN = {'ug_ga_vax_test_frequency': 2/7, 'ug_ga_unvax_test_frequency': 2/7,
            'ug_other_vax_test_frequency': 1/7, 'ug_other_unvax_test_frequency': 2/7,
            'grad_vax_test_frequency': 1/7, 'grad_unvax_test_frequency': 2/7,
            'employee_vax_test_frequency': 1/7, 'employee_unvax_test_frequency': 2/7,
            'test_delay': 1, 'max_time_pre_ID': 2}

PARAMS_PRE_MOVEIN = {'ug_ga_vax_test_frequency': 1/7, 'ug_ga_unvax_test_frequency': 2/7,
            'ug_other_vax_test_frequency': 1/7, 'ug_other_unvax_test_frequency': 2/7,
            'grad_vax_test_frequency': 1/7, 'grad_unvax_test_frequency': 2/7,
            'employee_vax_test_frequency': 1/7, 'employee_unvax_test_frequency': 2/7,
            'test_delay': 2, 'max_time_pre_ID': 2}

In [5]:
def get_cum_infections(df):
    return df[['cumulative_mild', 'cumulative_severe']].iloc[df.shape[0]-1].sum()


def get_cum_inf_trajectory(df):
    return np.sum(df[['cumulative_mild', 'cumulative_severe']], axis=1)


def run_new_trajectory(sim, T, change_t, override_premovein_params=None):
    sim.reset_initial_state()
    if override_premovein_params != None:
        update_vax_sim_params(sim, override_premovein_params)
    else:
        update_vax_sim_params(sim, PARAMS_PRE_MOVEIN)
        
    for t in range(T):
        sim.step()
        if t == change_t:
            update_vax_sim_params(sim, PARAMS_POST_MOVEIN)

    for single_group_sim in sim.sims:
        single_group_sim.update_severity_levels()

    sim_df = sim.sims[0].sim_df
    for sim in sim.sims[1:]:
        sim_df = sim_df.add(sim.sim_df)
    return sim_df

CHANGE_T = 6 # start of simulation is aug 23, assume changes occur on Aug 29
def run_multigroup_sim(sim, T, override_premovein_params=None):
    run_new_trajectory(sim, T, CHANGE_T, override_premovein_params)
    inf_trajs_by_group = []
    for group in sim.sims:
        df = group.sim_df
        inf_trajs_by_group.append(get_cum_inf_trajectory(df))
    return inf_trajs_by_group


def get_centre_point():
    centre = {}
    for param in UNCERTAINTY_PARAM_RANGES:
        lb, ub = UNCERTAINTY_PARAM_RANGES[param]
        centre[param] = (lb + ub) / 2
    return centre


def run_multiple_trajs(sim, T, n, override_premovein_params=None):
    infs_by_group_list = []
    for _ in range(n):
        infs_by_group = run_multigroup_sim(sim,T, override_premovein_params)
        infs_by_group_list.append(infs_by_group)
    return infs_by_group_list

def get_timestamp():
    return str(time.time()).split('.')[0]

In [12]:
# param_modifiers= PARAMS_PRE_MOVEIN


# nsamples = 5

# # point is either center or pessimistic
# center_dict = get_centre_point()
# center_point = [center_dict['vax_susc_mult'], center_dict['vax_transmission_mult'], center_dict['contacts_per_day_mult'],
#                 center_dict['outside_infection_rate_mult'], center_dict['cases_isolated_per_contact_trace']]
# center_vax_sim = map_lhs_point_to_vax_sim(center_point, param_modifiers)
# center_inf_trajs_by_group = run_multiple_trajs(center_vax_sim, T=112, n=nsamples)


IndexError: list index out of range

In [6]:
def sample_and_save(point_id, save_folder, nsamples=100, T=50):
    prior_point = sample_from_prior()
    prior_sim = map_lhs_point_to_vax_sim(prior_point, param_modifiers)
    inf_trajs_by_group = run_multiple_trajs(prior_sim, T=T, n=nsamples, override_premovein_params=override_params)
    pickle.dump([prior_point, inf_trajs_by_group], open(save_folder+'point_{}'.format(point_id), 'wb'))
    return [prior_point, inf_trajs_by_group]

    
# for i in range(5):
#     # sample from prior
#     prior_point = sample_from_prior()
#     prior_sim = map_lhs_point_to_vax_sim(prior_point, param_modifiers)
#     inf_trajs_by_group = run_multiple_trajs(prior_sim, T=50, n=nsamples, override_premovein_params=override_params)
#     pickle.dump([prior_point, inf_trajs_by_group], open(save_folder+'point_{}'.format(i), 'wb'))

In [7]:

import multiprocessing as mp
import gc
from joblib import Parallel, delayed
import multiprocessing


param_modifiers = PARAMS_PRE_MOVEIN
override_params = PARAMS_POST_MOVEIN
nsamples = 100
T = 50

save_folder = 'samples_for_posterior:{}/'.format(get_timestamp())
os.mkdir(save_folder)

# inputs = np.arange(0.2,.8,.01)
# def processInput(i):
#     gc.collect()
#     spring_calib = SpringCalibration(i)
#     return spring_calib.run_and_score_trajectories(100)

num_cores = multiprocessing.cpu_count()
results = Parallel(n_jobs=num_cores)(delayed(sample_and_save)(i, save_folder, nsamples=nsamples, T=T) for i in range(1000))

ValueError: lam < 0 or lam is NaN

In [9]:
output_folder = 'samples_for_posterior:1633230015'
for i in range(5):
    [prior_point, inf_trajs_by_group] = pickle.load(open(save_folder+'point_{}'.format(i), 'rb'))
    print(inf_trajs_by_group)

[[0        2.0
1        7.0
2       20.0
3       36.0
4       41.0
5       50.0
6       60.0
7       71.0
8       81.0
9       95.0
10     115.0
11     141.0
12     164.0
13     196.0
14     241.0
15     287.0
16     344.0
17     431.0
18     504.0
19     598.0
20     677.0
21     774.0
22     882.0
23     978.0
24    1057.0
25    1142.0
26    1245.0
27    1362.0
28    1465.0
29    1540.0
30    1552.0
31    1552.0
32    1552.0
33    1552.0
34    1552.0
35    1552.0
36    1552.0
37    1552.0
38    1552.0
39    1552.0
40    1552.0
41    1552.0
42    1552.0
43    1552.0
44    1552.0
45    1552.0
46    1552.0
47    1552.0
48    1552.0
49    1552.0
50    1552.0
dtype: float64, 0      0.0
1      0.0
2      1.0
3      1.0
4      1.0
5      1.0
6      1.0
7      1.0
8      1.0
9      1.0
10     1.0
11     2.0
12     2.0
13     2.0
14     3.0
15     3.0
16     3.0
17     3.0
18     3.0
19     3.0
20     3.0
21     3.0
22     5.0
23     5.0
24    10.0
25    10.0
26    11.0
27    13.0
28    13.0


[[0        2.0
1        7.0
2       13.0
3       16.0
4       21.0
5       32.0
6       67.0
7      112.0
8      169.0
9      232.0
10     324.0
11     425.0
12     553.0
13     714.0
14     866.0
15    1079.0
16    1308.0
17    1556.0
18    1835.0
19    2147.0
20    2439.0
21    2702.0
22    2955.0
23    3117.0
24    3147.0
25    3167.0
26    3177.0
27    3184.0
28    3185.0
29    3185.0
30    3186.0
31    3186.0
32    3186.0
33    3188.0
34    3196.0
35    3196.0
36    3197.0
37    3197.0
38    3197.0
39    3197.0
40    3197.0
41    3197.0
42    3197.0
43    3197.0
44    3197.0
45    3197.0
46    3197.0
47    3197.0
48    3197.0
49    3197.0
50    3197.0
dtype: float64, 0      0.0
1      0.0
2      0.0
3      0.0
4      1.0
5      1.0
6      1.0
7      1.0
8      1.0
9      1.0
10     1.0
11     1.0
12     1.0
13     1.0
14     2.0
15     3.0
16     5.0
17     7.0
18     8.0
19    12.0
20    16.0
21    18.0
22    24.0
23    27.0
24    35.0
25    36.0
26    36.0
27    36.0
28    36.0


[[0        0.0
1        4.0
2       13.0
3       24.0
4       49.0
5       58.0
6       77.0
7      109.0
8      133.0
9      150.0
10     193.0
11     223.0
12     263.0
13     331.0
14     389.0
15     440.0
16     502.0
17     566.0
18     674.0
19     771.0
20     864.0
21     972.0
22    1095.0
23    1210.0
24    1367.0
25    1519.0
26    1707.0
27    1855.0
28    1980.0
29    2126.0
30    2313.0
31    2493.0
32    2698.0
33    2816.0
34    2875.0
35    2911.0
36    2941.0
37    2963.0
38    2971.0
39    2979.0
40    2990.0
41    3019.0
42    3021.0
43    3021.0
44    3021.0
45    3021.0
46    3021.0
47    3021.0
48    3021.0
49    3021.0
50    3021.0
dtype: float64, 0      0.0
1      1.0
2      1.0
3      1.0
4      1.0
5      1.0
6      1.0
7      1.0
8      1.0
9      1.0
10     1.0
11     1.0
12     1.0
13     1.0
14     1.0
15     1.0
16     1.0
17     1.0
18     1.0
19     1.0
20     1.0
21     1.0
22     1.0
23     3.0
24     3.0
25     3.0
26     3.0
27     5.0
28     8.0


[[0        1.0
1        3.0
2       10.0
3       19.0
4       35.0
5       55.0
6       82.0
7      116.0
8      187.0
9      240.0
10     327.0
11     443.0
12     567.0
13     713.0
14     889.0
15    1147.0
16    1429.0
17    1759.0
18    2091.0
19    2448.0
20    2810.0
21    3074.0
22    3295.0
23    3326.0
24    3330.0
25    3336.0
26    3339.0
27    3340.0
28    3343.0
29    3343.0
30    3343.0
31    3343.0
32    3343.0
33    3344.0
34    3345.0
35    3346.0
36    3346.0
37    3346.0
38    3346.0
39    3346.0
40    3346.0
41    3346.0
42    3346.0
43    3346.0
44    3346.0
45    3347.0
46    3347.0
47    3347.0
48    3347.0
49    3347.0
50    3347.0
dtype: float64, 0      0.0
1      0.0
2      1.0
3      2.0
4      2.0
5      2.0
6      2.0
7      2.0
8      2.0
9      3.0
10     4.0
11     4.0
12     4.0
13     4.0
14     4.0
15     4.0
16     6.0
17    12.0
18    17.0
19    19.0
20    23.0
21    26.0
22    34.0
23    40.0
24    49.0
25    54.0
26    66.0
27    68.0
28    68.0


In [8]:
base_params, base_group_names, contact_matrix, vax_rates = load_calibrated_params()

In [28]:
# base_params[0]
base_params[2]['daily_outside_infection_p']

6.45e-06