In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# import ipywidgets as widgets
from ipywidgets import interact, interactive, Layout, FloatSlider, HBox, VBox
# from ipywidgets.embed import embed_minimal_html, embed_data

import matplotlib.dates as mdates
# import matplotlib.ticker as ticker

from functools import reduce
from covis.utils import get_project_root

sns.set_style("white")
sns.set_palette("tab20")

# regional threshold plot for percentage of deaths due to COVID-19

In [None]:
long_reg = pd.read_csv(
    get_project_root() / "output/long_form_regional_weekly_deaths.csv",
    index_col=0
)

In [None]:
print(long_reg.shape)
long_reg.head()

In [None]:
long_reg["Week ended"] = pd.to_datetime(long_reg["Week ended"])
long_reg.set_index("Week ended", inplace=True)

In [None]:
long_reg = long_reg.loc["2020/03/01 00:00:00":"2021/07/01 23:59:59"]

In [None]:
long_reg["above thresh"] = long_reg["pc deaths due to covid19"].apply(lambda x: 1 if x>0.04 else 0)

In [None]:
# long_reg.head(50)

In [None]:
fig = plt.figure(figsize=(13,5))
ax = fig.add_subplot(111)

sns.stripplot(
    data=long_reg,
    x="Week ended",
    y="region",
    jitter=False,
    hue="above thresh",
    legend=False,
    orient="h",
    ax=ax
)

ax.set_ylabel("")
fig.tight_layout();

In [None]:
def t_plot(thresh, save=False):

    long_reg["above thresh"] = long_reg["pc deaths due to covid19"].apply(lambda x: 1 if x>=thresh else 0)

    fig = plt.figure(figsize=(13,5))
    ax = fig.add_subplot(111)

    ax = sns.stripplot(
        data=long_reg,
        x="Week ended",
        y="region",
        jitter=False,
        hue="above thresh",
        palette=["lightgrey", "tab:red"],
        orient="h",
        ax=ax
    )

    ax.set_ylabel("")
    ax.set_xlabel("")

    legend_handles, _= ax.get_legend_handles_labels()
    ax.legend(
        legend_handles,
        [
            f"weeks with percentage of deaths due to COVID19 below {100*thresh:.0f}%",
            f"weeks with percentage of deaths due to COVID19 above {100*thresh:.0f}%"
            ],
        loc="lower center",
        ncols=2,
        bbox_to_anchor=(0.5,-0.2)
    )

    fig.suptitle(f"weekly percentage of deaths due to COVID19 > {100*thresh:.0f}%")
    fig.tight_layout();

    if save is True:
        fig.savefig(
            get_project_root() / f"figures/t_plot_{thresh:.2f}_dots.png"
        )

In [None]:
t_plot(0.3)
# t_plot(0.3, save=True)

In [None]:
# remove the save argument:
def t_plot(thresh):

    long_reg["above thresh"] = long_reg["pc deaths due to covid19"].apply(lambda x: 1 if x>=thresh else 0)

    fig = plt.figure(figsize=(13,5))
    ax = fig.add_subplot(111)

    ax = sns.stripplot(
        data=long_reg,
        x="Week ended",
        y="region",
        jitter=False,
        hue="above thresh",
        palette=["lightgrey", "tab:red"],
        orient="h",
        ax=ax
    )

    ax.set_ylabel("")
    ax.set_xlabel("")

    legend_handles, _= ax.get_legend_handles_labels()
    ax.legend(
        legend_handles,
        [
            f"weeks with percentage of deaths due to COVID19 below {100*thresh:.0f}%",
            f"weeks with percentage of deaths due to COVID19 above {100*thresh:.0f}%"
            ],
        loc="lower center",
        ncols=2,
        bbox_to_anchor=(0.5,-0.2)
    )

    fig.suptitle(f"weekly percentage of deaths due to COVID19 > {100*thresh:.0f}%")
    fig.tight_layout();

