In [1]:
import pickle
import pandas as pd
import numpy as np
import datetime
from scipy import stats
from matplotlib import pyplot as plt
from matplotlib import dates as mdates
from ipywidgets import interact
import seaborn as sn
%matplotlib inline

from scrape_data import query

In [None]:
delta_dt = datetime.timedelta(hours=12)
# Download data from the last 12 hours
initial_dt = datetime.datetime.now() - delta_dt
query(initial_dt, delta_dt)

In [None]:
with open("data.pkl", "rb") as file:
    df_list, dt_list = pickle.load(file)

In [None]:
series_list = []

for df, dt in zip(df_list, dt_list):
    if "Bundesland" not in df.columns:
        df.columns = df.iloc[0]
    df = df[df["Bundesland"] != "Bundesland"]
    df = df.replace("Schleswig Holstein", "Schleswig-Holstein")
    series = df.set_index("Bundesland")["Fälle"]
    series.name = dt
    series_list.append(series)
    #print(dt, "\n", series, "\n")

df = pd.concat(series_list, axis=1, sort=True).transpose()
df = df.fillna(0).astype(int)
df = df.drop(columns=["Repatriierte"])
if "Sachsen-Anhalt" not in df.columns:
    df["Sachsen-Anhalt"] = 0
    df = df.reindex(sorted(df.columns), axis=1)
df.insert(0, "Total", df.pop("Gesamt"))
df = df.reset_index().rename(columns={"index": "Time"})
df["Time"] = pd.to_datetime(df["Time"])

df.to_csv("covid19-germany-lands.csv")

df

In [None]:
@interact(land=df.columns.to_list()[1:])
def plot_cases(land="Total", log_scale=False, show_fit=False):
    fig, ax = plt.subplots(figsize=(6,4), dpi=120)
    sn.lineplot(df["Time"], df[land], marker="o", ax=ax)
            
    ax.xaxis.set_major_locator(mdates.DayLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%d-%m'))
    plt.xticks(rotation=45)
    ax.grid()
    ax.set_ylabel("Cases")
    
    doubling_time = None
    
    ## Fit an exponential curve
    if show_fit:
        non_zero_data = df[["Time", land]][df[land] > 0]
        if len(non_zero_data) > 0:
            t = np.array([dt.timestamp() for dt in non_zero_data["Time"]])
            log_n = np.log2(non_zero_data[land].to_numpy())
            a, b, r_value, p_value, std_err = stats.linregress(t, log_n)
            t_dense = np.linspace(t[0], t[-1], 100)
            log_n_model = a * t_dense + b
            datetimes = np.array([datetime.datetime.fromtimestamp(ts) for ts in t_dense])
            n_model = 2**log_n_model
            doubling_time = 1. / (a * 3600 * 24)
            ax.plot(datetimes, n_model, zorder=-1)
        
    if log_scale:
        ax.set_yscale('log')
    else:
        ax.set_ylim([0, None])

    if land == "Total":
        land = "Germany"
    ax.set_title(f"COVID-19 cases in {land} (RKI data)" 
                 + ("" if doubling_time is None else f" \n Doubling time: {doubling_time:.1f} days"))
    fig.tight_layout()
    fig.savefig("covid19-germany.png")
    plt.show()