In [1]:
import pandas as pd, matplotlib.pyplot as plt, ipywidgets as widgets, numpy as np, geopandas as gpd, folium, plotly, plotly.graph_objects as go
import requests, ipyleaflet, plotly.express as px, branca, urllib, os, datetime, sys
from ipywidgets import interact, interactive, fixed, interact_manual

plotly.io.templates.default = "plotly_white"

def cache(url, max_hours=24):
    filename = os.path.basename(urllib.parse.urlparse(url).path)
    hours = (datetime.datetime.today() - datetime.datetime.fromtimestamp(os.path.getmtime(filename))).total_seconds() / 3600 if os.path.exists(filename) else 100000000
    if hours > max_hours:
        r = requests.get(url)  
        with open(filename, 'wb') as f:
            f.write(r.content)

    return open(filename, 'rb')

# Census data
population = pd.read_excel(cache('https://www.ers.usda.gov/webdocs/DataFiles/48747/PopulationEstimates.xls?v=3011.3', max_hours=10000000), header=2, dtype={"FIPS": str})
population['County'] = population.apply(lambda x: '%s, %s' % (x.Area_Name, x.State), axis=1)
population['Population'] = population.POP_ESTIMATE_2018

# Covid data
covid = pd.read_csv(cache('https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-counties.csv'), dtype={"fips": str})
covid.loc[covid.county == 'New York City', 'fips'] = '36061'
df = population.merge(covid, left_on='FIPS', right_on='fips', how='left').fillna(0)
df = df[['date', 'FIPS', 'State', 'County', 'Population', 'deaths']].query('Population>0')
#df = df[~df.State.isin(['PR', 'AK', 'HA', 'HI'])]

# Add geometry
counties = gpd.read_file(cache('https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json', max_hours=10000000))
counties['FIPS'] = counties.apply(lambda x: f'{x.STATE}{x.COUNTY}', axis=1)
counties.loc[counties.id == '36061', 'geometry']  = counties[counties.id.isin(['36061', '36005', '36085', '36047', '36081'])].dissolve(by='STATE').geometry.values
df = counties.merge(df, on='FIPS', how='left')
df = df[['date', 'FIPS', 'State', 'County', 'Population', 'deaths', 'geometry']]

# Clean up
df.fillna(0, inplace=True)
df['date'] = pd.to_datetime(df.date)
df = df[df.date >= '2020-01-01']
df.rename(columns={'deaths': 'Cumulative Deaths', 'date': 'Date'}, inplace=True)
df.sort_values('Date', inplace=True)
df = df[df['Cumulative Deaths']>=10] # Only care about counties with at least 10 dead

# Calculate metrics
df['Cumulative Deaths/1M'] = 1000000 * df['Cumulative Deaths'] / df['Population']
df['Daily Deaths'] = df['Cumulative Deaths'] - df.groupby('FIPS')['Cumulative Deaths'].shift(1).fillna(0)
df['Daily Deaths/1M'] = 1000000 * df['Daily Deaths'] / df['Population']
df['Daily Increase (%)'] = 100 * df['Daily Deaths'] / (df['Cumulative Deaths'] - df['Daily Deaths']).map(lambda x: max(100, x))
df['Days Since 10 Deaths'] = df.reset_index().merge(df[['FIPS', 'Date']].groupby('FIPS').min().rename(columns={'Date': '10 dead'}).reset_index()).set_index('index').apply(lambda x: (x['Date'] - x['10 dead']).days, axis=1)

iso_lookups = (pd
    .read_csv(cache('https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/UID_ISO_FIPS_LookUp_Table.csv', max_hours=10000000), dtype={"FIPS": str})
    .rename(columns={'Country_Region': 'Country/Region'})
)

# Get deaths per country
deaths_world = (pd
    .read_csv(cache('https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv'))
    .drop(columns=['Province/State', 'Lat', 'Long'])
    .groupby('Country/Region').sum()
    .reset_index()
    .melt(id_vars=['Country/Region'], var_name='Date', value_name='Cumulative Deaths')
    .astype({'Date': 'datetime64[ns]'})
    .merge(iso_lookups.query("Province_State.isnull()")[['iso3', 'Country/Region', 'Population']], on='Country/Region')
    .drop(columns=['iso3'])
    .assign(Level='Country')
)

