In [1]:
import os

os.chdir('..\\src\\models')

In [3]:
from MCTS import mcts
from CMDP import CovidState
import simulate_pandemic as simp
from policies import policies_restrictions_by_value as policy_restrictions
from tqdm import tqdm
import numpy as np

Loading Graph... Done!


In [4]:
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px

In [5]:
def run_full_mcst(iterationLimit=100, horizon=5, n_jobs = -1, step_size = 7, days = 211):

    pop_matrix = simp.init_infection(.0001)
    data = []
    actions = []

    tree = mcts(iterationLimit=iterationLimit, horizon=horizon)

    for day in tqdm(range(1, days)):
        #if less than 20% still susceptible, break simulation
        if pop_matrix[np.where(pop_matrix[:,1] == -1)].shape[0] > pop_matrix.shape[0]*.9: break

        # Choose a new policy at each week
        if day % step_size == 1:
            current_state = CovidState(pop_matrix, day, step_size)
            action = tree.search(initialState=current_state)
            actions.append(action)
            print(action)

            restrictions = policy_restrictions[action]

        pop_matrix = simp.spread_infection(pop_matrix, restrictions, day)
        pop_matrix = simp.lambda_leak_expose(pop_matrix, day)
        pop_matrix = simp.update_population(pop_matrix)

        data.append(np.array(sorted(pop_matrix,key=lambda x: x[0]))[:,1]) 
        
    
    return data, actions, tree

In [None]:
data, actions, tree = run_full_mcst(iterationLimit=20)

  2%|█▌                                                                                | 4/210 [00:24<57:44, 16.82s/it]

0


  5%|███▊                                                                             | 10/210 [00:50<28:25,  8.53s/it]

0


  8%|██████▌                                                                          | 17/210 [01:40<40:05, 12.46s/it]

0


 10%|████████▎                                                                      | 22/210 [03:33<1:59:55, 38.27s/it]

3


 13%|██████████▊                                                                      | 28/210 [03:34<14:09,  4.67s/it]

8


 17%|█████████████▌                                                                   | 35/210 [07:59<28:51,  9.90s/it]

1


 20%|████████████████▏                                                                | 42/210 [12:27<29:21, 10.48s/it]

4


 23%|██████████████████▉                                                              | 49/210 [21:56<57:10, 21.31s/it]

4


 27%|█████████████████████                                                          | 56/210 [35:08<1:17:12, 30.08s/it]

8


 30%|███████████████████████▋                                                       | 63/210 [47:45<1:12:28, 29.58s/it]

8


 33%|███████████████████████████                                                      | 70/210 [56:45<50:44, 21.74s/it]

7


 37%|████████████████████████████▉                                                  | 77/210 [1:02:40<32:06, 14.49s/it]

8


 40%|███████████████████████████████▏                                             | 85/210 [1:09:32<1:59:50, 57.52s/it]

9


 44%|█████████████████████████████████▋                                           | 92/210 [1:11:16<1:10:28, 35.84s/it]

10


 47%|█████████████████████████████████████▏                                         | 99/210 [1:12:13<36:26, 19.70s/it]

0


 50%|██████████████████████████████████████▋                                       | 104/210 [1:12:13<08:29,  4.81s/it]

In [12]:
d_counts = pd.DataFrame([pd.Series(d).value_counts() for d in data])
d_counts.fillna(0, inplace=True)

fig = go.Figure()

fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[-1]/55e3, name='removed',line_color='green'))
fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[0]/55e3, name='susceptible', line_color='blue'))
fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[1]/55e3, name='exposed',line_color ='orange'))
fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[2]/55e3, name='infected', line_color = 'red'))
fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[3]/55e3, name='hospitalized', line_color = 'purple'))
fig.show()

In [13]:
fig = go.Figure()


fig.add_trace(go.Scatter(x=[None], y=[None], showlegend=None,
        marker=dict(
            colorscale='deep',
            showscale=True,
            cmin=0,
            cmax=10,
            colorbar=dict(
            thickness=5, tickvals=[0, 10],
            ticktext=['Relaxed', 'Restricted'], outlinewidth=0,
            ticks="outside")), hoverinfo='none'),                              

        )

fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[3]/55e3, name='hospitalized', line_color = 'purple'))
fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=len(d_counts)*[0.0025], name='capacity', line_color = 'orange'))
fig.update_layout(
    shapes=[
        # 1st highlight during Feb 4 - Feb 6
        dict(
            type="rect",
            # x-reference is assigned to the x-values
            xref="x",
            # y-reference is assigned to the plot paper [0,1]
            yref="paper",
            x0=7*i,
            y0=0,
            x1=7*(i+1),
            y1=1,
            fillcolor=px.colors.sequential.deep[a],
            opacity=0.5,
            layer="below",
            line_width=0,
        ) for i,a in enumerate(actions)] 
)


fig.update_layout(coloraxis = {'colorscale':'deep'},  showlegend=False)

fig.show()

In [14]:
pop_matrix


NameError: name 'pop_matrix' is not defined