In [3]:
from datetime import datetime
import requests

import pandas as pd
import numpy as np

In [54]:
from bokeh.plotting import figure, show, output_file, output_notebook

from bokeh.models import ColumnDataSource, CustomJS, Panel, Tabs
from bokeh.models.axes import LinearAxis, LogAxis
from bokeh.models.widgets import CheckboxGroup, Dropdown, RadioGroup, MultiSelect

from bokeh.layouts import column, row
from bokeh.palettes import viridis

from bokeh.application.handlers import FunctionHandler
from bokeh.application import Application

output_file('app.html')
output_notebook()

In [5]:
GH_STATES_DATA = pd.read_csv(r'K:\covid-19-data\us-states.csv', parse_dates=['date'])
GH_COUNTIES_DATA = pd.read_csv(r'K:\covid-19-data\us-counties.csv', parse_dates=['date'])

STATES = sorted(GH_STATES_DATA['state'].unique())

In [6]:
POP_DATA = pd.read_csv(r'K:\ACSDT5Y2018.B01003_data_with_overlays_2020-07-10T111915.csv')

In [7]:
def population(region):
    entry = POP_DATA[POP_DATA['NAME'] == region]
    if len(entry) != 0:
        return int(entry.values[0][2])
    else:
        return 1

In [100]:
TRACKING_DATA = pd.DataFrame.from_dict(requests.get(url='https://covidtracking.com/api/v1/states/daily.json').json())

TRACKING_DATA['datetime'] = [datetime.strptime(str(x), '%Y%m%d') for x in TRACKING_DATA['date']]
TRACKING_DATA['positivity'] = TRACKING_DATA['positive'] / TRACKING_DATA['totalTestResults'] * 100

STATE_ABBRV = {'Alabama': 'AL', 'Alaska': 'AK', 'Arizona': 'AZ', 'Arkansas': 'AR', 'California': 'CA',
               'Colorado': 'CO', 'Connecticut': 'CT', 'Delaware': 'DE', 'Florida': 'FL', 'Georgia': 'GA',
               'Hawaii': 'HI', 'Idaho': 'ID', 'Illinois': 'IL', 'Indiana': 'IN', 'Iowa': 'IA',
               'Kansas': 'KS', 'Kentucky': 'KY', 'Louisiana': 'LA', 'Maine': 'ME', 'Maryland': 'MD',
               'Massachusetts': 'MA', 'Michigan': 'MI', 'Minnesota': 'MN', 'Mississippi': 'MS', 'Missouri': 'MO',
               'Montana': 'MT', 'Nebraska': 'NE', 'Nevada': 'NV', 'New Hampshire': 'NH', 'New Jersey': 'NJ',
               'New Mexico': 'NM', 'New York': 'NY', 'North Carolina': 'NC', 'North Dakota': 'ND', 'Ohio': 'OH',
               'Oklahoma': 'OK', 'Oregon': 'OR', 'Pennsylvania': 'PA', 'Rhode Island': 'RI', 'South Carolina': 'SC',
               'South Dakota': 'SD', 'Tennessee': 'TN', 'Texas': 'TX', 'Utah': 'UT', 'Vermont': 'VT',
               'Virginia': 'VA', 'Washington': 'WA', 'West Virginia': 'WV', 'Wisconsin': 'WI', 'Wyoming': 'WY'}

In [115]:
def get_data(region, per_capita=False, data_type='cases'):

    data = dict()

    if data_type in ('cases', 'deaths'):

        subset = GH_STATES_DATA[GH_STATES_DATA['state'] == region]
        
        date_offset = np.timedelta64(3, 'D') + np.timedelta64(12, 'h')
        
        dates = subset['date']
        avg_dates = subset['date'] - date_offset
        
        data = subset[data_type].diff()
        avg_data = subset[data_type].diff().rolling(7).mean()
        
        if per_capita:
            pop = population(region)
            data = data / pop * 100000
            avg_data = avg_data / pop * 100000

        if per_capita:
            label = f'{data_type.title()} per 100,000'
        else:
            label = f'Total {data_type.title()}'
    
    elif data_type == 'positivity':

        subset = TRACKING_DATA[TRACKING_DATA['state'] == STATE_ABBRV[region]]
        
        date_offset = np.timedelta64(3, 'D') + np.timedelta64(12, 'h')
        
        dates = subset['datetime']
        avg_dates = subset['datetime'] + date_offset
        
        data = subset['positivity']
        avg_data = subset['positivity'].rolling(7).mean()
        
        label = 'Positivity (%)'
        

    return dates, avg_dates, data, avg_data, label

