In [1]:
import os 

os.chdir('../../src/models')

In [2]:
from patient_evolution_matrix import *

import networkx as nx
import numpy as np
from tqdm import tqdm
from functools import partial
import pandas as pd
import plotly.graph_objects as go
from datetime import datetime

In [3]:
state_dict = {
    'susceptible': 0,
    'exposed': 1,
    'infected': 2,
    'hospitalized': 3,
    'removed': -1
}

In [4]:
def susceptible_to_exposed(node, day):
    if node['status'] != 'susceptible':
        print(node)
        raise ValueError("Node status different from susceptible")
    
    node['status'] = 'exposed'
    node['period_duration'] = sample_incubation()
    node['infection_day'] = day
    
    return

def exposed_to_infected(node):
    if node['status'] != 'exposed':
        raise ValueError("Node status different from exposed")
    if node['period_duration'] > 0:
        raise ValueError("Not yet time to change")
    
    node['status'] = 'infected'
    node['period_duration'] = sample_onset_to_hosp_or_asymp()
    
    return
    
def infected_to_new_state(node):
    if node['status'] != 'infected':
        raise ValueError("Node status different from infected")
    if node['period_duration'] > 0:
        raise ValueError("Not yet time to change")
    
    if needs_hospitalization(node['age']):
        node['status'] = 'hospitalized'
        node['period_duration'] = sample_hospitalization_to_removed()
    else:
        node['status'] = 'removed'
        
    return
        
def hospitalized_to_removed(node):
    if node['status'] != 'hospitalized':
        raise ValueError("Node status different from hospitalized")
    if node['period_duration'] > 0:
        raise ValueError("Not yet time to change")
        
    node['status'] = 'removed'
    
    return

def change_state(node):
    if node['status'] == 'exposed':
        exposed_to_infected(node)
        return
    if node['status'] == 'infected':
        infected_to_new_state(node)
        return 
    if node['status'] == 'hospitalized':
        hospitalized_to_removed(node)
        return 

def update_node(node):
    if node['status'] == 'susceptible' or node['status'] == 'removed':
        return
    if node['period_duration'] == 0: 
        change_state(node)
        return 
    
    node['period_duration'] = node['period_duration'] - 1
    return

def infect_node(node, day):
    susceptible_to_exposed(node, day)
    return

def infect_graph(Graph, node_list, day):
    for n in node_list:
        infect_node(Graph.nodes[n], day)

def update_graph(Graph):
    for i, node in Graph.nodes(data=True):
        update_node(node)
        

In [18]:
def change_state(person):
    if person[1] == state_dict['exposed']:
        person = exposed_to_infected(person)
        return person
    if person[1] == state_dict['infected']:
        person = infected_to_new_state(person)
        return person
    if person[1] == state_dict['hospitalized']:
        hospitalized_to_removed(person)
        return person

def susceptible_to_exposed(person, day):
    if person[1] != state_dict['susceptible']:
        print('Failed Person')
        print(person)
        raise ValueError("Node status different from susceptible")
        
    person[1] = state_dict['exposed']
    person[2] = day
    person[3] = sample_incubation()

    return person

def exposed_to_infected(person):
    if person[1] != state_dict['exposed']:
        raise ValueError("person status different from exposed")
    if person[3] > 0:
        raise ValueError("Not yet time to change")
    
    person[1] = state_dict['infected']
    person[3] = sample_onset_to_hosp_or_asymp()
    
    return person

def infected_to_new_state(person):
    if person[1] != state_dict['infected']:
        raise ValueError("person status different from infected")
    if person[3] > 0:
        raise ValueError("Not yet time to change")
    
    if needs_hospitalization(person[4]):
        person[1] = state_dict['hospitalized']
        person[3] = sample_hospitalization_to_removed()
    else:
        person[1] = state_dict['removed']
        
    return person

def hospitalized_to_removed(person):
    if person[1] != state_dict['hospitalized']:
        raise ValueError("person status different from hospitalized")
    if person[3] > 0:
        raise ValueError("Not yet time to change")
        
    person[1] = state_dict['removed']
    
    return person

def infect_population(pop_matrix, infected, day):
    matrix_change = pop_matrix[np.isin(pop_matrix[:,0], infected.astype(int))]
    matrix_keep = pop_matrix[~np.isin(pop_matrix[:,0], infected)]
    matrix_change = np.apply_along_axis(susceptible_to_exposed, 1, matrix_change, day=day)
    
    new_matrix = np.concatenate((matrix_keep, matrix_change))
    assert new_matrix.shape == pop_matrix.shape
    
    return new_matrix

def update_population(pop_matrix):
    matrix_keep = pop_matrix[np.isin(pop_matrix[:,1], [state_dict['susceptible'], state_dict['removed']])]
    matrix_change = pop_matrix[~np.isin(pop_matrix[:,1], [state_dict['susceptible'], state_dict['removed']])]
    
    matrix_change[:,3] = matrix_change[:,3].astype(int) - 1
    matrix_no_change = matrix_change[matrix_change[:,3].astype(int) > 0]
    matrix_change = matrix_change[matrix_change[:,3].astype(int) == 0]
    
    if matrix_change.shape[0] > 0:    
        matrix_change = np.apply_along_axis(change_state, 1, matrix_change)

    new_matrix = np.concatenate((matrix_keep, matrix_no_change, matrix_change))  
    assert new_matrix.shape == pop_matrix.shape
    
    return new_matrix

