In [399]:
import pandas as pd
import numpy as np
import altair as alt

# Set up some global config and variables
alt.renderers.enable('default')
pd.options.mode.chained_assignment = None

def align(data):
    # Find the index with the lowest non-zero starting case number
    base = min(((idx, val) for (idx, val) in enumerate(data) if val[0] > 0), key = lambda x: x[1][0])[0]

    # Align everyone else to this base
    for i in [x for x in range(len(data)) if x != base]:
        padding = min(((x[0], abs(x[1] - data[i][0])) for x in enumerate(data[base])), key = lambda x: x[1])[0]
        data[i] = [0 for i in range(padding)] + data[i]
    return data

# Takes a list of numbers and returns a list with number of rows to double
def doubling(data):
    doubling = []
    for i in reversed(range(len(data))):
        index = -1
        for j in reversed(range(i)):
            if data[i] >= data[j]*2:
                index = j
                break
        doubling.append(i - index if index > 0 else np.nan)
    return list(reversed(doubling))

def streamgraph(df, by, value, sort, limit, stack='center'):
    zeros = df.groupby('Date', as_index=False).sum()
    zeros = zeros[zeros[value] <= 10].Date
    df = df[~df['Date'].isin(zeros)]
    
    df.loc[df[value]<0, value] = 0
    order = df[df['Date'] == days[-1]].sort_values(sort, ascending=False)
    top = order[by].values.tolist()
    data = df[df[by].isin(top[:limit])]
    others = df[df[by].isin(top[limit:])].groupby('Date', as_index=False).sum()
    others.insert(0, by, 'Others')
    data = data.append(others, ignore_index=True)
    selection = alt.selection_multi(fields=[by], bind='legend')

    return alt.Chart(data).mark_area().encode(
        alt.X('Date:T', axis=alt.Axis(domain=False, format='%d %b', tickSize=0)),
        alt.Y(value + ':Q', stack=stack, axis=alt.Axis(title=' '.join(value.split('_'))) if stack=='zero' else None),
        alt.Color(by + ':N', scale=alt.Scale(scheme='tableau20'), sort=top[:limit]),
        order=alt.Order('sort:Q'),
        tooltip=alt.Tooltip(['Date:T', by, value]),
        opacity=alt.condition(selection, alt.value(1), alt.value(0.3))
    ).transform_lookup(
        lookup=by,
        from_=alt.LookupData(order, by, [sort]),
        as_=['sort']
    ).properties(
        width=1200,
        height=800
    ).add_selection(selection)

