In [1]:
import os
os.chdir('../src/models')

In [2]:
import simulate_pandemic as simp
from actions import costs, city_restrictions
from MCFS import mcts, treeNode
from CMDP import CovidState

from tqdm import tqdm
import numpy as np
import pickle as pkl
import datetime
from numpy.random import default_rng


import plotly.graph_objects as go
import pandas as pd

def run(gpickle_path,
        p_r, 
        step_size = 7, days = 364, seed=None):

    rng = default_rng(seed)
    pop_matrix, adj_list = simp.init_infection(gpickle_path)
    data = []
    actions = []
    
    for day in tqdm(range(1, days+1)):
        #if less than 20% still susceptible, break simulation
        if pop_matrix[np.where(pop_matrix[:,1] == -1)].shape[0] > pop_matrix.shape[0]*.9: break            
            
        
        # Choose a new policy at each week
        if day % step_size == 1:                    
            action = 'Unrestricted'
            actions.append(action)
            restrictions = city_restrictions[action]

        pop_matrix = simp.spread_infection(pop_matrix, adj_list, restrictions, day, rng, p_r)
        pop_matrix = simp.lambda_leak_expose(pop_matrix, day)
        pop_matrix = simp.update_population(pop_matrix)

        data.append(pop_matrix[:, 0:2])
    
    return data, actions

In [3]:
prhome = 0.06
p_r = {
    'home'    :  prhome,
    'neighbor':  .1*prhome,
    'work'    :  .1*prhome,
    'school'  :  .15*prhome,
}

gpickle_path = '../../data/processed/SP_multiGraph_Job_Edu_Level.gpickle'

In [4]:
data, actions = run(gpickle_path, p_r)

100%|████████████████████████████████████████████████████████████████████████████████| 364/364 [00:15<00:00, 23.76it/s]


In [5]:
pd.DataFrame(pd.Series(d[:, 1]).value_counts() for d in data)[3].max()

509.0