In [1]:
import os
os.chdir('..\\..\\src\\models')

In [2]:
import simulate_pandemic 
import numpy as np
import pandas as pd
import plotly.express as px
from tqdm import tqdm
from disease_states import state_to_color

Loading Graph... Done!


In [3]:
data = simulate_pandemic.main()

  7%|█████▊                                                                           | 36/499 [00:27<05:58,  1.29it/s]


In [4]:
from pyproj import Transformer

df = pd.DataFrame(np.array(data))

transformer = Transformer.from_crs('epsg:22523', 'epsg:4289')

def re_project(x,y):
    return transformer.transform(x,y)

dfs = []
for col in tqdm(df.columns):
    size = len(df)
    p_id = size*[col+1]
    
    home_x, home_y = simulate_pandemic.G.nodes[col+1]['home_x'], simulate_pandemic.G.nodes[col+1]['home_y']
    home_x, home_y = re_project(home_x, home_y)
    home_x, home_y = home_x + np.random.normal(loc=0,scale=.001), home_y + np.random.normal(loc=0,scale=.001)
    home_x, home_y = size*[home_x] , size*[home_y]
    tmp = np.transpose([df[col], p_id, list(range(0, len(df[col]))), home_x, home_y])
   
    dfs.append(tmp)

100%|██████████████████████████████████████████████████████████████████████████| 55492/55492 [00:05<00:00, 9753.31it/s]


In [5]:
df_ts = pd.DataFrame(np.concatenate(dfs), columns = ['status', 'id', 'day', 'home_x', 'home_y'])
fill = pd.DataFrame([['red', -1, 0, None, None], ['purple', -1, 0, None, None], 
                     ['orange', -1, 0, None, None], ['gray', -1, 0, None, None]],
                    columns = ['status', 'id', 'day', 'home_x', 'home_y'])

df_ts.status = df_ts.status.apply(lambda x: state_to_color[x])
df_ts = pd.concat([df_ts, fill]).reset_index(drop=True)

df_full = df_ts.copy()

In [9]:
import pandas as pd
import plotly.graph_objects as go
mapbox_access_token = 'pk.eyJ1IjoieWFtYmFuaXMiLCJhIjoiY2thbzFza2t6MXdkNTJ6cWg3emQxMjhmeCJ9.FJqkTcY28Pvhjbu-1XquMg'


max_day = 300

sample_ids = np.random.choice(df_full.id, replace=False, size = 2500)
sample_ids = np.concatenate([sample_ids, [-1]])


df_plot = df_full[(df_full.id.isin(sample_ids)) & (df_ts.day <= max_day)]

max_day = int(df_ts.day.max())

data = [go.Scattermapbox(
               lat=df_plot[df_plot.day == 0]['home_x'],
               lon=df_plot[df_plot.day == 0]['home_y'],
               mode='markers',
               marker=dict(size=5, color=df_plot[df_plot.day == 0]['status'])
            )
        ]

layout = go.Layout(width=800,
    autosize=True,
    hovermode='closest',
    mapbox=dict(accesstoken=mapbox_access_token,
                bearing=0,
                center=dict(lat=-23.5505,
                            lon=-46.6333),
                pitch=0,
                zoom=9,
                style='light'
                )
            )

frames = [dict(data= [dict(type='scattermapbox',
                           lat=df_plot[df_plot.day == day]['home_x'],
                           lon=df_plot[df_plot.day == day]['home_y'], mode='markers',
                           marker=dict(size=5, color=df_plot[df_plot.day == day]['status']))],
               traces= [0], 
               name='day {}'.format(day)       
              )for day  in  range(1, max_day)] 

sliders = [dict(steps= [dict(method= 'animate',
                           args= [[ 'day {}'.format(day) ],
                                  dict(mode= 'immediate',
                                  frame= dict( duration=100, redraw= True ),
                                           transition=dict( duration= 0)
                                          )
                                    ],
                            label='{:d}'.format(day)
                             ) for day in range(1, max_day)], 
                transition= dict(duration= 0 ),
                x=0,#slider starting position  
                y=0, 
                currentvalue=dict(font=dict(size=12), 
                                  prefix='Point: ', 
                                  visible=True, 
                                  xanchor= 'center'),  
                len=1.0)
           ]
layout.update(updatemenus=[dict(type='buttons', showactive=False,
                                y=0,
                                x=1.05,
                                xanchor='right',
                                yanchor='top',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Play',
                                              method='animate',
                                              args=[None, 
                                                    dict(frame=dict(duration=100, 
                                                                    redraw=True),
                                                         transition=dict(duration=0),
                                                         fromcurrent=True,
                                                         mode='immediate'
                                                        )
                                                   ]
                                             )
                                        ]
                               )
                          ],
              sliders=sliders);

fig=go.Figure(data=data, layout=layout, frames=frames)
import plotly.io as pio
pio.renderers.default = "browser"  ##offline plot
fig.show()


In [10]:
fig.write_html("SP_spread_animation_2.5k.html")
