In [1]:
import sys
import os
import numpy as np
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path + "/src/multi_agent_sim")
from multi_agent_sim import MultiAgentSim

In [2]:
infection_dynamics_params = {
    'forward_gen_alpha': 8,
    'forward_gen_beta_hyperparams': (5,5),
    'detectability_curve_type': 'optimistic',
    'self_reporting_multiplier': 0.8,
    'self_reporting_delay': 3,
    'init_infection_rate': 0.001,
    'use_deterministic_infection_counts': True,
}

social_network_params_high_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': 2,
    'neg_bin_p_hyperparams': (2,3),
}

n_agents = 1000
main_params = {
    'n_agents': 1000,
    'use_contact_tracing': False,
    'use_adaptive_testing': True,
    'use_surveillance_testing': True
}

ct_params = {
    'ct_recall_window': 8,
    'ct_delay_distribution': [1/3,1/3,1/3], # uniform over 0, 1, 2 days delay
    'ct_recall_rate': 0.5
}

at_params = {
    'at_delay_distribution': [1/5] * 5,
    'at_net_size_contact_multiplier': 10,
    'at_recall_rate': 0.9
}

st_params = {
    'st_testing_window': 3,
    'st_missed_test_rate': 0.1
}

def init_high_var_sim(R0):
    social_network_params_high_var['neg_bin_r'] = R0 / 0.9
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_high_var, 
                         ct_params, at_params, st_params)

social_network_params_low_var = {
    'location_vec_dim': 2,
    'network_gamma': 3,
    'daily_contacts_distn_type': 'negative_binomial',
    'neg_bin_r': 2,
    'neg_bin_p_hyperparams': (5,6),
}

def init_low_var_sim(R0):
    social_network_params_low_var['neg_bin_r'] = R0 / 0.9
    return MultiAgentSim(main_params, infection_dynamics_params, social_network_params_low_var, 
                         ct_params, at_params, st_params)

In [3]:
def sample_trajectory(R0, at_multiplier, st_testing_window):
    st_params['st_testing_window'] = st_testing_window
    at_params['at_net_size_contact_multiplier'] = at_multiplier
    sim = init_low_var_sim(R0)
    infection_counts = [len(sim.infection.get_cum_infected_agent_ids())]
    init_infection_counts = infection_counts[0]
    for _ in range(7*12):
        sim.step()
        infection_counts.append(len(sim.infection.get_cum_infected_agent_ids()))
    total_tests = sim.get_total_tests()
    return total_tests, infection_counts, init_infection_counts, infection_counts[-1]

In [4]:
import cProfile
cProfile.run('sample_trajectory(1.5, 2, 7)')

         17780401 function calls (16778868 primitive calls) in 23.288 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      210    0.000    0.000    0.004    0.000 <__array_function__ internals>:2(any)
      210    0.000    0.000    0.003    0.000 <__array_function__ internals>:2(atleast_1d)
       39    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(count_nonzero)
       35    0.000    0.000    0.001    0.000 <__array_function__ internals>:2(cumsum)
   999000    0.740    0.000    4.040    0.000 <__array_function__ internals>:2(dot)
      630    0.001    0.000    0.020    0.000 <__array_function__ internals>:2(extract)
      630    0.001    0.000    0.003    0.000 <__array_function__ internals>:2(nonzero)
   999000    0.922    0.000   15.542    0.000 <__array_function__ internals>:2(norm)
      210    0.000    0.000    0.002    0.000 <__array_function__ internals>:2(place)
     1116    0.001    0.000    0

In [5]:
main_params['n_agents'] = 3000
cProfile.run('sample_trajectory(1.5, 2, 7)')

         155251235 function calls (146250995 primitive calls) in 200.324 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      270    0.000    0.000    0.005    0.000 <__array_function__ internals>:2(any)
      270    0.000    0.000    0.004    0.000 <__array_function__ internals>:2(atleast_1d)
       90    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(count_nonzero)
      135    0.000    0.000    0.003    0.000 <__array_function__ internals>:2(cumsum)
  8997000    6.526    0.000   35.373    0.000 <__array_function__ internals>:2(dot)
      810    0.001    0.000    0.026    0.000 <__array_function__ internals>:2(extract)
      810    0.001    0.000    0.003    0.000 <__array_function__ internals>:2(nonzero)
  8997000    8.156    0.000  135.504    0.000 <__array_function__ internals>:2(norm)
      270    0.000    0.000    0.003    0.000 <__array_function__ internals>:2(place)
     3141    0.004    0.000  

In [6]:
cProfile.run('init_low_var_sim(1.5)')

         153048321 function calls (144049881 primitive calls) in 204.198 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      120    0.000    0.000    0.003    0.000 <__array_function__ internals>:2(any)
      120    0.000    0.000    0.002    0.000 <__array_function__ internals>:2(atleast_1d)
  8997000    6.588    0.000   36.486    0.000 <__array_function__ internals>:2(dot)
      360    0.001    0.000    0.016    0.000 <__array_function__ internals>:2(extract)
      360    0.000    0.000    0.002    0.000 <__array_function__ internals>:2(nonzero)
  8997000    8.481    0.000  139.510    0.000 <__array_function__ internals>:2(norm)
      120    0.000    0.000    0.002    0.000 <__array_function__ internals>:2(place)
     3000    0.004    0.000    0.057    0.000 <__array_function__ internals>:2(prod)
      120    0.000    0.000    0.002    0.000 <__array_function__ internals>:2(putmask)
      720    0.001    0.000    0.006 