In [1]:
from datetime import date, datetime, timedelta
from itertools import cycle
import os
import requests
import socket

import pandas as pd
import numpy as np

from tqdm import tqdm

PREFIX = 'C:\\Users\\watso' if socket.gethostname() == 'DESKTOP-VD3TK5G' else 'K:\\'

In [2]:
from bokeh.plotting import figure, show, output_file, output_notebook

from bokeh.io import export_png
from bokeh.models import ColumnDataSource, CustomJS, Panel, Tabs, ColorBar, LogColorMapper, LogTicker
from bokeh.models.axes import LinearAxis, LogAxis
from bokeh.models.widgets import CheckboxGroup, Dropdown, RadioGroup, MultiSelect, DatePicker, Button
from bokeh.events import MenuItemClick

from bokeh.layouts import column, row
from bokeh.palettes import viridis, Category20_20, linear_palette, Inferno256

from bokeh.application.handlers import FunctionHandler
from bokeh.application import Application

from bokeh.sampledata.us_states import data as US_STATES
from bokeh.sampledata.us_counties import data as US_COUNTIES

del US_STATES['HI']
del US_STATES['AK']

output_notebook()

In [3]:
def compute_states_data():
    GH_STATES_DATA.sort_values('date')
    for fips in tqdm(GH_STATES_DATA['fips'].unique()):

        slicer = GH_STATES_DATA['fips'] == fips
        subset = GH_STATES_DATA.loc[slicer, :]

        state = subset['state'].values[0]
        pop = population(state)

        avg_dates = subset['date'] - (timedelta(days=3) + timedelta(hours=12))
        diff_cases = subset['cases'].diff()
        diff_deaths = subset['deaths'].diff()
        avg_cases = subset['cases'].diff().rolling(7).mean()
        avg_deaths = subset['deaths'].diff().rolling(7).mean()

        GH_STATES_DATA.loc[subset.index, 'diff_cases'] = diff_cases
        GH_STATES_DATA.loc[subset.index, 'diff_deaths'] = diff_deaths
        GH_STATES_DATA.loc[subset.index, 'diff_cases_pc'] = diff_cases / pop * 100000
        GH_STATES_DATA.loc[subset.index, 'diff_deaths_pc'] = diff_deaths / pop * 100000
        GH_STATES_DATA.loc[subset.index, 'avg_dates'] = avg_dates
        GH_STATES_DATA.loc[subset.index, 'avg_cases'] = avg_cases
        GH_STATES_DATA.loc[subset.index, 'avg_deaths'] = avg_deaths
        GH_STATES_DATA.loc[subset.index, 'avg_cases_pc'] = avg_cases / pop * 100000
        GH_STATES_DATA.loc[subset.index, 'avg_deaths_pc'] = avg_deaths / pop * 100000

In [4]:
def compute_counties_data():
    GH_COUNTIES_DATA.sort_values('date')
    for fips in tqdm(GH_COUNTIES_DATA['fips'].unique()):

        if np.isnan(fips):
            continue

        slicer = GH_COUNTIES_DATA['fips'] == fips
        subset = GH_COUNTIES_DATA.loc[slicer, :]

        county, state = subset['county'].values[0], subset['state'].values[0]
        pop = population(f'{state}, {county}')

        avg_dates = subset['date'] - (timedelta(days=3) + timedelta(hours=12))
        diff_cases = subset['cases'].diff()
        diff_deaths = subset['deaths'].diff()
        avg_cases = subset['cases'].diff().rolling(7).mean()
        avg_deaths = subset['deaths'].diff().rolling(7).mean()

        GH_COUNTIES_DATA.loc[subset.index, 'diff_cases'] = diff_cases
        GH_COUNTIES_DATA.loc[subset.index, 'diff_deaths'] = diff_deaths
        GH_COUNTIES_DATA.loc[subset.index, 'diff_cases_pc'] = diff_cases / pop * 100000
        GH_COUNTIES_DATA.loc[subset.index, 'diff_deaths_pc'] = diff_deaths / pop * 100000
        GH_COUNTIES_DATA.loc[subset.index, 'avg_dates'] = avg_dates
        GH_COUNTIES_DATA.loc[subset.index, 'avg_cases'] = avg_cases
        GH_COUNTIES_DATA.loc[subset.index, 'avg_deaths'] = avg_deaths
        GH_COUNTIES_DATA.loc[subset.index, 'avg_cases_pc'] = avg_cases / pop * 100000
        GH_COUNTIES_DATA.loc[subset.index, 'avg_deaths_pc'] = avg_deaths / pop * 100000

