In [None]:
import locale

locale.setlocale(locale.LC_ALL, "de_AT.UTF-8");

In [None]:
%matplotlib inline
%precision %.4f
# %load_ext snakeviz
import colorcet
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib.dates
import matplotlib.ticker
import matplotlib.patches
import matplotlib.colors
import seaborn as sns
from importlib import reload
from itertools import count
from datetime import date, datetime, timedelta
from IPython.display import display, Markdown, HTML
import textwrap
from functools import partial
from cycler import cycler
from pathlib import Path

from covidat import cov
from covidat.util import DATAROOT

pd.options.display.precision = 4

cov = reload(cov)

plt.rcParams['figure.figsize'] = (16 * 0.7, 9 * 0.7)
plt.rcParams['figure.dpi'] = 200  # 80
plt.rcParams['figure.facecolor'] = '#fff'
sns.set_theme(style="whitegrid")
sns.set_palette('tab10')

plt.rcParams['image.cmap'] = cov.un_cmap

pd.options.display.max_rows = 120
pd.options.display.min_rows = 40

In [None]:
def stampit(fig):
    cov.stampit(fig, "EuroMOMO (euromomo.eu) 2022,2023")

In [None]:
Z_INDEX_COLS = ["country", "group", "week"]


def collect_z():
    all_fzs = []
    for f in Path(DATAROOT / "euromomo").glob("*.csv"):
        fzs = pd.read_csv(f, encoding="utf-8", sep=";")
        if 'zscore' not in fzs.columns:
            continue
        # print(f, fzs.columns.to_list(), fzs.iloc[0].to_list())
        if 'country' not in fzs.columns:
            fzs.insert(0, 'country', 'ALL')
        all_fzs.append(fzs)
    all_fzs.sort(key=lambda fzs: fzs.iloc[-1]["week"])
    # print(" ".join(x.iloc[-1]['week'] for x in all_fzs))
    combined = pd.concat(all_fzs)
    combined.drop_duplicates(subset=["week", "country", "group"], keep="last", inplace=True, ignore_index=True)
    return combined

In [None]:
zs = collect_z()  # pd.read_csv("zscores.csv", encoding="utf-8", sep=";")
zs["week"] = pd.to_datetime(zs["week"] + "-1", format="%G-%V-%u")
zs["week"].freq = "W"
# zs = zs.loc[zs["week"].max() - zs["week"] < timedelta(500)]
# zs = zs[zs["week"] != zs["week"].max()]
# zs = zs[zs["week"] != zs["week"].max()]
zs.set_index(Z_INDEX_COLS, inplace=True)
zs.sort_index(inplace=True)
zs = zs["zscore"]

In [None]:
zs.index.get_level_values("group").unique()

In [None]:
def grpname(grp):
    return grp.replace("P", "+").replace("to", "‒")

In [None]:
def draw_detail_ax(country, grp, ax, smallmode=False):
    z0 = zs.xs(country)
    if grp in z0.index:
        z0 = z0.xs(grp)
    else:
        # z0 = z0.xs('Total') * np.nan
        return
    ax.plot(z0, zorder=1, lw=1)  # , label="Z-Score")
    ax.axhline(0, color="k")
    ax.plot(
        z0.rolling(52, center=False).mean(), color="darkgrey", ls="--", label="52-Wochen-Schnitt", zorder=6, alpha=0.85
    )
    is_unusual = z0.between(2, 4, inclusive="right") | z0.between(-4, -2, inclusive="left")
    is_substantial = (z0 > 4) | (z0 < -4)
    ax.plot(
        z0.where(is_unusual),
        marker=".",
        color="C1",
        label="Außerhalb des Normalbereichs (>2)",
        markersize=2 if smallmode else 5,
    )
    ax.plot(z0.where(is_unusual | is_substantial), color="C1")
    ax.plot(
        z0.where(is_substantial),
        marker="s",
        color="r",
        label="Substantielle Abweichung (>4)",
        markersize=1 if smallmode else 3,
    )
    ax.axhspan(-2, 2, alpha=0.2, zorder=0)
    ax.axhspan(-1, 1, alpha=0.2, zorder=0)
    if z0.max() > 4:
        ax.axhline(4, color="r", ls=":", zorder=1.5)
    if z0.min() < -4:
        ax.axhline(-4, color="r", ls=":", zorder=1.5)
    fulldata_end = z0.index[-3] - timedelta(4)
    ax.axvspan(matplotlib.dates.date2num(fulldata_end), ax.get_xlim()[1], color="yellow", alpha=0.2)
    ax.axvline(fulldata_end, color="grey", alpha=0.5, ls=":")
    return z0


