# Import Packages

In [None]:
import datetime as dt
import json

import folium
import numpy as np
import pandas as pd
import panel as pn
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

pn.extension("plotly", sizing_mode="stretch_width")

# Load Data

In [None]:
if "data" not in pn.state.cache.keys():
    # Cache data to speed up
    covid = pd.read_csv(
        "https://covid.ourworldindata.org/data/owid-covid-data.csv",
        usecols=[
            "date",
            "iso_code",
            "location",
            "new_cases_per_million",
            "new_deaths_per_million",
            "people_fully_vaccinated",
            "population",
            "stringency_index",
        ],
        index_col="date",
        parse_dates=True,
    )

    covid["vaccination_rate_fully_vaccinated"] = covid["people_fully_vaccinated"] / covid["population"]
    pn.state.cache["data"] = covid.copy()
else:
    covid = pn.state.cache["data"]

# Data Preprocessing

In [None]:
# Helper function for transforming column to rolling average values
def transform_column_to_rolling(df, column_to_transform, interval, column_to_groupby):
    return df.groupby(column_to_groupby)[column_to_transform].rolling(interval).mean().dropna().reset_index()

# Plots

## Plot 1: World Map

In [None]:
def get_map(date, metric, interval):
    # Index to the specific date based on the date slider
    formatted_date = dt.datetime.strftime(date, format="%Y-%m-%d")

    # Default color map for new cases or new deaths
    fill_color = "YlOrRd"

    if metric == "vaccination_rate_fully_vaccinated":
        result_df = covid.query("(date==@formatted_date) and (0 <= vaccination_rate_fully_vaccinated <= 1)")
        # Change to green based color map for vaccination rate
        fill_color = "YlGn"

    else:
        # Transform new case/death metric to rolling average data
        # Rolling one day average simply means daily data
        result_df = transform_column_to_rolling(covid, metric, interval, "iso_code").query("date==@formatted_date")

    # Load geojson data
    with open("world.geojson", "r") as f:
        geo_json_data = json.load(f)

    m = folium.Map(height=500)
    cp = folium.Choropleth(
        geo_data=geo_json_data,
        name="choropleth",
        data=result_df,
        columns=["iso_code", metric],
        key_on="feature.properties.ISO_A3",
        fill_color=fill_color,
        fill_opacity=0.7,
        line_opacity=0.2,
        highlight=True,
        bins=8,
        nan_fill_color="grey",
    ).add_to(m)

    # fmt: off
    # Create a Series indexed with iso code so we can lookup metric values
    # NAN values are replaced by "Not Recorded"
    metric_value_s = (
        result_df
        .set_index('iso_code')
        [metric]
        .round(2)
    )

    # fmt: on
    # Convert Vaccination rate less than 1% to "<= 0.01"
    # Convet NAN values to "Not Recorded"
    if metric == "vaccination_rate_fully_vaccinated":
        metric_value_s = metric_value_s.where(lambda s: ~(s <= 0.01), "<= 0.01").replace(np.nan, "Not Recorded")

    # Convert new cases or deaths less than 1 to "<= 1"
    # Convet NAN values to "Not Recorded"
    else:
        metric_value_s = metric_value_s.where(lambda s: ~(s <= 1), "<= 1").replace(np.nan, "Not Recorded")

    # Add Hover tips to show location name and metric value
    for s in cp.geojson.data["features"]:
        iso_code = s.get("properties").get("ISO_A3")
        metric_value = metric_value_s.get(iso_code, "Not Recorded")
        s["properties"][metric] = metric_value

    folium.GeoJsonTooltip(
        fields=["NAME_LONG", metric],
        aliases=["Location", metric.replace("_", " ").title()],
    ).add_to(cp.geojson)

    folium.LayerControl().add_to(m)
    return m

## Plot 2: Covid Explorer

