In [1]:
import sys
import os
import numpy as np
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path + "/src/multi_agent_sim")
from multi_agent_sim import MultiAgentSim

In [None]:

social_network_params_high_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': 2,
    'neg_bin_p_hyperparams': (2,3),
}

n_agents = 1000
expected_total_outside_infs = 5
time_horizon = 7 * 12
main_params = {
    'n_agents': n_agents,
    'use_contact_tracing': True,
    'use_adaptive_testing': True,
    'use_surveillance_testing': True
}

infection_dynamics_params = {
    'forward_gen_alpha': 8,
    'forward_gen_beta_hyperparams': (5,5),
    'detectability_curve_type': 'optimistic',
    'self_reporting_multiplier': 0.8,
    'self_reporting_delay': 3,
    'init_infection_rate': 0,
    'use_deterministic_infection_counts': True,
    'disc_gamma_transmission_mult': 1
    'outside_infection_rate': expected_total_outside_infs / (n_agents * time_horizon)
}

ct_params = {
    'ct_recall_window': 8,
    'ct_delay_distribution': [1/3,1/3,1/3], # uniform over 0, 1, 2 days delay
    'ct_recall_rate': 0.5
}

at_params = {
    'at_delay_distribution': [1/5] * 5,
    'at_net_size_contact_multiplier': 10,
    'at_recall_rate': 0.9
}

st_params = {
    'st_testing_window': 3,
    'st_missed_test_rate': 0.1
}

def init_high_var_sim(R0):
    social_network_params_high_var['neg_bin_r'] = R0 / 0.9
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_high_var, 
                         ct_params, at_params, st_params)

social_network_params_low_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': 2,
    'neg_bin_p_hyperparams': (5,6),
}

#  r * (alpha) / (beta - 1) = 1 = default R0

def init_low_var_sim(R0):
    social_network_params_low_var['disc_gamma_transmission_mult'] = R0 
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_low_var, 
                         ct_params, at_params, st_params)

In [8]:

social_network_params_high_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': 2,
    'neg_bin_p_hyperparams': (2,3),
}

n_agents = 1000
expected_total_outside_infs = 5
time_horizon = 7 * 12
main_params = {
    'n_agents': n_agents,
    'use_contact_tracing': True,
    'use_adaptive_testing': True,
    'use_surveillance_testing': True
}


infection_dynamics_params = {
    'forward_gen_alpha': 8,
    'forward_gen_beta_hyperparams': (5,5),
    'detectability_curve_type': 'optimistic',
    'self_reporting_multiplier': 0.8,
    'self_reporting_delay': 3,
    'init_infection_rate': 0,
    'use_deterministic_infection_counts': True,
    'disc_gamma_transmission_mult': 1,
    'outside_infection_rate': expected_total_outside_infs / (n_agents * time_horizon)
}

ct_params = {
    'ct_recall_window': 4,
    'ct_delay_distribution': [1/3,1/3,1/3], # uniform over 0, 1, 2 days delay
    'ct_recall_rate': 0.5,
    'min_delay_between_traces': 14
}

at_params = {
    'at_delay_distribution': [1/3,1/3,1/3],
    'at_net_size_contact_multiplier': 10,
    'at_recall_rate': 0.9
}

st_params = {
    'st_testing_window': 3,
    'st_missed_test_rate': 0.1
}

def init_high_var_sim(R0):
    social_network_params_high_var['neg_bin_r'] = R0 / 0.9
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_high_var, 
                         ct_params, at_params, st_params)

neg_bin_r = 2
social_network_params_low_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': neg_bin_r,
    'neg_bin_p_hyperparams': (5,6)
}

# avg close contacts / day = 2 * 5 / (6-1)

def init_low_var_sim(R0):
    social_network_params_low_var['disc_gamma_transmission_mult'] = R0 / neg_bin_r
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_low_var, 
                         ct_params, at_params, st_params)

In [6]:
def sample_trajectory(R0, at_multiplier, st_testing_window, 
                      use_adaptive_testing = True, use_surveillance_testing = True):
    st_params['st_testing_window'] = st_testing_window
    at_params['at_net_size_contact_multiplier'] = at_multiplier
    main_params['use_adaptive_testing'] = use_adaptive_testing
    main_params['use_surveillance_testing'] = use_surveillance_testing
    sim = init_low_var_sim(R0)
    infection_counts = [len(sim.infection.get_cum_infected_agent_ids())]
    init_infection_counts = infection_counts[0]
    for _ in range(time_horizon):
        sim.step()
        infection_counts.append(len(sim.infection.get_cum_infected_agent_ids()))
    total_tests = sim.get_total_tests()
    return total_tests, infection_counts, init_infection_counts, infection_counts[-1]

In [9]:
from multiprocessing import Process
import pickle


def sim_target_f(R0, at_mult, st_window, use_at, use_st, ntrajectories, pickle_file_loc):
    results = [sample_trajectory(R0, at_mult, st_window, use_at, use_st) for _ in range(ntrajectories)]
    pickle.dump(results, open(pickle_file_loc, "wb"))

def run_sims_new_process(R0, at_mult, st_window, use_at, use_st, ntrajectories, pickle_file_loc):
    p = Process(target = sim_target_f, args = (R0, at_mult, st_window, use_at, use_st, ntrajectories, pickle_file_loc))
    p.start()
    return p




In [None]:
R0s_to_try = [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5]

at_mults = [4,8,12,16]

st_freqs = [4,7,10,14, 18, 21]

adaptive_test_policies = []

surveillance_test_policies = []

no_use = 1

for st_freq in st_freqs:
    surveillance_test_policies.append((no_use, st_freq, False, True))
    for at_mult in at_mults:
        adaptive_test_policies.append((at_mult, st_freq, True, True))
        

