In [1]:
import os
os.chdir('../../src/')
from visualization.SIR_Plot import plot_hospitalized

In [2]:
os.chdir('models')
from MCFS import mcts, treeNode
import simulate_pandemic as simp

Loading Graph...Done!


In [259]:
from glob import glob
import pickle as pkl
import pandas as pd

H = 2
N = 48
D = 364

files = glob(f'../../data/MCTS_Results/pickles/looser_cost_H{H}_N{N}_D{D}_bfFalse*.pkl')
results = []
act_list = []
print(len(files))
for file in files:
    with open(file, 'rb') as f:
        data, actions, tree = pkl.load(f)
    
    data = plot_hospitalized(data, actions, 7, 
                             file.replace('.pkl', '').replace('pickles', 'plots'), make_df=False)
    #data = pd.DataFrame([pd.Series(d).value_counts() for d in data])
    #data.fillna(0, inplace=True)
    results.append(data)
    act_list.append(actions)

10


In [260]:
from sklearn.metrics import mean_absolute_error
import numpy as np

pop=len(simp.G.nodes())
capacity = 0.0015


mae_policy = [mean_absolute_error(y_true=len(d)*[capacity], y_pred=d[3]/pop) for d in results]
print('MAE AVERAGE ', round(np.mean(mae_policy)*1e3, 3))
print('MAE STD ', round(np.std(mae_policy)*1e3, 3))

MAE AVERAGE  0.73
MAE STD  0.102


In [261]:
over_capacity = []
for r in results:
    tmp = (r[3] / pop) - capacity
    over_capacity.append(tmp[tmp>0].sum())

print('CAPACITY OVERLOAD MEAN ', round(np.mean(over_capacity)*1e3, 3))
print('CAPACITY OVERLOAD STD ', round(np.std(over_capacity)*1e3, 3))

CAPACITY OVERLOAD MEAN  14.669
CAPACITY OVERLOAD STD  18.19


In [262]:
from actions import costs, exposed_cost


sims_costs = []
for d, actions in zip(results, act_list):
    replicated_actions = [l for lst in [7*[a] for a in actions] for l in lst]
    sims_costs.append(np.sum([max(exposed_cost(v/pop), costs[replicated_actions[i]]) for i, v in enumerate(d[3])]))

print(round(np.mean(sims_costs), 1))
print(round(np.std(sims_costs), 1))

68.8
5.3


In [269]:
import plotly.graph_objects as go

