## Environment setup

In [1]:
import pandas as pd
from datetime import date
from os import path

from bokeh.plotting import figure, output_notebook, output_file, show
from bokeh.layouts import column, row
from bokeh.models import HoverTool, ColumnDataSource, Select
from bokeh.models.callbacks import CustomJS
from bokeh.palettes import Colorblind

## Methods that return our datasets

In [2]:
def get_us_dataset():
    
    filename = "../data/raw/us_confirmed_%s.csv" % date.today().isoformat()
    
    # If we don't already have a file today, download it and write it to the data directory
    if not path.exists(filename):
        covid_us = pd.read_csv('https://github.com/datasets/covid-19/raw/master/data/us_confirmed.csv')
        covid_us.to_csv(filename, index=False)
        
    # Read the US data from today's CSV into a big dataframe
    covid_us = pd.read_csv(filename)
    
    covid_us['Admin2'] = covid_us['Admin2'].fillna("None")
    
    return covid_us

In [3]:
def state_county_data(window=7):
    
    
    # Admin2 = County
    # We want case numbers broken down by State and County
    # Once the case numbers are grouped by (Date, State, County), there should only be a single value 
    # return from the 'Case' column. So the max() just returns that single value. But you need 
    # to have some kind of function there when using .groupby
    case_df = df[['Date', 'Province/State', 'Admin2', 'Case']].groupby(['Date','Province/State','Admin2']).max()
    
    # Turn the inner two indexes (State and Admin2=County) into columns.
    total_cases = case_df.unstack(level=[-2,-1])
    total_cases.columns = total_cases.columns.droplevel()
    
    # For some reason the index is not automatically recognized as a datetime
    # Need datetime format for the graph to display correctly
    total_cases.index = pd.to_datetime(total_cases.index)
    
    # Create new dataframes for new daily cases and rolling averages
    new_cases = total_cases.diff()
    rolling_avg = new_cases.rolling(window).mean()
    
    return total_cases.reset_index(level=0), new_cases.reset_index(level=0), rolling_avg.reset_index(level=0)

## Download the dataset

In [5]:
df = get_us_dataset()

## Create the dashboard

In [6]:
palette = Colorblind[6]

output_file("../county_comparison.html")

#########
# Create Data Sources for  County Data
#########
total_cases, new_cases, rolling_avg = state_county_data()
totals_source = ColumnDataSource(data=total_cases)
new_cases_source = ColumnDataSource(data=new_cases)
rolling_avg_source = ColumnDataSource(data=rolling_avg)


#########
# Create list of the counties in each state
# to use in the dropdown selection boxes
#########

state_list = df['Province/State'].unique().tolist()
county_list_dict = {state: df[df['Province/State']==state]['Admin2'].unique() 
                    for state in df['Province/State'].unique()}

# Each state has a different number of counties.
# Create a dataframe with states as the rows, then transpose it to have states as columns.
# This will fill in NaN at the end of the lists to make the columns all the same length.
county_list_df = pd.DataFrame.from_dict(county_list_dict, orient='index')
cs_source = ColumnDataSource(county_list_df.transpose())



######
# Create HoverTool
######

hovertool = HoverTool(
    tooltips=[
        ( 'date',   '@date{%F}'            ),
        ( 'New Cases',  '@new_cases' ), # use @{ } for field names with spaces
        ( '7 Day Avg', '@rolling_avg'      ),
    ],

    formatters={
        '@date'      : 'datetime', # use 'datetime' formatter for 'date' field
    },

    # display a tooltip whenever the cursor is vertically in line with a glyph
    #mode='vline'
)




##########
# Method for creating a plot along with state/county selection boxes
##########