In [None]:
def get_explorer_lineplot(locations, metric, interval):
    # make ylabel title case
    ylabel = metric.replace("_", " ").title()

    # Return an empty figure when no locations selcted
    if not locations:
        return go.Figure()

    # Get vaccination rate for selected location
    # No rolling average is required, as it's cumlative
    if metric == "vaccination_rate_fully_vaccinated":
        result_df = covid.query("location in @locations").loc[lambda df: df[metric].notna()].reset_index()
        ylabel = "Vaccination Rate (Fully Vaccinated)"

    # For new case/death metric, transform the column to rolling average
    else:
        result_df = transform_column_to_rolling(covid, metric, interval, "location").query("location in @locations")

    # Line plot for metric, colored by location
    fig = px.line(
        data_frame=result_df,
        x="date",
        y=metric,
        color="location",
    )

    fig.update_xaxes(title_text="", showgrid=False)
    fig.update_yaxes(title_text=ylabel, showgrid=True)
    fig.update_traces(mode="lines", hovertemplate="%{y:.2f}")

    fig.update_layout(margin=dict(t=20, b=20, l=20, r=20), hovermode="x unified", autosize=True)

    return fig

## Plot 3: Stringency Index

In [None]:
def get_stringency_index_lineplot(locations: list):
    locations = sorted(locations)
    num_locations = len(locations)

    # Decide how many subplots needs to be created based on selcted number of locations
    # Return an empty figure when no locations selcted
    if not locations:
        return go.Figure()
    elif num_locations == 1:
        fig = make_subplots(rows=1, cols=1, shared_xaxes=True, shared_yaxes=True)
    elif num_locations % 2 == 0:
        fig = make_subplots(rows=int(num_locations / 2), cols=2, shared_xaxes=True, shared_yaxes=True)
    elif num_locations % 2 == 1:
        fig = make_subplots(rows=int(np.ceil(num_locations / 2)), cols=2, shared_xaxes=True, shared_yaxes=True)

    # Assign each location a row, column position in the subplots
    position_di = {}
    for idx, location in enumerate(locations):
        position_di.setdefault(location, {})
        idx += 1
        if idx == 1:
            position_di[location]["row"] = 1
            position_di[location]["col"] = 1
        elif idx % 2 == 0:
            position_di[location]["row"] = int(idx / 2)
            position_di[location]["col"] = 2
        elif idx % 2 == 1:
            position_di[location]["row"] = int(np.ceil(idx / 2))
            position_di[location]["col"] = 1

    # Find consecutive dates given a dataframe which has date indexed
    # This function will be helpful when we find maximum/minimum stringency index periods
    # Referenced from: https://stackoverflow.com/questions/71311169/grouping-by-consecutive-dates-into-date-ranges-using-python
    def find_consecutive_dates(df):
        return (
            df.groupby(df.groupby("stringency_index")["date"].diff().ne(pd.Timedelta(days=1)).cumsum())["date"]
            .agg(**{"start": "first", "end": "last"})
            .reset_index()
        )

    # Find consecutive dates for each location's maximum/minimum stringency index periods
    # Once the consecutive dates have been found, use rectangle areas to highlight those periods
    for location in locations:
        result_df = covid.query("location==@location").reset_index()
        max_stringency_index = result_df.query("stringency_index == stringency_index.max()")
        min_stringency_index = result_df.query("stringency_index == stringency_index.min()")

        max_consecutive_dates = find_consecutive_dates(max_stringency_index)
        min_consecutive_dates = find_consecutive_dates(min_stringency_index)

        # Line plot of stringency_index for each selcted location
        fig.add_trace(
            go.Scatter(x=result_df["date"], y=result_df["stringency_index"], name=location, showlegend=True),
            row=position_di.get(location).get("row"),
            col=position_di.get(location).get("col"),
        )

        # Highlight maximum stringency index periods with red shaded area
        for start, end in zip(max_consecutive_dates["start"], max_consecutive_dates["end"]):
            fig.add_vrect(
                x0=start,
                x1=end,
                line_width=0,
                fillcolor="red",
                opacity=0.2,
                row=position_di.get(location).get("row"),
                col=position_di.get(location).get("col"),
            )

        # Highlight minimum stringency index periods with green shaded area
        for start, end in zip(min_consecutive_dates["start"], min_consecutive_dates["end"]):
            fig.add_vrect(
                x0=start,
                x1=end,
                line_width=0,
                fillcolor="green",
                opacity=0.2,
                row=position_di.get(location).get("row"),
                col=position_di.get(location).get("col"),
            )

    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(title_text="Stringency Index")
    fig.update_layout(
        legend_title="location",
        hovermode="x unified",
        margin=dict(t=20, b=20, l=20, r=20),
        autosize=True
    )
    return fig