ntrajectories = 250

sim_ps = {}

from datetime import datetime

all_policies = adaptive_test_policies + surveillance_test_policies

output_dir = "jan_05_multiagent_sims/"

for R0 in R0s_to_try:
    print("on R0 {}\n".format(R0))
    for policy in all_policies:
        at_mult, st_window, use_at, use_st = policy
        print("on policy {}\n".format(policy))
        print("Timestamp: {}\n".format(datetime.now()))
        sim_ps[(R0, policy)] = run_sims_new_process(R0, at_mult, st_window, 
                                                    use_at, use_st, ntrajectories, 
                                                    output_dir + str((R0, policy)) + ".pickle")
    
print("launched {} processes".format(sim_ps.values()))
for p in sim_ps.values():
    p.join()

on R0 1

on policy (4, 4, True, True)

Timestamp: 2021-01-05 22:45:03.791786

on policy (8, 4, True, True)

Timestamp: 2021-01-05 22:45:03.811086

on policy (12, 4, True, True)

Timestamp: 2021-01-05 22:45:03.818907

on policy (16, 4, True, True)

Timestamp: 2021-01-05 22:45:03.827410

on policy (4, 7, True, True)

Timestamp: 2021-01-05 22:45:03.834658

on policy (8, 7, True, True)

Timestamp: 2021-01-05 22:45:03.841408

on policy (12, 7, True, True)

Timestamp: 2021-01-05 22:45:03.847713

on policy (16, 7, True, True)

Timestamp: 2021-01-05 22:45:03.852947

on policy (4, 10, True, True)

Timestamp: 2021-01-05 22:45:03.859117

on policy (8, 10, True, True)

Timestamp: 2021-01-05 22:45:03.865734

on policy (12, 10, True, True)

Timestamp: 2021-01-05 22:45:03.871914

on policy (16, 10, True, True)

Timestamp: 2021-01-05 22:45:03.878113

on policy (4, 14, True, True)

Timestamp: 2021-01-05 22:45:03.883701

on policy (8, 14, True, True)

Timestamp: 2021-01-05 22:45:03.890247

on policy (12

on policy (1, 14, False, True)

Timestamp: 2021-01-05 22:45:04.921517

on policy (1, 18, False, True)

Timestamp: 2021-01-05 22:45:04.939313

on policy (1, 21, False, True)

Timestamp: 2021-01-05 22:45:04.951280

on R0 3

on policy (4, 4, True, True)

Timestamp: 2021-01-05 22:45:04.962913

on policy (8, 4, True, True)

Timestamp: 2021-01-05 22:45:04.973057

on policy (12, 4, True, True)

Timestamp: 2021-01-05 22:45:04.986931

on policy (16, 4, True, True)

Timestamp: 2021-01-05 22:45:04.997554

on policy (4, 7, True, True)

Timestamp: 2021-01-05 22:45:05.021935

on policy (8, 7, True, True)

Timestamp: 2021-01-05 22:45:05.033437

on policy (12, 7, True, True)

Timestamp: 2021-01-05 22:45:05.043292

on policy (16, 7, True, True)

Timestamp: 2021-01-05 22:45:05.053207

on policy (4, 10, True, True)

Timestamp: 2021-01-05 22:45:05.065102

on policy (8, 10, True, True)

Timestamp: 2021-01-05 22:45:05.073102

on policy (12, 10, True, True)

Timestamp: 2021-01-05 22:45:05.123145

on policy (

on policy (1, 4, False, True)

Timestamp: 2021-01-05 22:45:07.214562

on policy (1, 7, False, True)

Timestamp: 2021-01-05 22:45:07.243949

on policy (1, 10, False, True)

Timestamp: 2021-01-05 22:45:07.256860

on policy (1, 14, False, True)

Timestamp: 2021-01-05 22:45:07.270954

on policy (1, 18, False, True)

Timestamp: 2021-01-05 22:45:07.284576

on policy (1, 21, False, True)

Timestamp: 2021-01-05 22:45:07.299101

on R0 5

on policy (4, 4, True, True)

Timestamp: 2021-01-05 22:45:07.317373

on policy (8, 4, True, True)

Timestamp: 2021-01-05 22:45:07.359746

on policy (12, 4, True, True)

Timestamp: 2021-01-05 22:45:07.384629

on policy (16, 4, True, True)

Timestamp: 2021-01-05 22:45:07.416674

on policy (4, 7, True, True)

Timestamp: 2021-01-05 22:45:07.463390

on policy (8, 7, True, True)

Timestamp: 2021-01-05 22:45:07.491385

on policy (12, 7, True, True)

Timestamp: 2021-01-05 22:45:07.508333

on policy (16, 7, True, True)

Timestamp: 2021-01-05 22:45:07.555320

on policy (

launched dict_values([<Process(Process-1, started)>, <Process(Process-2, started)>, <Process(Process-3, started)>, <Process(Process-4, started)>, <Process(Process-5, started)>, <Process(Process-6, started)>, <Process(Process-7, started)>, <Process(Process-8, started)>, <Process(Process-9, started)>, <Process(Process-10, started)>, <Process(Process-11, started)>, <Process(Process-12, started)>, <Process(Process-13, started)>, <Process(Process-14, started)>, <Process(Process-15, started)>, <Process(Process-16, started)>, <Process(Process-17, started)>, <Process(Process-18, started)>, <Process(Process-19, started)>, <Process(Process-20, started)>, <Process(Process-21, started)>, <Process(Process-22, started)>, <Process(Process-23, started)>, <Process(Process-24, started)>, <Process(Process-25, started)>, <Process(Process-26, started)>, <Process(Process-27, started)>, <Process(Process-28, started)>, <Process(Process-29, started)>, <Process(Process-30, started)>, <Process(Process-31, starte