def create_plot(init_state="Maryland", init_county="Prince George's"):
    
    sc_key = "{}_{}".format(init_state, init_county)
    county_source = ColumnDataSource(data=dict(date=total_cases['Date'], 
                                           total=totals_source.data[sc_key], 
                                           new_cases=new_cases_source.data[sc_key],
                                           rolling_avg=rolling_avg_source.data[sc_key]))
    
    
    ########
    # Create the plot
    #######

    county_plot = figure(title=init_county + " County, " + init_state, plot_width=600, plot_height=300, x_axis_type="datetime")
    #p5.line(x='date', y="total", source=county_source,
    #        color='navy', line_width=2, alpha=0.8, legend_label="Total Cases")
    county_plot.line(x='date', y="new_cases", source=county_source,
            color=palette[1], line_width=2, alpha=0.5, legend_label="New Cases")
    county_plot.line(x='date', y="rolling_avg", source=county_source,
            color=palette[0], line_width=2, alpha=0.8, legend_label="7 Day Average")
    county_plot.add_tools(hovertool)
    county_plot.legend.location = "top_left"



    #########
    # Create Widgets to select State and County
    #########

    select_state = Select(title="Choose a State:", value=init_state, options=state_list)
    select_county = Select(title="Choose a County:", value=init_county, options=list(county_list_dict[init_state]))


    #######
    # Define the JavaScript Callbacks
    #######


    state_callback = CustomJS(args=dict(cs_source=cs_source, 
                                        select_county=select_county, 
                                        source=county_source,
                                        total_source=totals_source, 
                                        new_source=new_cases_source,
                                        rolling_source=rolling_avg_source,
                                        county_plot=county_plot), code="""
        var f = cb_obj.value;
        var data = source.data;
    
        var county_list = cs_source.data[f];
        var county = county_list[0];
        var sc_key = f + "_" + county;
        var title = county_plot.title;
        if (county == 'None') {
            title.text = f;
        } else {
            title.text = county + " County, " + f;
        }
        
        for( var i = 0; i < county_list.length; i++){ 
            if ( county_list[i] === 'NaN') { county_list.length=i; break;}}
    
        select_county.options = county_list;
        
        
        for (var i = 0; i < data['total'].length; i++) {
            data['total'][i] = total_source.data[sc_key][i];
            data['new_cases'][i] = new_source.data[sc_key][i];
            data['rolling_avg'][i] = rolling_source.data[sc_key][i];
        }
    
        source.change.emit()
    """)


    county_callback = CustomJS(args=dict(select_state=select_state,
                              source=county_source, 
                              total_source=totals_source, 
                              new_source=new_cases_source,
                              rolling_source=rolling_avg_source,
                              county_plot=county_plot), code="""

        var county = cb_obj.value;
        var state = select_state.value;
        var sc_key = state + "_" + county;
        var data = source.data;
        var title = county_plot.title;
        if (county == 'None') {
            title.text = state;
        } else {
            title.text = county + " County, " + state;
        }
            
        for (var i = 0; i < data['total'].length; i++) {
            data['total'][i] = total_source.data[sc_key][i];
            data['new_cases'][i] = new_source.data[sc_key][i];
            data['rolling_avg'][i] = rolling_source.data[sc_key][i];
        }
    
        source.change.emit();
    """)
    

    select_county.js_on_change('value', county_callback)
    select_state.js_on_change('value', state_callback)

    return county_plot, select_state, select_county


###########
# Create a grid of four plots
###########

cplot1, ss1, sc1 = create_plot("Maryland", "Prince George's")
cplot2, ss2, sc2 = create_plot("District of Columbia", "District of Columbia")
cplot3, ss3, sc3 = create_plot("Virginia", "Arlington")
cplot4, ss4, sc4 = create_plot("Ohio", "Clermont")

tile1 = column(row(ss1, sc1), cplot1)
tile2 = column(row(ss2, sc2), cplot2)
tile3 = column(row(ss3, sc3), cplot3)
tile4 = column(row(ss4, sc4), cplot4)


# show the results
show(column( row(tile1, tile2),  row(tile3, tile4)  ))