def make_beds_graph(data, actions, step_size, title, color_map=None, make_df=True):
    fig = go.Figure()
    
    if make_df:
        data = pd.DataFrame([pd.Series(d).value_counts() for d in data])
        data.fillna(0, inplace=True)

    color_map = {
        'Lockdown':          'rgb(0.83, 0.13, 0.15)',
        'Hard Quarantine':    'rgb(0.85, 0.35, 0.13)',
        'Light Quarantine':   'rgb(0.97, 0.91, 0.56)',
        'Social Distancing':  'rgb(0.67, 0.88, 0.69)',
        'Unrestricted':        'rgb(0.86, 0.86, 0.86)'    
    }

    pop = 55492

    x = list(range(len(data[0])+1))
  

    # HOSPITALIZED LINE #####################################################################################
    #\definecolor{royalblue(web)}{rgb}{0.25, 0.41, 0.88}
    for ix,d in enumerate(data):
        if ix == int(len(data)/2):
            pass
        else:
            if ix == 0:
                fig.add_trace(go.Scatter(x=x, y=d[3]/pop, line_color = 'rgb(0.25, 0.41, 0.88)',
                                    line=dict(width=2), opacity=.75, name="Simulations",
                                    # yaxis="y2"
                                    ))
                #fig.add_trace(go.Scatter(x=x, y=d[1]/pop, line_color = 'firebrick',
                #                    line=dict(width=2), opacity=.75, name='Exposed',
                #                    # yaxis="y2"
                #                    ))
            else:
                fig.add_trace(go.Scatter(x=x, y=d[3]/pop, line_color = 'rgb(0.25, 0.41, 0.88)',
                                    line=dict(width=2), opacity=.75, showlegend=False,
                                     #yaxis="y2"
                                    ))
               #fig.add_trace(go.Scatter(x=x, y=d[1]/pop, line_color = 'firebrick',
               #                     line=dict(width=2), showlegend=False, opacity=.75, #name='exposed', 
               #                     # yaxis="y2"
               #                     ))
            
    ix = int(len(data)/2)
    print(ix)
    fig.add_trace(go.Scatter(x=x, y=data[ix][3]/pop, line_color = 'black',
                        line=dict(width=3), name='Median Performer'
                        # yaxis="y2"
                        ))
    act = list(map(color_map.get,  actions[ix]))

    fig.update_layout(
            shapes=[
                dict(
                    type="rect",
                    # x-reference is assigned to the x-values
                    xref="x",
                    # y-reference is assigned to the plot paper [0,1]
                    yref="paper",
                    x0=step_size*i,
                    y0=0,
                    x1=step_size*(i+1),
                    y1=1,
                    fillcolor=a,
                    opacity=0.7,
                    layer="below",
                    line_width=0,
                ) for i,a in enumerate(act)] 
        )
            
    # CAPACITY LINE ################################################################################
    fig.add_trace(go.Scatter(x=x, y=len(x)*[0.0015], name='Capacity', line_color = 'black',
                            line=dict(dash='dash', width = 2)
                             #, yaxis="y2"
                            ))
    #################################################################################################
    

    # GHOST TRACES FOR LEGEND#############################
    #for k,v in color_map.items():
    #    fig.add_trace(go.Bar(x=[None], y=[None], marker=dict(color=v), name = k))

    fig.update_layout( xaxis={'showgrid': False,},
                      yaxis = {'showgrid': False, 'zeroline': False,
                               'title':'Proportion of Hospitalized'
                              },
                      showlegend=False, hovermode="x",  font=dict(family="sans-serif",),
                      )
    
    fig.update_layout(
        xaxis = dict(
            tickmode = 'array',
            tickvals = list(range(0, len(data[0])+1, 28)),
            ticktext = list(range(0, int(len(data[0])/7) + 1, 4)),
            title = 'Weeks',
            ticks='outside',
            titlefont=dict(
                size=36,
            ),
            tickfont=dict(
                size=36,
                color="black"
            )),
        font=dict(
            color="black",
            size=35
        ),
        yaxis = dict(
            ticks='',
            showticklabels=True,
            dtick = 0.001,
            tickfont=dict(
                size=24,
                color="black"
            ),
          range=[0,0.00325],  # sets the range of xaxis

        ),
        showlegend=False,
        autosize=False,
        width=900,
        height=700,
        margin=dict(
            l=110,
            r=10,
            b=80,
            t=50,
            pad=0
        ),
    )
    
    fig.update_layout(
        title={
            'text': 'H=2, S=48',
            'y':.99,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'}
    )
    fig.update_layout(
        legend=dict(
            x=.67,
            y=0.995,
            traceorder="normal",
            font=dict(
                size=28,
                color="black"
            ),
            bgcolor="white",
            bordercolor="Black",
            borderwidth=2
        )
    )


    fig.update_yaxes(automargin=False)
    fig.write_image(f"{title}.pdf")

    fig.show()

In [270]:
ord_res, ord_act, ord_cost = list(zip(*sorted(zip(results, act_list, sims_costs), key=lambda x: x[1])))
make_beds_graph(ord_res, ord_act, 7, f'../../data/MCTS_Results/plots/Avg_H{H}_N{N}_D{D}', make_df=False)

5


In [108]:
np.mean([[i+1 for i in range(len(l)) if l[i] == 'Light Quarantine' and l[i+1] in ['Unrestricted', 'Social Distancing']][0]
                                                                                     for l in act_list])

12.3

In [109]:
np.mean([[i for i,x in enumerate(l) if x != 'Unrestricted' ][0] for l in act_list])

4.2

In [11]:



pd.DataFrame([pd.Series(l).value_counts() for l in act_list]).mean()

Light Quarantine     27.8
Social Distancing    18.3
Unrestricted          3.4
Hard Quarantine       2.5
dtype: float64