def draw_detail(country, grp):
    fig, ax = plt.subplots()
    z0 = draw_detail_ax(country, grp, ax)
    endd = z0.index[-1].strftime("KW%V %G")
    ax.set_xlim(z0.index[0] - timedelta(14), z0.index[-1] + timedelta(35))
    fig.suptitle(f"EuroMOMO Mortalitäts-Z-Score je Woche für {grpname(grp)} in {country} bis {endd}", y=0.95)
    fig.legend(loc="upper center", frameon=False, bbox_to_anchor=(0.5, 0.93), ncol=4)
    # ax.set_xlabel("Datum (Kalenderwoche)")
    ax.set_ylabel("Z-Score (Standardabweichungen)")
    stampit(fig)


def draw_detail_cmp(cmps):
    ncols = 3 if len(cmps) > 10 else 2 if len(cmps) > 2 else 1
    kws = {"figsize": (10, 10) if len(cmps) < 10 else (10, 15)} if len(cmps) > 4 else {}
    fig = plt.Figure(**kws)
    figh = fig.get_size_inches()[1]
    print(figh)
    axs = fig.subplots(nrows=(len(cmps) + ncols - 1) // ncols, ncols=ncols, squeeze=False, sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.1, top=0.94)
    if len(set(c[0] for c in cmps)) == 1:
        titlefmt = "{grp}"
        extratitle = " in " + next(iter(cmps))[0]
    elif len(set(c[1] for c in cmps)) == 1:
        titlefmt = "{country}"
        extratitle = ", " + grpname(next(iter(cmps))[1]) + ","
    else:
        titlefmt = "{grp} in {country}"
        extratitle = ""

    for (country, grp), ax in zip(cmps, axs.flat):
        z0 = draw_detail_ax(country, grp, ax, smallmode=ncols >= 3)
        if ax is axs.flat[0]:
            endd = z0.index[-1].strftime("KW%V %G")
            ax.set_xlim(z0.index[0] - timedelta(14), z0.index[-1] + timedelta(35))
            fig.suptitle(
                f"EuroMOMO Mortalitäts-Z-Score{extratitle} je Woche bis {endd}",
                y=0.98 if figh > 10 else 1 if figh > 9 else 1.04,
            )
            fig.legend(
                loc="upper center",
                frameon=False,
                bbox_to_anchor=(0.5, 0.975 if figh > 10 else 0.99 if figh > 9 else 1.02),
                ncol=4,
            )
        # if ax in axs[-1]:
        #    ax.set_xlabel("Datum (Kalenderwoche)")
        if ax in axs.T[0]:
            ax.set_ylabel("Z-Score")
        ax.set_title(titlefmt.format(grp=grpname(grp), country=country), y=0.97 if ncols < 3 else 0.94)
    if ncols > 2:
        fig.autofmt_xdate()
    stampit(fig)
    return fig, axs


if True:
    grp = "0to14"
    draw_detail("ALL", grp)
    # fig, axs = draw_detail_cmp([(("Austria", grp), ("Germany", grp), ("Switzerland",grp), ("Pooled", grp))])
    fig, axs = draw_detail_cmp(
        [(cntry, grp) for cntry in zs.index.get_level_values("country").unique() if cntry != "ALL"]
    )
    plt.setp(axs, ylim=(max(-3, axs.flat[0].get_ylim()[0]), min(axs.flat[0].get_ylim()[1], 16)))
    display(fig)

In [None]:
zs.index.get_level_values("country").unique()

In [None]:
cntry = "Austria"
fig, axs = draw_detail_cmp([(cntry, grp) for grp in zs.index.get_level_values("group").unique()])
plt.setp(axs, ylim=(-3.5, None))
display(fig)

In [None]:
grp = "Total"
fig, axs = draw_detail_cmp([("Austria", grp), ("Germany", grp), ("Switzerland", grp), ("Italy", grp)])
plt.setp(axs, ylim=(-3, None))
display(fig)

fig, axs = draw_detail_cmp([("Finland", grp), ("Sweden", grp), ("Denmark", grp), ("Norway", grp)])
display(fig)

In [None]:
zmean = zs.groupby(["country", "group"]).transform(lambda s: s.rolling(52).mean())

In [None]:
data_end = zmean.index.get_level_values("week").max() - timedelta(7 * 3)
zm0 = (
    zmean.to_frame().query("country != 'ALL'")
    # .query("country in ('Germany', 'Switzerland', 'Austria', 'Sweden', 'Finland', 'Norway', 'Denmark')")
    .xs("Total", level="group")
).reset_index()
zm0 = zm0[zm0["week"] <= data_end].dropna()
colors = sns.color_palette("cet_glasbey_dark", n_colors=zm0["country"].nunique())
ax = sns.lineplot(zm0, hue="country", y="zscore", x="week", legend=False, palette=colors)
fig = ax.figure
ax.axhline(0, color="grey")
ax.axhspan(-2, 2, alpha=0.06, zorder=0)
ax.axhspan(-1, 1, alpha=0.06, zorder=0)
# ax.axhline(4, color="r", ls=":")
cov.labelend2(ax, zm0, "zscore", x="week", cats="country", shorten=lambda c: c, colorize=colors)
ax.set_xlim(left=zm0["week"].min())
ax.set_ylim(bottom=-1.1)
fig.suptitle(
    "EuroMOMO Mortalitäts-Z-Score, Altersgruppe 15‒44, 52-Wochen-Schnitt bis " + data_end.strftime("KW %V %G*"), y=0.94
)
ax.set_title(
    "*letzte 2 Wochen wegen Unvollständigkeit weggelassen, Datenstand "
    + zmean.index.get_level_values("week").max().strftime("KW %V %G"),
    y=1,
    fontsize="medium",
)
stampit(fig)
ax.set_ylabel("Z-Score (Standardabweichungen)")
ax.annotate("Normaler Bereich (für Einzelwochen): ‒2 bis +2", (ax.get_xlim()[0], 2))
ax.set_xlabel(None)

In [None]:
zs17 = zs.loc[zs.index.get_level_values("week").isocalendar().year.to_numpy() >= 2016]
zs17 = zs17.loc[zs17.index.get_level_values("week").isocalendar().year.to_numpy() < 2024]
zscat = zs17.groupby(
    [
        zs17.index.get_level_values("country"),
        zs17.index.get_level_values("group"),
        zs17.index.get_level_values("week").isocalendar().year.to_numpy(),
    ]
).agg(
    **{
        "Substantiell erhöht (>4)": lambda s: (4 < s).sum(),
        "Oberhalb des Normalbereichs (>2)": lambda s: s.between(2, 4, inclusive="right").sum(),
        "Überdurchschnittlich (im Normalbereich)": lambda s: s.between(0, 2, inclusive="right").sum(),
        "Durchschitt oder darunter (<= 0)": lambda s: (s <= 0).sum(),
    }
)

In [None]:
zscat.index.get_level_values(2).max()

In [None]:
from logging import getLogger

def plt_y_cats_ax(country, grp, ax):
    with cov.with_palette(sns.color_palette("Reds_r", n_colors=3) + ["b"]):
        try:
            pdata = zscat.xs(grp, level="group").xs(country, level="country")
        except KeyError as exc:
            getLogger(__name__).warning("Error plotting %s/%s: %r", country, grp, exc)
            return
        pdata.plot.bar(stacked=True, ax=ax, lw=0)
    ax.get_legend().remove()


def plt_y_cats(country, grp):
    # display(pdata.xs("Total", level="group"))
    fig, ax = plt.subplots()
    plt_y_cats_ax(country, grp, ax)
    ax.legend(loc="upper left")
    # fig = ax.figure
    fig.suptitle(
        f"{grpname(grp)} in {country}: " + "Übersterblichkeit pro Jahr und Kalenderwoche (EuroMOMO Z-Scores)", y=0.93
    )


def draw_ycats_cmp(cmps):
    ncols = 3 if len(cmps) > 10 else 2 if len(cmps) > 2 else 1
    kws = {"figsize": (10, 10) if len(cmps) < 10 else (10, 15)} if len(cmps) > 4 else {}
    fig = plt.Figure(**kws)
    figh = fig.get_size_inches()[1]
    print(figh)
    axs = fig.subplots(nrows=(len(cmps) + ncols - 1) // ncols, ncols=ncols, squeeze=False, sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.1, top=0.94, hspace=0.25)
    if len(set(c[0] for c in cmps)) == 1:
        titlefmt = "{grp}"
        extratitle = " in " + next(iter(cmps))[0]
    elif len(set(c[1] for c in cmps)) == 1:
        titlefmt = "{country}"
        extratitle = ", " + grpname(next(iter(cmps))[1]) + ","
    else:
        titlefmt = "{grp} in {country}"
        extratitle = ""

    for (country, grp), ax in zip(cmps, axs.flat):
        plt_y_cats_ax(country, grp, ax)
        if ax is axs.flat[0]:
            # endd = z0.index[-1].strftime("KW%V %G")
            # ax.set_xlim(z0.index[0] - timedelta(14), z0.index[-1] + timedelta(35))
            fig.suptitle(
                f"EuroMOMO Mortalitäts-Z-Score{extratitle} Kategorien je Jahr",
                y=0.995 if figh > 10 else 1.02 if figh > 9 else 1.07,
            )
            fig.legend(
                loc="upper center",
                frameon=False,
                bbox_to_anchor=(0.5, 0.99 if figh > 10 else 1.015 if figh > 9 else 1.06),
                ncol=2,
            )
        # if ax in axs[-1]:
        #    ax.set_xlabel("Datum (Kalenderwoche)")
        if ax in axs.T[0]:
            ax.set_ylabel("Anzahl Kalenderwochen" if ncols <= 2 else "Anz. KWs")
        # ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(5, min_n_ticks=3, integer=True))
        ax.set_title(titlefmt.format(grp=grpname(grp), country=country), y=0.97 if ncols < 3 else 0.94)
    xlim = (zscat.index.get_level_values(2).min(), zscat.index.get_level_values(2).max())
    # print(xlim)
    # print(axs.flat[0].get_xlim())
    axs.flat[0].set_xlim(right=xlim[1] - xlim[0] + 0.5)
    plt.setp(
        axs[-1],
        xticks=list(range(xlim[1] - xlim[0] + 1)),
        xticklabels=[str(y) for y in zscat.index.get_level_values(2).unique()],
    )
    if figh >= 15:
        axs.flat[0].set_yticks([0, 25, 50])
        fig.autofmt_xdate()
    stampit(fig)
    return fig, axs


if False:
    grp = "Total"
    fig, axs = draw_ycats_cmp(
        [
            ("Austria", grp),
            ("Germany", grp),
            ("Switzerland", grp),
            ("Sweden", grp),
            # ("Sweden", grp), ("Finland", grp), ("Denmark",grp), ("Israel", grp),
        ]
    )
    display(fig)

    cntry = "Austria"
    fig, axs = draw_ycats_cmp(
        [
            (cntry, grp)
            for grp in zs17.index.get_level_values("group").unique()
            # ("Sweden", grp), ("Finland", grp), ("Denmark",grp), ("Israel", grp),
        ]
    )
    display(fig)

if True:
    grp = "Total"
    fig, axs = draw_ycats_cmp([(cntry, grp) for cntry in zs17.index.get_level_values("country").unique()])
    display(fig)

if False:
    # plt.setp(axs, ylim=(-3, None))
    plt_y_cats("Austria", "Total")
    plt_y_cats("Germany", "Total")
    plt_y_cats("ALL", "Total")