# Create Dashboard 

## Widgets

In [None]:
# Widget for selecting locations
locations_selector = pn.widgets.MultiChoice(
    name="Locations (Type to search locations)",
    value=["Australia", "China", "Germany", "United States"],
    options=covid["location"].unique().tolist(),
    solid=False,
    max_items=4,
    option_limit=4
)

# Widget for selecting metric
metric_selector = pn.widgets.Select(
    name="Metric",
    options={
        "New cases per million people": "new_cases_per_million",
        "New deaths per million people": "new_deaths_per_million",
        "People fully vaccinated": "vaccination_rate_fully_vaccinated",
    },
)

# Widget for selecting interval
interval_selector = pn.widgets.Select(
    name="Interval",
    options={"New per day": 1, "7-day rolling average": 7, "14-day rolling average": 14, "Cumulative": "Cumulative"},
    value=7,
    disabled_options=["Cumulative"]
)


@pn.depends(metric=metric_selector, watch=True)
def _update_interval_disabled(metric):
    if metric == "vaccination_rate_fully_vaccinated":
        interval_selector.disabled_options = []
        interval_selector.value = "Cumulative"
        interval_selector.disabled = True
    else:
        interval_selector.value = 7
        interval_selector.disabled_options = ["Cumulative"]
        interval_selector.disabled = False


# Define start and end date for Date slider
start = covid["new_cases_per_million"].loc[lambda s: s > 0].sort_index().index.date[0]
end = covid["new_cases_per_million"].loc[lambda s: s > 0].sort_index().index.date[-1]

date_slider = pn.widgets.DateSlider(start=start, end=end, value=end, name="Date")

## Dashboard Body

In [None]:
# Function for adding headings to plots
def get_markdown_words(metric):
    if metric == "new_cases_per_million":
        return pn.pane.Markdown("""## Daily new confirmed COVID-19 cases per million people""")
    elif metric == "new_deaths_per_million":
        return pn.pane.Markdown("""## Daily new confirmed COVID-19 deaths per million people""")
    elif metric == "vaccination_rate_fully_vaccinated":
        return pn.pane.Markdown("""## Share of people who are fully vaccinated""")

In [None]:
template = pn.template.FastListTemplate(
    title="Covid-19 Dashboard",
    accent="#96ceb4",
    theme_toggle=False,
    corner_radius=5
)

# Widgets
part1_widgets = pn.Column(
    pn.pane.Markdown("## Settings"),
    pn.Column(metric_selector),
    pn.Column(interval_selector),
    pn.Column(locations_selector),
    width=250
)

# World map
worldmap = pn.Column(
    pn.panel(
        pn.bind(get_map, date_slider, metric_selector, interval_selector),
        min_width=480,
        sizing_mode="stretch_both"
    ),
    pn.Column(date_slider)
)

# Covid-19 explorer
explorer = pn.Column(
    pn.panel(
        pn.bind(get_explorer_lineplot, locations_selector, metric_selector, interval_selector)
    )
)

# Organise map and explorer in the same row
combined_plots = pn.Column(
    pn.Row(pn.bind(get_markdown_words, metric_selector), sizing_mode="stretch_both"),
    pn.Row(worldmap, explorer, sizing_mode="stretch_both")
)

template.main.append(pn.Row(part1_widgets, combined_plots))

# Widget for Stringency Index plot
part2_widgets = pn.Column(pn.pane.Markdown("## Settings"), pn.Column(locations_selector), width=250)

# Create Stringency Index plot
si_index_heading = pn.Column(
    pn.pane.Markdown(
        """
                ## COVID-19: Stringency Index
                * The stringency index is a composite measure, rescaled to a value from 0 to 100 (100 = strictest).
                * Red shaded area indicates maximum stringency index periods.
                * Green shaded area indicates minimum stringency index periods.
                """
    )
)

si_index_plot = pn.Column(pn.panel(pn.bind(get_stringency_index_lineplot, locations_selector)))

template.main.append(pn.Row(part2_widgets, pn.Column(si_index_heading, si_index_plot, sizing_mode="stretch_both")))
template.servable()