# Merge and calculate metrics
deaths = deaths_world
deaths['Cumulative Deaths/1M'] = 1000000 * deaths['Cumulative Deaths'] / deaths['Population']
deaths['Daily Deaths'] = deaths['Cumulative Deaths'] - deaths.groupby('Country/Region')['Cumulative Deaths'].shift(1).fillna(0)
deaths['Daily Deaths/1M'] = 1000000 * deaths['Daily Deaths'] / deaths['Population']
deaths['Daily Increase (%)'] = 100 * deaths['Daily Deaths'] / (deaths['Cumulative Deaths'] - deaths['Daily Deaths']).map(lambda x: max(100, x))
deaths['Days Since 10 Deaths'] = (
    deaths
    .reset_index()
    .merge(deaths.query('`Cumulative Deaths`>=10.0')[['Country/Region', 'Date']].groupby(['Country/Region']).min().rename(columns={'Date': '10 dead'}).reset_index())
    .set_index('index')
    .apply(lambda x: (x['Date'] - x['10 dead']).days, axis=1)
)
#deaths[deaths['Country/Region'] == 'US'].tail()

In [12]:
def table(level, x_axis, y_axis, window_size):
    d = deaths.query(f'Level=="{level}"')
    
    # Only look at regions with at least 50k residents
    d = d[d.Population > 50000] 
    
    # Sort the counties by the max y value
    regions = d.groupby('Country/Region').sum().sort_values(y_axis, ascending=False).index[:10]
    fig = go.FigureWidget(layout={'width': 690, 'height': 490, 'margin': {'b': 0, 'l': 0, 'r': 0, 't': 0}, 'autosize': False})

    for region in regions:
        v = d[d['Country/Region']==region]

        if x_axis == 'Date':
            x = pd.to_datetime(v[x_axis].values)
        elif x_axis == 'Days Since 10 Deaths':
            x = v[x_axis].values
        else:
            x = v[x_axis].map(lambda x: int(x))
            x = x.rolling(window_size).mean().fillna(0).map(lambda x: int(x)).values
        y = v[y_axis].map(lambda x: int(x))
        y = y.rolling(window_size).mean().fillna(0).map(lambda x: int(x)).values
        
        # Only look at 
        if x_axis == 'Days Since 10 Deaths':
            y = y[x>=0]
            x = x[x>=0]
        
        title = f'{region}'
        text = v.apply(lambda x: f'''
            {region}<br>
            {x.Date.date()}<br>
            Daily Deaths: {int(x['Daily Deaths'])}<br>
            Cumulative Deaths: {int(x['Cumulative Deaths'])}<br>
            Daily Deaths/1M: {int(x['Daily Deaths/1M'])}<br>
            Cumulative Deaths/1M: {int(x['Cumulative Deaths/1M'])}
            ''', axis=1).values
        fig.add_trace(go.Scatter(x=x, y=y, mode='lines+markers', text=text, name=title, hoverinfo='text+name', textposition='bottom right'))

    fig.update_annotations(dict(xref="x", yref="y", showarrow=True, arrowhead=7, ax=100, ay=0))
    fig.update_layout(
        title=None,
        xaxis_title=x_axis,
        yaxis_title=y_axis)
    
    display(fig)