def nontimegraph(df, by, value, sort, limit, highlight='', addon=[], xlog=True, ylog=True, rect=None):
    top = df[df['Date'] == days[-1]].sort_values(sort, ascending=False)[by].values.tolist()[:limit] + addon
    data = df[(df[value] > 1) & (df['Confirmed'] > 100) & (df[by].isin(top))]
    #data.loc[:,value] = data[value].rolling(window=3, win_type='gaussian').mean(std=3)
    selection = alt.selection_multi(fields=[by], bind='legend', init=[{by: highlight}], empty='none', nearest=True)

    # Background
    line = alt.Chart(data).mark_line().encode(
        alt.X('Confirmed:Q', scale=alt.Scale(type='log' if xlog else 'linear'), axis=alt.Axis(title='Cumulative Cases')),
        alt.Y(value + ':Q', scale=alt.Scale(type='log' if ylog else 'linear'), axis=alt.Axis(title=' '.join(value.split('_')))),
        color=alt.Color(by+':N', scale=alt.Scale(scheme='category10' if highlight == '' else 'set2')),
        size=alt.value(1),
        tooltip=alt.Tooltip(['Date:N', by, value])
    ).transform_filter(~selection)
    point = line.mark_circle().encode(
        alt.X('Confirmed:Q', scale=alt.Scale(type='log' if xlog else 'linear')),
        alt.Y(value + ':Q', scale=alt.Scale(type='log' if ylog else 'linear')),
        color=alt.Color(by+':N'),
        size=alt.value(60)
    ).transform_filter(
        alt.datum.Date == days[-1]
    ).transform_filter(~selection)
    text = point.mark_text(
        align='left',
        dx=7,  # Nudges text to right so it doesn't appear on top of the bar
    ).encode(
        text=by+':N',
        size=alt.value(10),
        color=alt.Color(by+':N')
    ).transform_filter(~selection)
    background = line+point+text

    # Foreground
    line = alt.Chart(data).mark_line().encode(
        alt.X('Confirmed:Q', scale=alt.Scale(type='log' if xlog else 'linear'), axis=alt.Axis(title='Cumulative Cases')),
        alt.Y(value + ':Q', scale=alt.Scale(type='log' if ylog else 'linear'), axis=alt.Axis(title=' '.join(value.split('_')))),
        color=alt.value('black'),
        size=alt.value(3),
        tooltip=alt.Tooltip(['Date:N', by, value])
    ).transform_filter(selection)
    point = line.mark_circle().encode(
        alt.X('Confirmed:Q', scale=alt.Scale(type='log' if xlog else 'linear')),
        alt.Y(value + ':Q', scale=alt.Scale(type='log' if ylog else 'linear')),
        color=alt.value('black'),
        size=alt.value(100)
    ).transform_filter(
        alt.datum.Date == days[-1]
    ).transform_filter(selection)
    text = point.mark_text(
        align='left',
        dx=7,  # Nudges text to right so it doesn't appear on top of the bar
    ).encode(
        text=by+':N',
        size=alt.value(15),
        color=alt.value('black')
    ).transform_filter(selection)
    foreground = line+point+text

    start = df[(df[by] == highlight) & (df['Date'] == rect[0])] if rect is not None else pd.DataFrame()
    end = df[(df[by] == highlight) & (df['Date'] == rect[1])] if rect is not None else pd.DataFrame()
    if not start.empty and not end.empty:
        marks = pd.DataFrame([{"start_x": start.Confirmed.values[0], "end_x": end.Confirmed.values[0], "start_y": start[value].values[0], "end_y": end[value].values[0]}])
        rule_rect = alt.Chart(marks).mark_rect(opacity=0.3).encode(x='start_x:Q', x2='end_x:Q')
        rule_point_1 = rule_rect.mark_circle().encode(
            alt.X('start_x:Q', scale=alt.Scale(type='log' if xlog else 'linear')),
            alt.Y('start_y:Q', scale=alt.Scale(type='log' if ylog else 'linear')),
            color=alt.value('grey'),
            size=alt.value(60)
        )
        rule_point_2 = rule_rect.mark_circle().encode(
            alt.X('end_x:Q', scale=alt.Scale(type='log' if xlog else 'linear')),
            alt.Y('end_y:Q', scale=alt.Scale(type='log' if ylog else 'linear')),
            color=alt.value('grey'),
            size=alt.value(60)
        )
        rule_text_1 = rule_point_1.mark_text(
            align='left',
            angle=90,
            dx=10
        ).encode(
            text=alt.value(rect[0]),
            size=alt.value(15),
            color=alt.value('black')
        )
        rule_text_2 = rule_point_2.mark_text(
            align='left',
            angle=90,
            dx=10
        ).encode(
            text=alt.value(rect[1]),
            size=alt.value(15),
            color=alt.value('black')
        )
        return (background+foreground+rule_rect+rule_point_1+rule_point_2+rule_text_1+rule_text_2).properties(
            width=1200,
            height=800
        ).add_selection(selection)
    else:
        return (background+foreground).properties(
            width=1200, 
            height=800
        ).add_selection(selection)        

In [400]:
# Read the raw data
df = pd.read_csv('jhu-daily-reports.csv')
df['Active']  = df.Confirmed - (df.Deaths + df.Recovered)
samples = df[['Date', 'Country']].groupby('Date').Country.nunique()
days = samples[samples > 1].index.tolist()
df = df[df['Date'].isin(days)]

# Global Data Visualizations

In [401]:
from ipywidgets import interact

# Aggregate at country level
country_level = df.groupby(['Country', 'Date'], as_index=False).sum()

# Drop 03-22-2020 since county breakdown started from here and so there is break in continuity
country_level = country_level.drop(country_level[country_level['Date'] == '03-22-2020'].index)
country_level = country_level.drop(country_level[country_level['Date'] == '03-12-2020'].index)
country_level = country_level.drop(country_level[country_level['Date'] == '02-12-2020'].index)

@interact(value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], limit=(0,50,1))
def chart(value='Confirmed_New', sort='Active', limit=10, zero=True):
    return streamgraph(country_level, 'Country', value, sort, limit, 'zero' if zero else 'center').interactive()