In [5]:
POP_DATA = pd.read_csv(os.path.join(PREFIX, 'ACSDT5Y2018.B01003_data_with_overlays_2020-07-10T111915.csv'))

In [6]:
EMPTY_COUNTIES = {'Alaska': ['Borough', 'Census Area'],
                  'District of Columbia': ['District of Columbia'],
                  'Maryland': ['Baltimore city'],
                  'Virginia': ['Virginia Beach city', 'Alexandria city', 'Harrisonburg city', 'Charlottesville city',
                               'Williamsburg city', 'Richmond city', 'Newport News city', 'Norfolk city',
                               'Portsmouth city', 'Suffolk city', 'Danville city', 'Chesapeake city',
                               'Fredericksburg city', 'Manassas city', 'Hampton city', 'Lynchburg city',
                               'Poquoson city', 'Radford city', 'Bristol city', 'Galax city',
                               'Roanoke city', 'Hopewell city', 'Manassas Park city', 'Winchester city',
                               'Petersburg city', 'Franklin city', 'Waynesboro city', 'Salem city',
                               'Buena Vista city', 'Emporia city', 'Lexington city', 'Staunton city',
                               'Colonial Heights city', 'Fairfax city', 'Falls Church city',
                               'Norton city', 'Covington city'],
                  'Nevada': ['Carson City'],
                  'Missouri': ['St. Louis city'],}
REPLACE_COUNTIES = {'Alaska': {'Anchorage': 'Anchorage Municipality, Alaska'},
                    'New York': {'New York City': 'New York County, New York'},
                    'New Mexico': {'Doña Ana': 'Do�a Ana County, New Mexico'}}

def format_region_name(region):
    if ', ' in region:
        state, county = region.split(', ')
        county_name = 'County' if state != 'Louisiana' else 'Parish'
        if state in EMPTY_COUNTIES and (county in EMPTY_COUNTIES[state] or
                                        any(val in county for val in EMPTY_COUNTIES[state])):
            region = f'{county}, {state}'
        elif state in REPLACE_COUNTIES and county in REPLACE_COUNTIES[state]:
            region = REPLACE_COUNTIES[state][county]
        else:
            region = f'{county} {county_name}, {state}'
    return region

def parse_detailed_name(name):
    county, _, state = name.partition(' County, ')
    return state, county

def get_pop_entry(region):
    region = format_region_name(region)
    entry = POP_DATA[POP_DATA['NAME'] == region]
    return entry
    
def population(region):
    entry = get_pop_entry(region)
    try:
        pop = int(entry.values[0][2])
    except:
        print(f'Unable to find population of {region}!')
        pop = 1
    return pop

In [7]:
gh_states_data_file = os.path.join(PREFIX, 'covid-19-data', 'us-states.csv')
gh_counties_data_file = os.path.join(PREFIX, 'covid-19-data', 'us-counties.csv')

drop_states = ['Guam', 'Northern Mariana Islands', 'Virgin Islands', 'Puerto Rico']
drop_counties = drop_states + ['Hawaii', 'Alaska']

if not os.path.exists('us-states.csv') or os.stat(gh_states_data_file).st_mtime > os.stat('us-states.csv').st_mtime:
    GH_STATES_DATA = pd.read_csv(gh_states_data_file, parse_dates=['date'])
    for state in drop_states:
        GH_STATES_DATA.drop(GH_STATES_DATA[GH_STATES_DATA['state'] == state].index, inplace=True)
    compute_states_data()
    GH_STATES_DATA.to_csv('us-states.csv')
else:
    GH_STATES_DATA = pd.read_csv('us-states.csv', parse_dates=['date'])
    
if not os.path.exists('us-counties.csv') or os.stat(gh_counties_data_file).st_mtime > os.stat('us-counties.csv').st_mtime:
    GH_COUNTIES_DATA = pd.read_csv(gh_counties_data_file, parse_dates=['date'])
    for state in drop_counties:
        GH_COUNTIES_DATA.drop(GH_COUNTIES_DATA[GH_COUNTIES_DATA['state'] == state].index, inplace=True)
    compute_counties_data()
    GH_COUNTIES_DATA.to_csv('us-counties.csv')