# create the interactive plot:
def t_interactive():
    style = {"description_width": "initial"}
    slider_opts = Layout(
        width="100%",
        position=""
    )
    t_fig = interactive(
        t_plot,
        thresh=FloatSlider(
            min=long_reg["pc deaths due to covid19"].min(),
            max=long_reg["pc deaths due to covid19"].max(),
            step=1e-2,
            value=0.25,
            layout=slider_opts,
            description="percentage of deaths due to COVID-19",
            style=style
        )
    )

    # arranging widget position
    # https://stackoverflow.com/questions/52980565/arranging-widgets-in-ipywidgets-interactive/53048425#53048425?newreg=b506e493e9584daf84eb7939c302c949
    controls = HBox(t_fig.children[:-1], layout = Layout(flex_flow='row wrap'))
    output = t_fig.children[-1]
    display(VBox([output, controls]))

In [None]:
t_interactive()

In [None]:
# target = get_project_root() / "output/threshold_slider.html"
# embed_minimal_html(target, views=[t_fig], title="threshold slider")

In [None]:
# target = get_project_root() / "output/threshold_slider.html"
# with open(target, "w") as f:
#     f.write(t_fig)

In [None]:
# slider = FloatSlider(value=40)
# embed_minimal_html('../output/export.html', views=[slider], title='Widgets export')

## developing the plot

In [None]:
def t_plot(thresh, save=False):

    long_reg["above thresh"] = long_reg["pc deaths due to covid19"].apply(lambda x: 1 if x>=thresh else 0)

    fig = plt.figure(figsize=(13,5))
    ax = fig.add_subplot(111)

    # a line can pick out consecutive weeks above threshold:
    reg_indicators = {}
    for reg in long_reg["region"].unique():
        reg_thresh = long_reg.groupby("region").get_group(reg).reset_index()
        # print(reg)
        lis = reg_thresh.loc[reg_thresh["above thresh"] == 1].index
        if len(lis) > 0:
            t_groups = reduce(lambda x,y : x[:-1]+[x[-1]+[y]] if (x[-1][-1]+1==y) else [*x,[y]], lis[1:] , [[lis[0]]] )
            # print(t_groups)
            sub_gs = []
            for sub_g in t_groups:
                sub_g = [sub_g[0], sub_g[-1]]
                sub_g_date = [reg_thresh["Week ended"][i] for i in sub_g]
                # print(reg)
                # print(sub_g_date)
                # print(sub_g)
                sns.lineplot(
                    x=sub_g_date,
                    y=[reg]*len(sub_g_date),
                    ax=ax,
                    color=sns.color_palette("deep")[3],
                    lw=10,
                    legend=None
                )
                sub_gs.append(sub_g)
            reg_indicators[reg] = sub_gs
        # print(reg_indicators)

    # require dots where only one week crosses the threshold:
    week_dots = long_reg[long_reg["above thresh"]==1]
    # print(week_dots.index)
    ax = sns.stripplot(
        data=week_dots,
        x="Week ended",
        y="region",
        jitter=False,
        size=9,
        color=sns.color_palette("deep")[3],
        orient="h",
        ax=ax
    )

    ax.set_ylabel("")
    ax.set_xlabel(f"weeks with percentage of deaths due to COVID-19 above {100*thresh:.0f}%")
    
    # mark the time axis with monthly major ticks:
    months = mdates.MonthLocator(interval=1)
    months_fmt = mdates.DateFormatter("%b-%y")
    ax.xaxis.set_major_locator(months)
    ax.xaxis.set_major_formatter(months_fmt)
    # colour the chart to identify weeks:
    days = mdates.DayLocator(interval=7)
    ax.xaxis.set_minor_locator(days)
    xticks = ax.get_xticks(minor=True)
    for x0, x1 in zip(xticks[::2], xticks[1::2]):
        ax.axvspan(x0, x1, color="black", alpha=0.1, zorder=0)
    # print(pd.to_datetime(ax.get_xticks(minor=True)))

    # ax2 = ax.twiny()
    # weeks = ax.get_xticks(minor=True)
    # weeks_fmt = mdates.DateFormatter("W")
    # ax2.xaxis.set_minor_locator(weeks)
    # ax2.xaxis.set_minor_formatter(weeks_fmt)

    fig.suptitle(f"weekly percentage of deaths due to COVID19 > {100*thresh:.0f}%")
    fig.tight_layout();

    if save is True:
        fig.savefig(
            get_project_root() / f"figures/t_plot_{thresh:.2f}.png"
        )

