## Setup

In [1]:
# Standard Imports
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import os

## Load

In [3]:
from datetime import date

df = pd.read_csv("./data/timeseries.csv")
df["date"] = pd.to_datetime(df["date"])
df["date_no"] = (df["date"] - pd.Timestamp('2020-01-01 00:00:00')).dt.days
df["per_capita"] = df["cases"] / df["population"]

print(df.columns)

Index(['name', 'level', 'city', 'county', 'state', 'country', 'population',
       'lat', 'long', 'url', 'aggregate', 'tz', 'cases', 'deaths', 'recovered',
       'active', 'tested', 'hospitalized', 'discharged', 'growthFactor',
       'date', 'date_no', 'per_capita'],
      dtype='object')


  interactivity=interactivity, compiler=compiler, result=result)


## Analysis

In [4]:
from pymc3.ode import DifferentialEquation
import pymc3 as pm

In [5]:
def sir_ode(y, t, p, N):
    # y = [S, I, R]; p = [beta, gamma]
    return (
        - p[0] * y[0] * y[1] / N,
        p[0] * y[0] * y[1] / N - p[1] * y[1],
        p[1] * y[1]
    )

def sample(data, start=21, end=0):

    population = data["population"].iloc[0]
    R_true = np.array(data.sort_values("date_no")[-start:-end]["cases"])

    sir = DifferentialEquation(
        func=lambda y, t, p: sir_ode(y, t, p, population),
        times=np.arange(0, start - end, 1),
        n_states=3, n_theta=2, t0=0)

    with pm.Model() as model:
        gamma = pm.Gamma('gamma', 2, 2)
        beta = pm.Gamma('beta', 2, 2)
        i_0 = pm.Poisson('i_0', int(R_true[4] - R_true[0]))

        prior_start = [population - i_0 - R_true[0], i_0, R_true[0]]
        sir_mean = sir(y0=prior_start, theta=[beta, gamma])
        Y = pm.Poisson("Y", mu=sir_mean[:, 2], observed=R_true)

        trace = pm.sample(draws=500, tune=500, step=pm.NUTS(), chains=1)

    return trace

In [6]:
def sample_place(name):
    _df = df[df["name"] == name]
    trace = sample(_df, start=26, end=5)
    np.savez(
        "./results/trace_{}.npz".format(name),
        gamma=trace["gamma"], beta=trace["beta"], i_0=trace["i_0"])

In [9]:
places_qualified = df[
    (df["cases"] > 100) &
    (df["date_no"] < df["date_no"].max() - 26)
]["name"].unique()
countries = df[df["name"].isin(places_qualified) & (df["level"] == "country")]["name"].unique()

In [10]:
def sample_countries():

    for country in countries:
        print(country)
        if "trace_{}.npz".format(country) not in os.listdir("./results"):
            try:
                sample_place(country)
            except Exception as e:
                print("EXCEPTION:")
                print(e)

In [11]:
counties_qualified = df[
    (df["cases"] > 50) &
    (df["date_no"] < df["date_no"].max() - 26)
]["name"].unique()

top_states = df[
    df["name"].isin(counties_qualified) &
    (df["level"] == "county") &
    (df["country"] == 'United States') &
    (df["date_no"] == 114)
]["state"].value_counts()

In [12]:
print(top_states[:10])

New York         21
New Jersey       17
California       16
Pennsylvania     12
Florida          12
Colorado         11
Massachusetts    10
Washington       10
Texas             9
Georgia           9
Name: state, dtype: int64


In [13]:
def sample_state(state):
    df_ = df[
        (df["level"] == "county") &
        (df["state"] == state) &
        (df["name"].isin(counties_qualified))
    ]
    
    targets = df_["name"].unique()
    for idx, place in enumerate(targets):
        print("[{}/{}] {}".format(idx + 1, len(targets), place))
        if("trace_{}.npz".format(place) not in os.listdir("./results")):
            sample_place(place)

In [None]:
# sample_state("New York")

In [None]:
# sample_state("New Jersey")

In [None]:
# sample_state("California")

In [16]:
from tqdm import tqdm

for file in tqdm(os.listdir("./results")):
    fig, axs = plt.subplots(3, 2, figsize=(16, 12))

    trace = np.load(os.path.join("./results", file))
    for ax, var in zip(axs, ["i_0", "beta", "gamma"]):
        ax[0].plot(trace[var])
        ax[0].set_title(var)
        ax[1].hist(trace[var], bins=25)
        ax[1].set_title("mean={:.4f}".format(np.mean(trace[var])))

    fig.savefig('./traceplot/' + file.split(".")[0] + ".png")
    plt.close(fig)


100%|██████████| 104/104 [00:41<00:00,  2.50it/s]


In [40]:
import seaborn as sns

for file in tqdm(os.listdir("./results")):
    trace = np.load(os.path.join("./results", file))
    
    g = sns.PairGrid(pd.DataFrame({k: trace[k] for k in ['i_0', 'beta', 'gamma']}))
    g.map_diag(sns.kdeplot)
    g.map_offdiag(sns.kdeplot, shade=True)
    g.fig.suptitle(file)
    g.fig.savefig('./kdeplot/' + file.split(".")[0] + ".png")
    plt.close(g.fig)

100%|██████████| 104/104 [02:28<00:00,  1.43s/it]
