In [1]:
import sys
sys.path.append("../../src/models")

In [2]:
import pandas as pd
import networkx as nx
import numpy as np

In [3]:
from joblib import Parallel, delayed
#from tqdm import tqdm
from tqdm.auto import tqdm

In [4]:
from actions import city_restrictions

prhome = 0.06

p_r = {
    'home'    :  prhome,
    'neighbor':  .1*prhome,
    'work'    :  .1*prhome,
    'school'  :  .15*prhome,
}

In [5]:
model_dataset = pd.read_parquet("model_dataset.parquet")

In [6]:
res_df = pd.read_parquet("simulation_results_dataset.parquet")

In [7]:
gpickle_path = "../../data/processed/SP_multiGraph_Job_Edu_Level.gpickle"
G = nx.read_gpickle(gpickle_path)

In [8]:
edges = [
    (source, target, data['edge_type'])
    for source, target, data in G.edges(data=True)
]

edgelist_df = pd.DataFrame(edges, columns=['source', 'target', 'edge_type'])

In [9]:
def get_percolation_states(df, initial_states=[0], percolation_one_states=[1, 2], state_column='state'):
    return df[state_column].apply(lambda x: 0 if x in initial_states
                                  else 1 if x in percolation_one_states
                                  else np.nan)

def get_percolation_nodes(df):
    nodes_percolation = df[['id', 'percolation']].dropna(subset=['percolation'])
    nodes_percolation['id'] = nodes_percolation['id'].astype(int)
    nodes_percolation = nodes_percolation.set_index('id').to_dict()['percolation']
    
    return nodes_percolation

def get_percolation_edgelist(edgelist_df, nodes_percolation, action):
    edgelist_percolation = edgelist_df[
        (edgelist_df['target'].isin(nodes_percolation.keys())) &
        (edgelist_df['source'].isin(nodes_percolation.keys()))
    ].copy()

    edgelist_percolation['weight'] = (
        (1 - edgelist_percolation['edge_type'].map(city_restrictions[action])) * 
        edgelist_percolation['edge_type'].map(p_r)
    )
        
    return edgelist_percolation
        
def message_pass(edgelist_percolation, nodes_percolation):
    percolation_graph = nx.from_pandas_edgelist(edgelist_percolation, edge_attr='weight')
    adj_matrix = nx.to_scipy_sparse_matrix(percolation_graph, weight='weight')
        
    initial_estates = np.array([nodes_percolation[n] for n in percolation_graph.nodes()])
    
    mp1 = adj_matrix @ initial_estates
    mp2 = adj_matrix @ mp1
    mp3 = adj_matrix @ mp2
        
    message_passing_df = pd.DataFrame(np.stack([mp1, mp2, mp3, list(percolation_graph.nodes())]).T,
                                  columns=['mp1_hospitalized', 'mp2_hospitalized', 'mp3_hospitalized', 'id'])
        
    return message_passing_df


def all_steps_message_passing(res_df, simulation, week, initial_states=[0], percolation_one_states=[1, 2]):

    sim_df = res_df[res_df['simulation'] == simulation]
    sim_week_df = sim_df[sim_df['week'] == week].copy()
    action = sim_week_df['action'].iloc[0]

    sim_week_df['percolation'] = get_percolation_states(sim_week_df, initial_states, percolation_one_states)
    nodes_percolation = get_percolation_nodes(sim_week_df)
    edgelist_percolation = get_percolation_edgelist(edgelist_df, nodes_percolation, action)
    message_passing_df = message_pass(edgelist_percolation, nodes_percolation)
    message_passing_df['week'], message_passing_df['simulation'] = week, simulation

    return message_passing_df

In [None]:
message_passing_dfs = Parallel(n_jobs=16)(delayed(all_steps_message_passing)
                                          (res_df, simulation, week,
                                           initial_states=[0, 1, 2], 
                                           percolation_one_states=[3]) 
                                          for simulation, week in 
                                          tqdm([(sim, w) for sim in res_df['simulation'].unique() for w in range(1, 17)]))

message_passing_final_df = pd.concat(message_passing_dfs)

  0%|          | 0/256 [00:00<?, ?it/s]

## No joblib test code

message_passing_dfs = [all_steps_message_passing(res_df, simulation, week,
                                           initial_states=[0, 1, 2], 
                                           percolation_one_states=[3])
                                          for simulation, week in 
                                          tqdm([(sim, w) for sim in res_df['simulation'].unique() for w in range(1, 17)])]

message_passing_final_df = pd.concat(message_passing_dfs)

In [None]:
model_dataset = pd.merge(
    model_dataset,
    message_passing_final_df,
    on=['id', 'simulation', 'week']
)

In [None]:
model_dataset.reset_index(drop=True).to_parquet("model_dataset.parquet")

In [14]:
model_dataset.shape

(431921, 50)