interactive(children=(Dropdown(description='value', index=4, options=('Active', 'Confirmed', 'Deaths', 'Recove…

In [402]:
country_level = country_level[country_level['Country'].isin(country_level[country_level['Confirmed'] > 10000]['Country'].unique())]
countries = country_level[country_level['Confirmed'] > 1000]['Country'].unique().tolist()

@interact(value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], limit=(1,50,1), add=countries, start=reversed(days), end=reversed(days))
def chart(value='Confirmed_New', sort='Active', limit=10, highlight='India', start='04-16-2020', end='05-01-2020', xlog=True, ylog=True):
    if start is not None and end is not None:
        rect = (start, end)
    return nontimegraph(country_level, 'Country', value, sort, limit, highlight=highlight, xlog=xlog, ylog=ylog, rect=(None, None)).interactive()

interactive(children=(Dropdown(description='value', index=4, options=('Active', 'Confirmed', 'Deaths', 'Recove…

# State Level Visualizations

In [403]:
countries = df[(df['State'].notnull()) & (df['Confirmed'] > 1000)]['Country'].unique()

def state_data(country):
    state_level = df[df['Country'] == country].groupby(['State', 'Date'], as_index=False).sum()
    if country == 'US':
        # Drop 03-22-2020 since county breakdown started from here and so there is break in continuity
        state_level = state_level.drop(state_level[state_level['Date'] == '03-22-2020'].index)
        state_level = state_level.drop(state_level[state_level['Date'] == '03-18-2020'].index)
        state_level = state_level[state_level['Date'] >= '03-01-2020']
    if country == 'India':
        state_level = state_level[state_level['Date'] >= '03-01-2020']
    return state_level

@interact(country=countries, value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], limit=(0,50,1))
def chart(country='US', value='Confirmed_New', sort='Active', limit=10, zero=True):
    return streamgraph(state_data(country), 'State', value, sort, limit, 'zero' if zero else 'center').interactive()

interactive(children=(Dropdown(description='country', index=7, options=('Australia', 'Canada', 'China', 'Denma…

In [404]:
@interact(country=countries, value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], limit=(1,50,1))
def chart(country='US', value='Confirmed_New', sort='Active', limit=10, xlog=True, ylog=True):
    return nontimegraph(state_data(country), 'State', value, sort, limit, addon=['NE', 'MN', 'CA'], xlog=xlog, ylog=ylog).interactive()

interactive(children=(Dropdown(description='country', index=7, options=('Australia', 'Canada', 'China', 'Denma…

# US County Level Visualizations

In [405]:
us_state_level = state_data('US')
states=us_state_level[us_state_level['Confirmed'] > 1000]['State'].unique()
@interact(value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], limit=(0,50,1), state=states)
def chart(state='CA', value='Confirmed_New', sort='Active', limit=10, zero=True):
    county_level = df[df['State'] == state].fillna({'County': state}).groupby(['County', 'Date'], as_index=False).sum()
    return streamgraph(county_level, 'County', value, sort, limit, 'zero' if zero else 'center').interactive()

interactive(children=(Dropdown(description='state', index=3, options=('AL', 'AR', 'AZ', 'CA', 'CO', 'CT', 'DC'…

In [406]:
states=us_state_level[us_state_level['Confirmed'] > 1000]['State'].unique()

@interact(value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], limit=(1,50,1), state=states)
def chart(state='CA', value='Confirmed_New', sort='Confirmed', limit=10, xlog=True, ylog=True):
    county_level = df[df['State'] == state].fillna({'County': state}).groupby(['County', 'Date'], as_index=False).sum()
    return nontimegraph(county_level, 'County', value, sort, limit, xlog=xlog, ylog=ylog).interactive()

interactive(children=(Dropdown(description='state', index=3, options=('AL', 'AR', 'AZ', 'CA', 'CO', 'CT', 'DC'…

In [422]:
@interact(value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], limit=(1,50,1), state = us_state_level['State'].unique())
def chart(value='Confirmed_New', sort='Confirmed', limit=10, xlog=True, ylog=True):
    df['Area'] = df.County + ', ' + df.State
    county_level = df.groupby(['Area', 'Date'], as_index=False).sum()
    return nontimegraph(county_level, 'Area', value, sort, limit, xlog=xlog, ylog=ylog).interactive()

interactive(children=(Dropdown(description='value', index=4, options=('Active', 'Confirmed', 'Deaths', 'Recove…

# Cross Country Comparisons

In [408]:
@interact(country=countries, value=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New'], sort=['Active', 'Confirmed', 'Deaths', 'Recovered', 'Confirmed_New', 'Deaths_New', 'Recovered_New', 'Slope'], limit=(1,50,1))
def chart(country='India', value='Confirmed_New', sort='Active', limit=10, xlog=True, ylog=True):
    data = state_data(country)
    us = state_data('US')
    data = pd.concat([data, us[us['State'] == 'NY']])
    return nontimegraph(data, 'State', value, sort, limit, xlog=xlog, ylog=ylog).interactive()

interactive(children=(Dropdown(description='country', index=4, options=('Australia', 'Canada', 'China', 'Denma…

# Slope Comparisons

In [409]:
from scipy import stats
from altair import datum

def slope_chart(data, by, offset, xscale='linear', limit=400, scale=1, value='Confirmed_New', window=7):
    source = data[data['Date'] == days[-1]]
    for var in source[by].unique():
        values = data[data[by] == var].sort_values('Date').tail(window)[['Confirmed', value]]
        slope, intercept, r_value, p_value, std_err = stats.linregress(values.Confirmed, values[value])
        source.loc[source[by] == var, 'Slope'] = slope
    source.fillna(0, inplace=True)
    source = source[source[value] > limit]

    base = alt.Chart(source).mark_point(filled=True, stroke='grey').encode(
        alt.X('Confirmed:Q', scale=alt.Scale(type=xscale), axis=alt.Axis(offset=offset)),
        y='Slope:Q',
        color=alt.Color(by+':N', scale=alt.Scale(scheme='category20'), legend=alt.Legend(columns=2, clipHeight=20, padding=10)),
        size=alt.Size(value+':Q', scale=alt.Scale(domain=[source.Confirmed_New.min(), source.Confirmed_New.max()], range=[100*scale, 3000*scale])),
        tooltip=[by, 'Confirmed', 'Slope', value]
    )
    text = base.mark_text().encode(
        text=by+':N',
        size=alt.value(12),
        color=alt.value('black')
    ).transform_filter(datum[value] > limit*2)
    regression = base.transform_regression('Confirmed', 'Slope', method="poly", order=1).mark_line(strokeDash=[6,8]).encode(color=alt.value('grey'), size=alt.value(2))

    return (base+text+regression)

In [420]:
data = state_data('US')
slope_chart(data, 'State', -364).properties(
    width=1200,
    height=800
).interactive()

In [411]:
state = alt.Chart(data[data['State'] == 'CA'].sort_values('Date').tail(21)).mark_line().encode(
    x='Confirmed:Q',
    y='Confirmed_New:Q'
)
reg = state.transform_regression("Confirmed", "Confirmed_New", method="linear").mark_line()
(state+reg).interactive()

In [412]:
slope_chart(country_level, 'Country', -160, xscale='log', limit=500, scale=5).properties(
    width=1200,
    height=800
).interactive()

In [413]:
data = country_level
state = alt.Chart(data[data['Country'] == 'France'].sort_values('Date').tail(7)).mark_line().encode(
    x='Confirmed:Q',
    y='Confirmed_New:Q'
)
reg = state.transform_regression("Confirmed", "Confirmed_New", method="linear").mark_line()
(state+reg).interactive()

In [426]:
state = 'CA'
county_level = df[df['State'] == state].fillna({'County': state}).groupby(['County', 'Date'], as_index=False).sum()

slope_chart(county_level, 'County', -300, xscale='log', limit=1, scale=5).properties(
    width=1200,
    height=800
).interactive()

In [415]:
data = county_level
state = alt.Chart(data[data['County'] == 'Alameda'].sort_values('Date').tail(30)).mark_line().encode(
    x='Confirmed:Q',
    y='Confirmed_New:Q'
)
reg = state.transform_regression("Confirmed", "Confirmed_New", method="linear").mark_line()
(state+reg).interactive()

In [416]:
dfh = pd.read_csv('https://covidtracking.com/api/v1/states/daily.csv')
dfh.date = pd.to_datetime(dfh.date, format='%Y%m%d')
dfh.date = dfh.date.dt.strftime('%m-%d-%Y')
dfh = dfh.rename({'date': 'Date', 'state':'State'}, axis=1)
data = state_data('US')
data = data.merge(dfh, on=['Date', 'State'], how='outer')

slope_chart(data, 'State', -582, xscale='log', limit=200, scale=2, value='hospitalizedCurrently', window=7).properties(
    width=1200,
    height=800
).interactive()