In [1]:
import os, sys
sys.path.insert(0, os.path.abspath('/media/yamba/193b2d55-8f5d-44c9-8078-07220e0aecba/yamba/Documents/CovidMDP/src/models'))

In [2]:
from actions import city_restrictions
import numpy as np
import simulate_pandemic as simp
import pandas as pd

rng = np.random.default_rng(None)

In [3]:
from tqdm import tqdm

In [4]:
import pickle as pkl

In [5]:
def make_prediction(data, model):
    df = pd.DataFrame([pd.Series(d[:, 1]).value_counts() for d in data]).rename(columns = {
                                                                                          -1 : 'removed',
                                                                                           0 : 'susceptible',
                                                                                           1 : 'exposed',
                                                                                           2 : 'infected',
                                                                                           3 : 'hospitalized'
                                                                                          }
                                                                             )
    mask = list(range(0, len(df)+7, 7))[-4:]
    mask = [v-1 for v in mask]

    df = df.loc[mask].reset_index(drop=True)
    df = df.fillna(0)
    df['week'] = df.index+1
    
    shifts = []
    for i in range(0,4):
        shifted = df.shift(i)
        shifted['week'] = shifted['week'] + i
        shifted = shifted.reset_index(drop=True)
        shifted.columns = pd.Series(shifted.columns).apply(lambda x: x if x in ['week'] else x+'_'+str(i+1))
        shifts.append(shifted)
    for shift in shifts:
        df = pd.merge(df, shift, on=['week'])
        
    return model.predict(pd.DataFrame(df.iloc[-1]).T[model.get_booster().feature_names])

In [6]:
def simluate_and_predict(model):
    prhome = 0.06
    p_r = {
        'home'    :  prhome,
        'neighbor':  .1*prhome,
        'work'    :  .1*prhome,
        'school'  :  .15*prhome,
    }
    print('Features:')
    print(model.get_booster().feature_names)
    g_pickle = '../../data/processed/SP_multiGraph_Job_Edu_Level.gpickle'
    days = 364
    step_size=7
    pop_matrix, edge_list = simp.init_infection(g_pickle)
    data = []
    actions = []
    predictions = []

    for day in tqdm(range(1, days+1)):
        #if less than 20% still susceptible, break simulation
        if pop_matrix[np.where(pop_matrix[:,1] == -1)].shape[0] > pop_matrix.shape[0]*.9: break            


        # Choose a new policy at each week
        if day % step_size == 1:                    
            action = rng.choice(list(city_restrictions.keys()))

            actions.append(action)
            #current_action = [action]
            restrictions = city_restrictions[action]

            if day >= 28:
                prediction = make_prediction(data, model)
                predictions.append([day, prediction[0]])

        pop_matrix = simp.spread_infection(pop_matrix, edge_list, restrictions, day, rng, p_r)
        pop_matrix = simp.lambda_leak_expose(pop_matrix, day)
        pop_matrix = simp.update_population(pop_matrix)



        data.append(pop_matrix[:, 0:2])
    
    df = pd.DataFrame([pd.Series(d[:, 1]).value_counts() for d in data]).rename(columns = {
                                                                                      -1 : 'removed',
                                                                                       0 : 'susceptible',
                                                                                       1 : 'exposed',
                                                                                       2 : 'infected',
                                                                                       3 : 'hospitalized'
                                                                                      }
                                                                         ).fillna(0)
    
    return df, predictions

In [7]:
import plotly.graph_objects as go

def make_plot(predictions, df):
    fig = go.Figure()

    preds = np.array(predictions)

    fig.add_trace(go.Scatter(x=df.index, y=df['hospitalized'], name='hospitalized', line_color = 'rgb(0.25, 0.41, 0.88)',
                            line=dict(width=3.5),  
                            # yaxis="y2"
                            ))

    fig.add_trace(go.Scatter(x=preds[:,0],
                             y=preds[:,1],
                             name='predictions',
                             line_color = 'purple',
                             line=dict(width=1.5),
                             mode='lines+markers'
                            # yaxis="y2"
                            ))

    fig.show()

In [8]:
with open('known_vars_model.pkl', 'rb') as file:
    known_vars_model = pkl.load(file)
    
with open('data_model_noact.pkl', 'rb') as file:
    full_model = pkl.load(file)

In [11]:
df, predictions = simluate_and_predict(full_model)
make_plot(predictions, df)

Features:
['susceptible_1', 'exposed_1', 'infected_1', 'removed_1', 'hospitalized_1', 'susceptible_2', 'exposed_2', 'infected_2', 'removed_2', 'hospitalized_2', 'susceptible_3', 'exposed_3', 'infected_3', 'removed_3', 'hospitalized_3', 'susceptible_4', 'exposed_4', 'infected_4', 'removed_4', 'hospitalized_4']


100%|██████████| 364/364 [00:12<00:00, 28.34it/s]


In [12]:
df, predictions = simluate_and_predict(known_vars_model)
make_plot(predictions, df)

Features:
['removed_1', 'hospitalized_1', 'removed_2', 'hospitalized_2', 'removed_3', 'hospitalized_3', 'removed_4', 'hospitalized_4', 'week']


100%|██████████| 364/364 [00:12<00:00, 29.02it/s]