def spread_infection(pop_matrix, restrictions, day):
    
    
    
    current_i = pop_matrix[np.where(pop_matrix[:,1] == state_dict['infected'])][:,0]

    if current_i.shape[0] == 0: return pop_matrix
    #now = datetime.now()
    infected = list(map(partial(spread_through_contacts, restrictions=restrictions), current_i))  
    #later = datetime.now()
    #print(f'Time for executing spread_through_contacts: {(later - now).total_seconds()}s')
          
    infected = np.unique(np.array([x for l in infected for x in l ]))
       
    mask = np.isin(pop_matrix[:,0], infected)
    susceptible = np.isin(pop_matrix[np.array(mask)][:,1], state_dict['susceptible'])
    infected = pop_matrix[np.array(mask)][:,0][susceptible]   
    
    if len(infected) == 0: return pop_matrix
    
    #now = datetime.now()
    new_matrix = infect_population(pop_matrix, infected, day)
    #later = datetime.now()
    #print(f'Time for executing infect_population: {(later - now).total_seconds()}s')
    #print(f'Done day {day} ' + 50*'*')
          
    return new_matrix

def spread_through_contacts(spreader, restrictions):
    global G, pr

    spreader = spreader
    contacts = [[y,v['edge_type']] for x,y,v in G.edges(spreader, data=True)]
    
    infected=[y for r in restrictions.keys() for y,v in contacts \
                      if v==r and np.random.random() < p_r[r] * (1 - restrictions[r])]
    
    #print(infected)
    return infected
#    infected = np.array(list(map(partial(infect_through_relation, restrictions=restrictions), contacts)))
#    return contacts[infected][:,0].astype(int)

def infect_through_relation(edge, restrictions):
    global p_r
    return np.random.random() < p_r[edge[1]] * (1 - restrictions[edge[1]])
    
# Start with pct% of population infected
def init_graph(initial_infection = .0001, file_path='SP_multiGraph.gpickle'):
    G = nx.read_gpickle(file_path)
    pop_matrix = init_infection(G, initial_infection)

    return G, pop_matrix

def init_infection(G, pct):
    """
    Given a Graph G, infects pct% of population and set the remainder as susceptible.
    This is considered day 0.
    """    
    size = int(len(G.nodes) * pct) 
    
    infected = list(np.random.choice(G.nodes(), size = size, replace = False))
    
    infected = [x for x in infected]

    pop_matrix = np.array([[node, state_dict['susceptible'], 
                                        -1, -1, data['age']] for node,data in G.nodes(data=True)])

    matrix_change = pop_matrix[np.isin(pop_matrix[:,0], infected)]
    
    matrix_keep = pop_matrix[~np.isin(pop_matrix[:,0], infected)]
    
    matrix_change = np.apply_along_axis(susceptible_to_exposed, 1, matrix_change, day=0)

    new_matrix = np.concatenate((matrix_keep, matrix_change))
    assert new_matrix.shape == pop_matrix.shape
    
    return new_matrix


KeyboardInterrupt



In [36]:
def simulate_pandemic_numpy(i):
    print(i)
    G, pop_matrix = init_graph(file_path='SP_multiGraph_intID.gpickle')

    data = []

    restrictions={'work':0, 'school': 0, 'home':0, 'neighbor':0}
    p_r={'neighbor':.3/80, 'work':.3/40, 'school':.6/40, 'home':.3}

    for day in tqdm(range(1, 500)):
        pop_matrix = update_population(pop_matrix)
        pop_matrix = spread_infection(pop_matrix, restrictions, day)
        #if day == 28:
        #    restrictions={'work':.6, 'school': 1, 'home':0, 'neighbor':.4}
        #if day >= 97:
        #    restrictions={'work':0, 'school': 0, 'home':0, 'neighbor':0}


        data.append(pop_matrix[:,1]) 
    
    return data

In [None]:
from joblib import Parallel, delayed
end_policy_may_pr04 = Parallel(n_jobs=6)(delayed(simulate_pandemic_numpy)(i) for i in range(12))

In [28]:
d_counts = pd.DataFrame([pd.Series(d).value_counts() for d in data])
d_counts.fillna(0, inplace=True)

fig = go.Figure()

#fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[-1], name='removed',line_color='green'))
#fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[0], name='susceptible', line_color='blue'))
#fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[1], name='exposed',line_color ='orange'))
fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[2]*12e6/50e3, name='infected', line_color = 'red'))
#fig.add_trace(go.Scatter(x=d_counts.index.to_list(), y=d_counts[3], name='hospitalized', line_color = 'purple'))

In [8]:
restrictions={'work':0, 'school': 0, 'home':0, 'neighbor':0}
p_r={'neighbor':.1, 'work':.1, 'school':.5, 'home':.9}
spreader = '00010008101'
contacts = np.array([[y,v['edge_type']] for x,y,v in G.edges(spreader, data=True)])


In [9]:
for x,y,z in G.edges('00010008101', data=True):
    print(x + ' _ ' + y + ' _ ' + str(z))