In [13]:
def map(level, y_axis, window_size):
    d = (deaths
        .query(f'Level=="{level}"')
        .groupby('Country/Region')
        [['Cumulative Deaths', 'Cumulative Deaths/1M', 'Daily Deaths', 'Daily Deaths/1M', 'Daily Increase (%)']]
        .rolling(window_size)
        .mean()
        .fillna(0)
        .reset_index(level='Country/Region')
        .merge(deaths[['Date', 'Population', 'Level', 'Days Since 10 Deaths']], left_index=True, right_index=True)
    )

    # Add geometries
    deaths_world_geo = (gpd
        .read_file(cache('https://raw.githubusercontent.com/johan/world.geo.json/master/countries.geo.json', max_hours=10000000))
        .rename(columns={'id': 'iso3'})
        .merge(iso_lookups[['iso3', 'Country/Region']].drop_duplicates())
        .merge(d[d.Date == d.Date.max()])
        .drop(columns=['name', 'iso3', 'Date'])
    )

    m = ipyleaflet.Map(basemap=ipyleaflet.basemaps.CartoDB.Positron, center=(50.6252978589571, 0.34580993652344), zoom=2)

    step = int(np.power(10, np.floor(np.log10(deaths_world_geo[y_axis].max()))))
    stop = int(np.ceil(deaths_world_geo[y_axis].max() / step)*step)
    if stop/step < 5:
        step = int(step/2)
    colorscale = branca.colormap.linear.YlOrRd_09.to_step(index=range(0, stop+1, step))
    colorscale.caption = y_axis

    geo_json = ipyleaflet.GeoJSON(
        data=deaths_world_geo.__geo_interface__,
        style_callback=lambda x: {'fillColor': colorscale(x['properties'][y_axis])},
        style={'opacity': 1, 'weight': 0, 'fillOpacity': 0.7},
        hover_style={'weight': 1, 'dashArray': '9', 'fillOpacity': 0.5}
    )
    m.add_layer(geo_json)

    # Legend
    def col_to_bkg(c):
        return 'black' if (int(c[1:3], 16) + int(c[3:5], 16) + int(c[5:7], 16))/3 > 100 else 'white'

    colors = ''.join([
        f'<div style="color: {col_to_bkg(colorscale(i))}; text-align: center; background-color:{colorscale(i)}; width: 40px; float: left;">{i}</div><br>' 
        for i in range(0, stop+1, step)
    ])
    m.add_control(ipyleaflet.WidgetControl(widget=widgets.HTML(colors, layout={'margin': '0px 0px 0px 0px'}), position='topright'))

    # Hover
    html = widgets.HTML('Hover over a country/region', layout={'margin': '0px 10px 10px 20px;'})
    hover = ipyleaflet.WidgetControl(widget=html, position='bottomright')
    def update_html(properties, **kwargs):
        html.value = f"""
        <div style="margin: 0px 0px 10px 0px;">
            <div style="height: 16px"><b>{properties['Country/Region']}</b></div>
            <div style="height: 16px">Daily Deaths: {int(properties['Daily Deaths'])}</div>
            <div style="height: 16px">Daily Deaths/1M: {int(properties['Daily Deaths/1M'])}</div>
            <div style="height: 16px">Cumulative Deaths: {int(properties['Cumulative Deaths'])}</div>
            <div style="height: 16px">Cumulative Deaths/1M: {int(properties['Cumulative Deaths/1M'])}</div>
            <div style="height: 16px">Daily Increase: {int(properties['Daily Increase (%)'])}%</div>
            <div style="height: 16px">Population: {int(properties['Population']):,}</div>
        </div>
        """               
    geo_json.on_hover(update_html)
    geo_json.on_mouseover(lambda **x: m.add_control(hover))
    geo_json.on_mouseout(lambda **x: m.remove_control(hover))

    m.layout=widgets.Layout(border='1px black solid', height='500px', padding='0', margin='0')
    display(m)
    
def dropdown(options, description, value=None):
    if value:
        v = widgets.Dropdown(options=options, description='', layout=widgets.Layout(padding_left='10px', margin_left='10px', width='200px'), value=value)
    else:
        v = widgets.Dropdown(options=options, description='', layout=widgets.Layout(padding_left='10px', margin_left='10px', width='200px'))
    return v, widgets.VBox([
        widgets.Label(description), 
        v
    ])

level, level_widget = dropdown(options=['Country'], description='Level:')
x_axis, x_axis_widget = dropdown(options=['Days Since 10 Deaths', 'Cumulative Deaths/1M', 'Cumulative Deaths', 'Date'], description='X-axis:')
y_axis, y_axis_widget = dropdown(options=['Daily Deaths/1M', 'Daily Deaths', 'Cumulative Deaths/1M', 'Cumulative Deaths', 'Daily Increase (%)'], description='Y-axis:')
window_size, window_size_widget = dropdown(options=range(1, 15), value=7, description='Window size:') 
f = widgets.Label('', layout=widgets.Layout(width='5px'))
selectors_widget = widgets.HBox([level_widget, f, x_axis_widget, f, y_axis_widget, f, window_size_widget], layout=widgets.Layout(justify_content='flex-end', width='100%'))
title = widgets.HTML(f'<H1><NOBR>COVID-19 - {y_axis.value}</H1>')
def update_title(*args):
    title.value = f'<H1>COVID-19 - {y_axis.value}</H1>'
y_axis.observe(update_title, 'value')
header_widgets = widgets.HBox([title, selectors_widget])

table_widget = widgets.HBox(
    [widgets.interactive_output(table, {'level': level, 'x_axis': x_axis, 'y_axis': y_axis, 'window_size': window_size})], 
    layout=widgets.Layout(border='1px black solid', width='700px', height='500px', padding='0', margin='0')
)
map_widget = widgets.interactive_output(map, {'level': level, 'y_axis': y_axis, 'window_size': window_size})

widgets.AppLayout(
    header=header_widgets,
    left_sidebar=map_widget,
    center=widgets.Label(''),
    right_sidebar=table_widget,
    footer=None,
    pane_widths=[4, '10px', '700px'],
    pane_heights=['80px', 20, 0]
)    

AppLayout(children=(HBox(children=(HTML(value='<H1><NOBR>COVID-19 - Daily Deaths/1M</H1>'), HBox(children=(VBo…