In [1]:
import numpy as np
import pandas as pd
import geopandas as gpd
from datetime import datetime, date, timedelta

from bokeh.layouts import layout, column, row
from bokeh.models import ColumnDataSource, GeoJSONDataSource, DateSlider, Button, Patches, Select, Range1d, LinearColorMapper
from bokeh.plotting import figure
from bokeh.io import show, output_notebook, curdoc

from colorcet import fire as color_palette

from data_handler import DataHandler

output_notebook()

In [2]:
def app(doc):
    # load data
    dh = DataHandler()
    source = ColumnDataSource(dh.initial_view())
    geo_source = GeoJSONDataSource(geojson=dh.geo_data.to_json())
    
    # set inditial optins for View
    COUNTRYS = list(dh.country_iso.keys())
    INITIAL_COUNTRY = 'Europe'
    SELECT_LINE_OPTIONS = dh.fields
    INITIAL_LINE_VIEW = SELECT_LINE_OPTIONS[0]
    SELECT_COLOR_OPTIONS = dh.color_fields
    INITIAL_COLOR_VIEW = SELECT_COLOR_OPTIONS[0]
    RESTRICTION_CATEGORIES = dh.restriction_fields
    INITIAL_RESTR_VIEW = RESTRICTION_CATEGORIES[0]
    COLOR_PALETTE = color_palette
    
    # create line plots
    TOOLS_line = 'box_zoom,wheel_zoom,reset'
    lp_restr = figure(plot_width=800, 
                   plot_height=200,
                   x_axis_type="datetime", tools=TOOLS_line)
    
    line_plot = figure(plot_width=800, 
                   plot_height=200,
                   x_axis_type="datetime", tools=TOOLS_line)
    
    # draw lines
    lp_restr.line('date', 'restr', source=source)
    line_plot.line('date', 'line', source=source)
    
    # set initial axis ranges for line plots
    x_range_start = datetime(dh.date_range[0].year, dh.date_range[0].month, dh.date_range[0].day)
    x_range_end = datetime(dh.date_range[1].year, dh.date_range[1].month, dh.date_range[1].day)
    lp_restr.x_range.start = x_range_start
    lp_restr.x_range.end = x_range_end
    line_plot.x_range.start = x_range_start
    line_plot.x_range.end = x_range_end
    lp_restr.y_range = Range1d(0, dh.y_range_end['EUR'][INITIAL_RESTR_VIEW])
    line_plot.y_range = Range1d(0, dh.y_range_end['EUR'][INITIAL_LINE_VIEW])
    
    # create chloropleth
    TOOLS_chloropleth = 'hover,tap,pan'
    chloropleth = figure(plot_width=500, 
                         plot_height=400, 
                         tools=TOOLS_chloropleth, 
                         background_fill_color='#7dade0')
    chloropleth.grid.visible = False
    chloropleth.hover.tooltips = [
        ('Country', '@country'),
        ('Final Cases', '@cases'),
        ('Final Deaths', '@deaths'),
        ('Final Restrictions', '@all_restr'),
    ]
    color_mapper = LinearColorMapper(palette=COLOR_PALETTE)
    patches = Patches(xs="xs", ys="ys",
                        fill_alpha=0.7, 
                        fill_color={'field': INITIAL_COLOR_VIEW, 'transform': color_mapper},
                        line_color='white', 
                        line_width=0.3)
    chloropleth.add_glyph(geo_source, patches)
    
    # Create interactions
    # callback to update data and axis ranges when view changes
    def update(attrname, old, new):
        date = slider.value_as_datetime.date()
        source.data = dh.update_view(date, select_line.value, select_restr.value, dh.country_iso[select_country.value])
        line_plot.y_range.update(end=dh.y_range_end[dh.country_iso[select_country.value]][select_line.value])
        lp_restr.y_range.update(end=dh.y_range_end[dh.country_iso[select_country.value]][select_restr.value])
    
    # callback to update the color of the chloropleth
    def update_color(attrname, old, new):
        patches.update(fill_color={'field': select_color_by.value, 'transform': color_mapper})
        
    callback_id = None
    # callback to automate the slider
    def animate_update():
        global callback_id
        date = slider.value_as_datetime.date() + timedelta(days=1)
        if date >= dh.date_range[1]:
            date = dh.date_range[0]
        slider.value = date
        
    # function to animate the visualization with a play button
    def animate():
        global callback_id
        if button.label == '► Play':
            button.label = '❚❚ Pause'
            callback_id = curdoc().add_periodic_callback(animate_update, 200)
        else:
            button.label = '► Play'
            curdoc().remove_periodic_callback(callback_id)

    # slider
    slider = DateSlider(start=dh.date_range[0], end=dh.date_range[1], value=dh.date_range[0], step=1, title="Date")
    slider.on_change('value', update)
    
    # button
    button = Button(label='► Play', width=60)
    button.on_click(animate)
    
    # Selectors 
    select_restr = Select(title = 'Draw restrictions for category:', value=RESTRICTION_CATEGORIES[0], options=RESTRICTION_CATEGORIES)
    select_line = Select(title = 'Draw line plot by:', value=INITIAL_LINE_VIEW, options=SELECT_LINE_OPTIONS)
    select_restr.on_change('value', update)
    select_line.on_change('value', update)
    
    select_country = Select(title = 'Country:', value=INITIAL_COUNTRY, options=COUNTRYS)
    select_country.on_change('value', update)
    
    select_color_by = Select(title='Color by:', value=INITIAL_COLOR_VIEW, options=SELECT_COLOR_OPTIONS)
    select_color_by.on_change('value', update_color)
    
    # Create layout
    tools = column(button, slider, select_color_by, select_country, select_restr, select_line)
    top = row(chloropleth, tools)
    bottom = column(lp_restr, line_plot)
    layout = column(top, bottom)
    
    # Add to server
    doc.add_root(layout)

In [3]:
show(app)