In [None]:
# help(mdates.DayLocator)

In [None]:
t_plot(0.15)
# t_plot(0.15, save=True)

In [None]:
# remove save arg:
def t_plot(thresh):

    long_reg["above thresh"] = long_reg["pc deaths due to covid19"].apply(lambda x: 1 if x>=thresh else 0)

    fig = plt.figure(figsize=(13,5))
    ax = fig.add_subplot(111)

    # a line can pick out consecutive weeks above threshold:
    for reg in long_reg["region"].unique():
        reg_thresh = long_reg.groupby("region").get_group(reg).reset_index()
        lis = reg_thresh.loc[reg_thresh["above thresh"] == 1].index
        if len(lis) > 0:
            t_groups = reduce(lambda x,y : x[:-1]+[x[-1]+[y]] if (x[-1][-1]+1==y) else [*x,[y]], lis[1:] , [[lis[0]]] )
            for sub_g in t_groups:
                sub_g = [sub_g[0], sub_g[-1]]
                sub_g_date = [reg_thresh["Week ended"][i] for i in sub_g]
                sns.lineplot(
                    x=sub_g_date,
                    y=[reg]*len(sub_g_date),
                    ax=ax,
                    color=sns.color_palette("deep")[3],
                    lw=10,
                    legend=None
                )

    # require dots where only one week crosses the threshold:
    week_dots = long_reg[long_reg["above thresh"]==1]
    ax = sns.stripplot(
        data=week_dots,
        x="Week ended",
        y="region",
        jitter=False,
        size=9,
        color=sns.color_palette("deep")[3],
        orient="h",
        ax=ax
    )

    ax.set_ylabel("")
    ax.set_xlabel(f"weeks with percentage of deaths due to COVID-19 above {100*thresh:.0f}%")
    
    # mark the time axis with monthly major ticks:
    months = mdates.MonthLocator(interval=1)
    months_fmt = mdates.DateFormatter("%b-%y")
    ax.xaxis.set_major_locator(months)
    ax.xaxis.set_major_formatter(months_fmt)
    # colour the chart to identify weeks:
    days = mdates.DayLocator(interval=7)
    ax.xaxis.set_minor_locator(days)
    xticks = ax.get_xticks(minor=True)
    for x0, x1 in zip(xticks[::2], xticks[1::2]):
        ax.axvspan(x0, x1, color="black", alpha=0.1, zorder=0)

    fig.suptitle(f"weekly percentage of deaths due to COVID19 > {100*thresh:.0f}%")
    fig.tight_layout();


def t_interactive():
    style = {"description_width": "initial"}
    slider_opts = Layout(
        width="100%",
        position=""
    )
    t_fig = interactive(
        t_plot,
        thresh=FloatSlider(
            min=long_reg["pc deaths due to covid19"].min(),
            max=long_reg["pc deaths due to covid19"].max(),
            step=1e-2,
            value=0.15,
            layout=slider_opts,
            description="percentage of deaths due to COVID-19",
            style=style
        )
    )

    controls = HBox(t_fig.children[:-1], layout = Layout(flex_flow='row wrap'))
    output = t_fig.children[-1]
    display(VBox([output, controls]))

In [None]:
t_interactive()