In [1]:
import sys
import os
import numpy as np
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path + "/src/multi_agent_sim")
from multi_agent_sim import MultiAgentSim

In [2]:
infection_dynamics_params = {
    'forward_gen_alpha': 8,
    'forward_gen_beta_hyperparams': (5,5),
    'detectability_curve_type': 'optimistic',
    'self_reporting_multiplier': 0.8,
    'self_reporting_delay': 3,
    'init_infection_rate': 0.001,
    'use_deterministic_infection_counts': True,
}

social_network_params_high_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': 2,
    'neg_bin_p_hyperparams': (2,3),
}

n_agents = 3000
main_params = {
    'n_agents': 3000,
    'use_contact_tracing': False,
    'use_adaptive_testing': True,
    'use_surveillance_testing': True
}

ct_params = {
    'ct_recall_window': 8,
    'ct_delay_distribution': [1/3,1/3,1/3], # uniform over 0, 1, 2 days delay
    'ct_recall_rate': 0.5
}

at_params = {
    'at_delay_distribution': [1/5] * 5,
    'at_net_size_contact_multiplier': 10,
    'at_recall_rate': 0.9
}

st_params = {
    'st_testing_window': 3,
    'st_missed_test_rate': 0.1
}

def init_high_var_sim(R0):
    social_network_params_high_var['neg_bin_r'] = R0 / 0.9
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_high_var, 
                         ct_params, at_params, st_params)

social_network_params_low_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': 2,
    'neg_bin_p_hyperparams': (5,6),
}

def init_low_var_sim(R0):
    social_network_params_low_var['neg_bin_r'] = R0 / 0.9
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_low_var, 
                         ct_params, at_params, st_params)

In [3]:
def sample_trajectory(R0, at_multiplier, st_testing_window):
    st_params['st_testing_window'] = st_testing_window
    at_params['at_net_size_contact_multiplier'] = at_multiplier
    sim = init_low_var_sim(R0)
    infection_counts = [len(sim.infection.get_cum_infected_agent_ids())]
    init_infection_counts = infection_counts[0]
    for _ in range(7*12):
        sim.step()
        infection_counts.append(len(sim.infection.get_cum_infected_agent_ids()))
    total_tests = sim.get_total_tests()
    return total_tests, infection_counts, init_infection_counts, infection_counts[-1]

In [6]:
import cProfile
sim = init_low_var_sim(2)
cProfile.run('sim.step()')

         22115 function calls in 0.019 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        4    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(count_nonzero)
        4    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(cumsum)
        4    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(prod)
        4    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(unique)
        1    0.000    0.000    0.019    0.019 <string>:1(<module>)
        8    0.000    0.000    0.000    0.000 _asarray.py:88(asanyarray)
        8    0.000    0.000    0.000    0.000 _weakrefset.py:70(__contains__)
        4    0.000    0.000    0.000    0.000 abc.py:178(__instancecheck__)
        1    0.000    0.000    0.000    0.000 adaptive_testing.py:46(step_adaptive_tests)
     9941    0.002    0.000    0.003    0.000 agent.py:10(get_param)
        4    0.000    0.000    0.000    0.00

# look at R0 vs. at_multiplier

In [None]:
R0s_to_try = [1.5, 2] #[1, 1.5, 2, 2.5, 3, 3.5, 4]

at_multipliers_to_try = [2,6,8] #[2, 4, 6, 8, 10, 12, 14]

at_pareto_results = {}

main_params['use_surveillance_testing'] = False
main_params['use_contact_tracing'] = False

ntrajectories = 10

for R0 in R0s_to_try:
    for mult in at_multipliers_to_try:
        print("on pair {}".format((R0, mult)))
        at_pareto_results[(R0, mult)] = [sample_trajectory(R0, mult, 1) for _ in range(ntrajectories)]

on pair (1.5, 2)