else:
    GH_COUNTIES_DATA = pd.read_csv('us-counties.csv', parse_dates=['date'])

STATES = sorted(GH_STATES_DATA['state'].unique())
COUNTIES = sorted({f'{state}, {county}' for county, state in zip(GH_COUNTIES_DATA['county'], GH_COUNTIES_DATA['state'])})
#COUNTIES = list(filter(lambda region: not get_pop_entry(region).empty, COUNTIES))

100%|██████████████████████████████████████████████████████████████████████████████████| 51/51 [00:01<00:00, 39.01it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 3077/3077 [03:28<00:00, 14.75it/s]


In [8]:
TRACKING_DATA = pd.DataFrame.from_dict(requests.get(url='https://covidtracking.com/api/v1/states/daily.json').json())

TRACKING_DATA['datetime'] = [datetime.strptime(str(x), '%Y%m%d') for x in TRACKING_DATA['date']]
TRACKING_DATA['positivity'] = TRACKING_DATA['positive'] / TRACKING_DATA['totalTestResults'] * 100

STATE_ABBRV = {'Alabama': 'AL', 'Alaska': 'AK', 'Arizona': 'AZ', 'Arkansas': 'AR', 'California': 'CA',
               'Colorado': 'CO', 'Connecticut': 'CT', 'Delaware': 'DE', 'Florida': 'FL', 'Georgia': 'GA',
               'Hawaii': 'HI', 'Idaho': 'ID', 'Illinois': 'IL', 'Indiana': 'IN', 'Iowa': 'IA',
               'Kansas': 'KS', 'Kentucky': 'KY', 'Louisiana': 'LA', 'Maine': 'ME', 'Maryland': 'MD',
               'Massachusetts': 'MA', 'Michigan': 'MI', 'Minnesota': 'MN', 'Mississippi': 'MS', 'Missouri': 'MO',
               'Montana': 'MT', 'Nebraska': 'NE', 'Nevada': 'NV', 'New Hampshire': 'NH', 'New Jersey': 'NJ',
               'New Mexico': 'NM', 'New York': 'NY', 'North Carolina': 'NC', 'North Dakota': 'ND', 'Ohio': 'OH',
               'Oklahoma': 'OK', 'Oregon': 'OR', 'Pennsylvania': 'PA', 'Rhode Island': 'RI', 'South Carolina': 'SC',
               'South Dakota': 'SD', 'Tennessee': 'TN', 'Texas': 'TX', 'Utah': 'UT', 'Vermont': 'VT',
               'Virginia': 'VA', 'Washington': 'WA', 'West Virginia': 'WV', 'Wisconsin': 'WI', 'Wyoming': 'WY'}

In [9]:
def compute_log_palette(palette, low, high, value):
    if value >= high:
        return palette[-1]
    if value < low:
        return palette[0]
    log = np.log(value) - np.log(low)
    key = int(log * len(palette) / (np.log(high) - np.log(low)))
    return palette[key]

In [10]:
def get_dataset(region):

    pop_entry = get_pop_entry(region)
    
    if pop_entry['GEO_ID'].values[0].startswith('04'):
        data = GH_STATES_DATA[GH_STATES_DATA['state'] == region]
    elif pop_entry['GEO_ID'].values[0].startswith('05'):
        state, county = region.split(', ')
        data = GH_COUNTIES_DATA[(GH_COUNTIES_DATA['state'] == state).values & (GH_COUNTIES_DATA['county'] == county).values]

    return data

In [11]:
def get_data(region, per_capita=False, data_type='cases', constant_date=None):

    data = dict()
    test_data = None

    if data_type in ('cases', 'deaths'):

        subset = get_dataset(region)
        
        dates = subset['date']
        avg_dates = subset['avg_dates']
        
        if not per_capita:
            dt_label = data_type
            label = f'Total New {data_type.title()}'
        else:
            dt_label = f'{data_type}_pc'
            label = f'New {data_type.title()} per 100,000'

        data = subset[f'diff_{dt_label}']
        avg_data = subset[f'avg_{dt_label}']

    elif data_type in ('positivity', 'constant positivity', 'constant testing'):

        subset = TRACKING_DATA[TRACKING_DATA['state'] == STATE_ABBRV[region]].sort_values('date')
        
        date_offset = np.timedelta64(3, 'D') + np.timedelta64(12, 'h')
        
        dates = subset['datetime']
        avg_dates = dates - date_offset
        
        if data_type == 'positivity':
            data = subset['positivity']
            label = 'Positivity (%)'
        elif data_type == 'constant positivity':
            positivity = subset[subset['datetime'] == constant_date]['positivity'].values
            data = subset['positiveIncrease']
            test_data = (subset['totalTestResults'] * positivity / 100).diff().rolling(7).mean()
            label = 'Cases'
        elif data_type == 'constant testing':
            total_tests = subset[subset['datetime'] == constant_date]['totalTestResultsIncrease'].values
            data = subset['positiveIncrease']
            test_data = (subset['positivity'] * total_tests / 100).rolling(7).mean()
            label = 'Cases'

        if data_type != 'positivity' and per_capita:
            pop = population(region)
            data = data / pop * 100000
            
        avg_data = data.rolling(7).mean()

        if data_type != 'positivity':
            if per_capita:
                label = f'New {label.title()} per 100,000'
            else:
                label = f'Total New {label.title()}'

    return dates, avg_dates, data, avg_data, test_data, label

In [12]:
class StateDisplay:
    
    def __init__(self, dataset=STATES):

        self.dataset = dataset

        self.state_selection = MultiSelect(title='States:', options=self.dataset, value=['New York', 'Texas'], height=550)
        self.per_capita = RadioGroup(labels=['Total', 'Per Capita'], active=0)
        self.data_getter = RadioGroup(labels=['Cases', 'Deaths', 'Positivity', 'Constant Positivity',
                                              'Constant Testing'], active=0)
        self.plot_type = RadioGroup(labels=['Linear', 'Logarithmic'], active=0)
        
        self.constant_date = DatePicker(title='Constant Date', value=(datetime.today() - timedelta(days=1)).date())
        
        self.src = None
        self.p = None
        self.logp = None
    
    def make_dataset(self, state_list):

        by_state = {}

        color_cycle = cycle(Category20_20)
        palette = [next(color_cycle) for _ in self.dataset]

        for state_name in state_list:

            per_capita = self.per_capita.active == 1
            data_getter = self.data_getter.labels[self.data_getter.active].lower()
            constant_date = self.constant_date.value

            dates, avg_dates, data, avg_data, test_data, label = get_data(state_name, per_capita, data_getter, constant_date)

            by_state.setdefault('avg_date', []).append(avg_dates.values)
            by_state.setdefault('avg_data', []).append(avg_data.values)

            by_state.setdefault('state', []).append(state_name)
            by_state.setdefault('color', []).append(palette[self.dataset.index(state_name)])

        return label, ColumnDataSource(by_state)
    
    def make_plot(self):

        self.p = figure(title='COVID-19 Cases', x_axis_label='Date',
                        x_axis_type='datetime', y_axis_label='Total Cases')
            
        self.p.multi_line(source=self.src, xs='avg_date', ys='avg_data',
                          legend_field='state', color='color', line_width=2)

        self.p.legend.location = 'top_left'
    
        self.logp = figure(title='COVID-19 Cases', x_axis_label='Date',
                           x_axis_type='datetime', y_axis_label='Total Cases',
                           y_axis_type = 'log')
            
        self.logp.multi_line(source=self.src, xs='avg_date', ys='avg_data',
                             legend_field='state', color='color', line_width=2)

        self.logp.legend.location = 'top_left'

    def update(self, attr, old, new):

        states_to_plot = sorted(self.state_selection.value)

        label, new_src = self.make_dataset(states_to_plot)

        if self.src is None:
            self.src = new_src
            self.make_plot()
        else:
            self.src.data.update(new_src.data)

        if self.plot_type.active == 0:
            self.p.visible = True
            self.logp.visible = False
        else:
            self.p.visible = False
            self.logp.visible = True

        self.p.yaxis.axis_label = label
        self.logp.yaxis.axis_label = label
                
    def run(self, doc):

        self.state_selection.on_change('value', self.update)
    
        self.per_capita.on_change('active', self.update)
        self.data_getter.on_change('active', self.update)
        self.plot_type.on_change('active', self.update)
        self.constant_date.on_change('value', self.update)

        controls = column([self.state_selection, self.per_capita, self.data_getter, self.plot_type,
                           self.constant_date])

        self.update(None, None, None)
        
        plots = column(self.p, self.logp)

        layout = row(controls, plots)
        doc.add_root(layout)

In [13]:
show(Application(FunctionHandler(StateDisplay().run)))

In [14]:
class SingleStateDisplay:
    
    def __init__(self):
        
        self.state = 'New York'
        self.menu = STATES

        self.state_selection = Dropdown(menu=self.menu, label=self.state)
        self.per_capita = RadioGroup(labels=['Total', 'Per Capita'], active=0)
        self.data_getter = RadioGroup(labels=['Cases', 'Deaths', 'Positivity', 'Constant Positivity',
                                              'Constant Testing'], active=0)
        self.plot_type = RadioGroup(labels=['Linear', 'Logarithmic'], active=0)
        
        self.constant_date = DatePicker(title='Constant Date', value=(datetime.today() - timedelta(days=1)).date())
        self.show_constant_date = True

        self.src = None
        self.p = None
        self.logp = None
    
    def make_dataset(self, state_name=''):

        per_capita = self.per_capita.active == 1
        data_getter = self.data_getter.labels[self.data_getter.active].lower()
        constant_date = self.constant_date.value

        dates, avg_dates, data, avg_data, test_data, label = get_data(state_name, per_capita, data_getter, constant_date)

        data_dict = {'date': dates.values, 'avg_date': avg_dates.values, 'data': data.values, 'avg_data': avg_data.values}
        
        if test_data is None:
            data_dict['test_data'] = data.values.copy()
            data_dict['test_data'][:] = np.NaN
        else:
            data_dict['test_data'] = test_data.values

        return label, ColumnDataSource(data_dict)
    
    def make_plot(self):

        self.p = figure(title='COVID-19 Cases', x_axis_label='Date',
                        x_axis_type='datetime', y_axis_label='Total Cases')
            
        self.p.vbar(source=self.src, x='date', top='data', color='orange')
        self.p.line(source=self.src, x='avg_date', y='avg_data', line_width=2)
        self.p.line(source=self.src, x='date', y='test_data', line_width=2, line_dash='dashed')
        
        self.p.legend.visible = False

        self.logp = figure(title='COVID-19 Cases', x_axis_label='Date',
                           x_axis_type='datetime', y_axis_label='Total Cases',
                           y_axis_type='log')
            
        self.logp.vbar(source=self.src, x='date', bottom=1e-10, top='data', color='orange')
        self.logp.line(source=self.src, x='avg_date', y='avg_data', line_width=2)
        self.logp.line(source=self.src, x='date', y='test_data', line_width=2, line_dash='dashed')
        
        self.logp.legend.visible = False
    
    def update(self, attr, old, new):

        label, new_src = self.make_dataset(self.state)

        if self.src is None:
            self.src = new_src
            self.make_plot()
        else:
            self.src.data.update(new_src.data)

        if self.plot_type.active == 0:
            self.p.visible = True
            self.logp.visible = False
        else:
            self.p.visible = False
            self.logp.visible = True

        self.p.yaxis.axis_label = label
        self.logp.yaxis.axis_label = label
                
    def update_selection(self, event):
        self.state = event.item
        self.state_selection.label = self.state
        self.update(None, None, None)

    def run(self, doc):

        self.state_selection.on_click(self.update_selection)
    
        self.per_capita.on_change('active', self.update)
        self.data_getter.on_change('active', self.update)
        self.plot_type.on_change('active', self.update)
        self.constant_date.on_change('value', self.update)

        controls = [self.state_selection, self.per_capita, self.data_getter, self.plot_type]
        if self.show_constant_date:
            controls.append(self.constant_date)

        controls = column(controls)

        self.update_selection(MenuItemClick(None, self.state))

        plots = column(self.p, self.logp)

        layout = row(controls, plots)
        doc.add_root(layout)

In [15]:
show(Application(FunctionHandler(SingleStateDisplay().run)))

In [16]:
class CountyDisplay(StateDisplay):
    
    def __init__(self):

        super().__init__(COUNTIES)

        self.state_selection.title = 'Counties:'
        self.state_selection.value = ['New York, Washington', 'Texas, Harris']
        
        self.data_getter.labels = ['Cases', 'Deaths']
        
        self.show_constant_date = False

In [17]:
show(Application(FunctionHandler(CountyDisplay().run)))

In [18]:
class StateMap:
    
    def __init__(self):
        
        self.per_capita = RadioGroup(labels=['Total', 'Per Capita'], active=0, width=100)
        self.data_getter = RadioGroup(labels=['Cases', 'Deaths', 'Positivity'], active=0, width=100)
        self.date = DatePicker(title='Date', width=200)

        self.src = None
        self.p = None
        self.color_mapper = None

        dates = GH_STATES_DATA.loc[:, 'date']
        self.date.value = dates.max().date()
        self.date.enabled_dates = [(dates.min().date(), dates.max().date())]
        
        self.doc = None
        self.button = None
        self.callback = None
        self.counter = None

    def make_dataset(self):

        per_capita = self.per_capita.active == 1
        data_type = self.data_getter.labels[self.data_getter.active].lower()
        date = self.date.value

        data = np.empty(len(US_STATES))

        if data_type in ('cases', 'deaths'):

            if not per_capita:
                dt_label = data_type
                label = f'Total New {data_type.title()}'
            else:
                dt_label = f'{data_type}_pc'
                label = f'New {data_type.title()} per 100,000'

            subset = GH_STATES_DATA.loc[GH_STATES_DATA['date'] == date, :]
            for i, (abbrv, state) in enumerate(US_STATES.items()):
                state_name = state['name']
                value = subset.loc[subset['state'] == state_name, f'avg_{dt_label}']
                if not value.empty and not np.isnan(value.values[0]):
                    data[i] = max(0, value.values[0])
                else:
                    data[i] = 0

        maxval = GH_STATES_DATA.loc[:, f'avg_{dt_label}'].max()

        color_data = {'color': [compute_log_palette(Inferno256, maxval / 256, maxval, val) for val in data],
                      'value': data,
                      'name': [state['name'] for state in US_STATES.values()]}

        for state in US_STATES.values():
            color_data.setdefault('lons', []).append(state['lons'])
            color_data.setdefault('lats', []).append(state['lats'])
            
        return label, maxval, ColumnDataSource(color_data)

    def make_plot(self, maxval):
    
        tooltips = [('Name', '@name'),
                    ('Value', '@value')]
        
        self.color_mapper = LogColorMapper(palette='Inferno256', low=0, high=maxval)

        color_bar = ColorBar(color_mapper=self.color_mapper, ticker=LogTicker(),
                             label_standoff=12, border_line_color=None, location=(0,0))
        
        self.p = figure(toolbar_location="left", plot_width=950, aspect_ratio=1.3, tooltips=tooltips)

        self.p.patches(source=self.src, xs='lons', ys='lats', fill_color='color',
                       line_color='white', line_width=0.5)
        
        self.p.axis.visible = False
        self.p.grid.visible = False
        self.p.outline_line_color = None

        self.p.add_layout(color_bar, 'right')

    def update(self, attr, old, new):

        label, maxval, new_src = self.make_dataset()

        if self.src is None:
            self.src = new_src
            self.make_plot(maxval)
        else:
            self.src.data.update(new_src.data)

        self.p.title.text = label
        self.color_mapper.high = maxval
        
    def animate_update(self):
        self.counter += 1
        #export_png(self.p, filename=f'K:\\{self.__class__}_plot_{self.counter}.png')
        d = date.fromisoformat(self.date.value) + timedelta(days=1)
        self.date.value = d.isoformat()
        if d > self.date.enabled_dates[0][1] - timedelta(days=1):
            self.animate()
    
    def animate(self):
        if self.button.label == '► Play':
            self.button.label = '❚❚ Pause'
            self.counter = 0
            self.callback = self.doc.add_periodic_callback(self.animate_update, 200)
        else:
            self.button.label = '► Play'
            self.doc.remove_periodic_callback(self.callback)

    def run(self, doc):
    
        self.doc = doc
        
        self.per_capita.on_change('active', self.update)
        self.data_getter.on_change('active', self.update)
        self.date.on_change('value', self.update)

        self.update(None, None, None)

        self.button = Button(label='► Play', width=60)
        self.button.on_click(self.animate)

        controls = row([self.per_capita, self.data_getter, self.date, self.button])
        layout = column(self.p, controls)

        doc.add_root(layout)

In [19]:
show(Application(FunctionHandler(StateMap().run)))

In [20]:
class CountyMap(StateMap):
    
    def __init__(self):
        
        super().__init__()

        dates = GH_COUNTIES_DATA.loc[:, 'date']
        self.date.value = dates.max().date()
        self.date.enabled_dates = [(dates.min().date(), dates.max().date())]
    
    def make_dataset(self):

        per_capita = self.per_capita.active == 1
        data_type = self.data_getter.labels[self.data_getter.active].lower()
        date = self.date.value

        excluded = ('ak', 'hi', 'pr', 'gu', 'vi', 'mp', 'as')

        data = []

        if data_type in ('cases', 'deaths'):

            if not per_capita:
                dt_label = data_type
                label = f'Total New {data_type.title()}'
            else:
                dt_label = f'{data_type}_pc'
                label = f'New {data_type.title()} per 100,000'

            subset = GH_COUNTIES_DATA.loc[GH_COUNTIES_DATA['date'] == date, :]
            for abbrv, county in US_COUNTIES.items():
                if county['state'] not in excluded:
                    state_name, county_name = parse_detailed_name(county['detailed name'])
                    value = subset.loc[(subset['county'] == county_name).values &
                                       (subset['state'] == state_name).values, f'avg_{dt_label}']
                    if not value.empty and not np.isnan(value.values[0]):
                        data.append(max(0, value.values[0]))
                    else:
                        data.append(0)

        maxval = GH_COUNTIES_DATA.loc[:, f'avg_{dt_label}'].max()

        color_data = {'color': [compute_log_palette(Inferno256, maxval / 256, maxval, val) for val in data],
                      'value': data,
                      'name': [county['detailed name'] for county in US_COUNTIES.values()
                               if county['state'] not in excluded]}

        for county in US_COUNTIES.values():
            if county['state'] not in excluded:
                color_data.setdefault('lons', []).append(county['lons'])
                color_data.setdefault('lats', []).append(county['lats'])
            
        return label, maxval, ColumnDataSource(color_data)

In [21]:
show(Application(FunctionHandler(CountyMap().run)))

In [22]:
GH_COUNTIES_DATA[(GH_COUNTIES_DATA['state'] == 'Tennessee').values & (GH_COUNTIES_DATA['county'] == 'Lake').values]

Unnamed: 0,date,county,state,fips,cases,deaths,diff_cases,diff_deaths,diff_cases_pc,diff_deaths_pc,avg_dates,avg_cases,avg_deaths,avg_cases_pc,avg_deaths_pc
55893,2020-04-13,Lake,Tennessee,47095.0,4,0,,,,,2020-04-09 12:00:00,,,,
58596,2020-04-14,Lake,Tennessee,47095.0,4,0,0.0,0.0,0.000000,0.0,2020-04-10 12:00:00,,,,
61311,2020-04-15,Lake,Tennessee,47095.0,4,0,0.0,0.0,0.000000,0.0,2020-04-11 12:00:00,,,,
64038,2020-04-16,Lake,Tennessee,47095.0,4,0,0.0,0.0,0.000000,0.0,2020-04-12 12:00:00,,,,
66784,2020-04-17,Lake,Tennessee,47095.0,4,0,0.0,0.0,0.000000,0.0,2020-04-13 12:00:00,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
369239,2020-07-25,Lake,Tennessee,47095.0,706,0,2.0,0.0,26.574542,0.0,2020-07-21 12:00:00,1.000000,0.0,13.287271,0.0
372451,2020-07-26,Lake,Tennessee,47095.0,710,0,4.0,0.0,53.149083,0.0,2020-07-22 12:00:00,1.428571,0.0,18.981815,0.0
375664,2020-07-27,Lake,Tennessee,47095.0,716,0,6.0,0.0,79.723625,0.0,2020-07-23 12:00:00,2.000000,0.0,26.574542,0.0
378880,2020-07-28,Lake,Tennessee,47095.0,721,0,5.0,0.0,66.436354,0.0,2020-07-24 12:00:00,2.714286,0.0,36.065449,0.0


In [23]:
population('Tennessee, Trousdale')

9573

In [24]:
897/9573*100000

9370.103415857098