In [118]:
class StateDisplay:
    
    def __init__(self):

        self.state_selection = MultiSelect(title='States:', options=STATES, value=['Alabama'], height=550)
        self.per_capita = RadioGroup(labels=['Total', 'Per Capita'], active=0)
        self.data_getter = RadioGroup(labels=['Cases', 'Deaths', 'Positivity'], active=0)
        self.plot_type = RadioGroup(labels=['Linear', 'Logarithmic'], active=0)
        
        self.src = None
        self.p = None
        self.logp = None
    
    def make_dataset(self, state_list):

        by_state = {'date': [],
                    'avg_date': [],
                    'data': [],
                    'avg_data': [],
                    'state': [],
                    'color': []}

        palette = viridis(len(STATES))

        for state_name in state_list:

            per_capita = self.per_capita.active == 1
            data_getter = self.data_getter.labels[self.data_getter.active].lower()

            dates, avg_dates, data, avg_data, label = get_data(state_name, per_capita, data_getter)

            by_state['date'].append(dates.values)
            by_state['avg_date'].append(avg_dates.values)
            by_state['data'].append(data.values)
            by_state['avg_data'].append(avg_data.values)

            by_state['state'].append(state_name)
            by_state['color'].append(palette[STATES.index(state_name)])

        return label, ColumnDataSource(by_state)
    
    def make_plot(self):

        self.p = figure(title='COVID-19 Cases', x_axis_label='Date',
                        x_axis_type='datetime', y_axis_label='Total Cases')
            
        self.p.multi_line(source=self.src, xs='avg_date', ys='avg_data',
                          legend_field='state', color='color', line_width=2)

        self.p.legend.location = 'top_left'
    
        self.logp = figure(title='COVID-19 Cases', x_axis_label='Date',
                           x_axis_type='datetime', y_axis_label='Total Cases',
                           y_axis_type = 'log')
            
        self.logp.multi_line(source=self.src, xs='avg_date', ys='avg_data',
                             legend_field='state', color='color', line_width=2)

        self.logp.legend.location = 'top_left'

    def update(self, attr, old, new):

        states_to_plot = sorted(self.state_selection.value)

        label, new_src = self.make_dataset(states_to_plot)

        if self.src is None:
            self.src = new_src
            self.make_plot()
        else:
            self.src.data.update(new_src.data)

        if self.plot_type.active == 0:
            self.p.visible = True
            self.logp.visible = False
        else:
            self.p.visible = False
            self.logp.visible = True

        self.p.yaxis.axis_label = label
                
    def run(self, doc):

        self.state_selection.on_change('value', self.update)
    
        self.per_capita.on_change('active', self.update)
        self.data_getter.on_change('active', self.update)
        self.plot_type.on_change('active', self.update)

        controls = column([self.state_selection, self.per_capita,
                           self.data_getter, self.plot_type])

        self.update(None, None, None)
        
        plots = column(self.p, self.logp)

        layout = row(controls, plots)
        doc.add_root(layout)

In [121]:
show(Application(FunctionHandler(StateDisplay().run)))

In [119]:
class SingleStateDisplay:
    
    def __init__(self):

        self.state_selection = Dropdown(label='Select a state', menu=STATES)
        self.per_capita = RadioGroup(labels=['Total', 'Per Capita'], active=0)
        self.data_getter = RadioGroup(labels=['Cases', 'Deaths', 'Positivity'], active=0)
        self.plot_type = RadioGroup(labels=['Linear', 'Logarithmic'], active=0)
        
        self.state = None

        self.src = None
        self.p = None
        self.logp = None
    
    def make_dataset(self, state_name=''):

        per_capita = self.per_capita.active == 1
        data_getter = self.data_getter.labels[self.data_getter.active].lower()

        dates, avg_dates, data, avg_data, label = get_data(state_name, per_capita, data_getter)

        data = dict(date=dates.values, avg_date=avg_dates.values, data=data.values, avg_data=avg_data.values)

        return label, ColumnDataSource(data)
    
    def make_plot(self):

        self.p = figure(title='COVID-19 Cases', x_axis_label='Date',
                        x_axis_type='datetime', y_axis_label='Total Cases')
            
        self.p.vbar(source=self.src, x='date', top='data', color='orange')
        self.p.line(source=self.src, x='avg_date', y='avg_data', line_width=2)
        
        self.p.legend.visible = False

        self.logp = figure(title='COVID-19 Cases', x_axis_label='Date',
                           x_axis_type='datetime', y_axis_label='Total Cases',
                           y_axis_type='log')
            
        self.logp.vbar(source=self.src, x='date', bottom=1e-10, top='data', color='orange')
        self.logp.line(source=self.src, x='avg_date', y='avg_data', line_width=2)
        
        self.logp.legend.visible = False
    
    def update(self, attr, old, new):

        label, new_src = self.make_dataset(self.state)

        if self.src is None:
            self.src = new_src
            self.make_plot()
        else:
            self.src.data.update(new_src.data)

        if self.plot_type.active == 0:
            self.p.visible = True
            self.logp.visible = False
        else:
            self.p.visible = False
            self.logp.visible = True

        self.p.yaxis.axis_label = label
                
    def update_selection(self, event):
        self.state = event.item
        self.state_selection.label = self.state
        self.update(None, None, None)

    def run(self, doc):

        self.state_selection.on_click(self.update_selection)
    
        self.per_capita.on_change('active', self.update)
        self.data_getter.on_change('active', self.update)
        self.plot_type.on_change('active', self.update)

        controls = column([self.state_selection, self.per_capita, self.data_getter,
                           self.plot_type])

        self.update(None, None, None)

        plots = column(self.p, self.logp)

        layout = row(controls, plots)
        doc.add_root(layout)

In [120]:
show(Application(FunctionHandler(SingleStateDisplay().run)))