In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# Forecast ICU Usage into the Future with Best, Worst, Current, Estimates
https://seaborn.pydata.org/examples/errorband_lineplots.html

In [None]:
# Load an example dataset with long-form data
fmri = sns.load_dataset("fmri")

# Plot the responses for different events and regions
sns.lineplot(x="timepoint", y="signal",
             hue="region", style="event",
             data=fmri)

In [None]:
fmri.head()

# Synthetic Data

# Plot Our Data

In [None]:
import datetime
import logging
import os
import numpy as np
from multiprocessing import Pool
from functools import partial
import us
import pickle
import json
import copy
from collections import defaultdict
from pyseir.models.seir_model import SEIRModel
from pyseir.parameters.parameter_ensemble_generator import ParameterEnsembleGenerator
import pyseir.models.suppression_policies as sp
from pyseir import load_data
from pyseir.reports.county_report import CountyReport
from pyseir.utils import get_run_artifact_path, RunArtifact, RunMode
from pyseir.inference import fit_results
from libs.datasets.dataset_utils import AggregationLevel
from libs.datasets import JHUDataset
import pyseir.ensembles.ensemble_runner

In [None]:
example=load_data.load_ensemble_results("06")

In [None]:
example.keys()

In [None]:
example['suppression_policy__inferred'].keys()

In [None]:
x = np.array(example['suppression_policy__inferred']['t_list']).astype(float)
y = np.array(example['suppression_policy__inferred']['HICU']['ci_50']).astype(float)

In [None]:
test_df = pd.DataFrame([x,y]).T
test_df.columns = ['t','HICU']
test_df['r'] = 1

In [None]:
test_df.head()

In [None]:
test_df.plot(x=0, y=1)

In [None]:
# Fuzz and Make For Seaborn

In [None]:
DIVERGENCE_PT = 300

In [None]:
import numpy as np

In [None]:
rng = np.random.default_rng()

In [None]:
np.geomspace(-1,-200,len(y[DIVERGENCE_PT:]))

In [None]:
r = np.array([-.1,0,.1])

In [None]:
N = len(y[DIVERGENCE_PT:])
tmp_dfs = []
for final_offset in np.linspace(-3000,3000,100):
    y = test_df["HICU"].to_numpy(copy=True)
    if final_offset < 0:
        shim = np.geomspace(-1, final_offset, N)
    else:
        shim = np.geomspace(1, final_offset, N)
    y[DIVERGENCE_PT:] = y[DIVERGENCE_PT:] + shim
    tmp = pd.DataFrame([x,y]).T
    tmp['r'] = final_offset
    tmp_dfs.append(tmp)

In [None]:
# y[DIVERGENCE_PT:] + np.arange(len(y[DIVERGENCE_PT:]))*scale

In [None]:
df = pd.concat(tmp_dfs)
df.columns = ["t","HICU","R"]
df['HICU'].clip(lower=0, inplace=True)

high_cutoff = df.R.quantile(.75)
low_cutoff = df.R.quantile(.25)
df['Performance'] = ["Worst" if x > high_cutoff else "Best" if x < low_cutoff else "Current" for x in df.R]

In [None]:
df.R.describe()

In [None]:
df.head()

In [None]:
# pd.DataFrame([x,y]).T

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
sns.lineplot(data=df, x="t", y="HICU", hue="Performance",ax=ax)
mask = test_df['t'] < DIVERGENCE_PT
sns.lineplot(x=test_df[mask].t, y=test_df[mask].HICU ,ax=ax, color='k')
sns.scatterplot(x=test_df[mask].t,
                y=np.clip(test_df[mask].HICU+rng.normal(scale=50, size=len(test_df[mask].HICU)),
                          0, 10000
                         ),
                ax=ax, color='lightblue', marker='x', )

In [None]:
fig

In [None]:
fig.savefig('tmp.png', bbox_inches='tight')

In [None]:
df.groupby('R').plot(x='t',y='HICU')

In [None]:
EXAMPLE = "California"
scenario = "inferred"

state_obj = us.states.lookup(EXAMPLE)

artifact_path = get_run_artifact_path(state_obj.fips, RunArtifact.MLE_FIT_MODEL)

with open(artifact_path, "rb") as f:
    model = pickle.load(f)
inferred_params = fit_results.load_inference_result(state_obj.fips)

# Determine the appropriate future suppression policy based on the
# scenario of interest.
if scenario == "inferred":
    eps_final = inferred_params["eps2"]
else:
    eps_final = sp.get_future_suppression_from_r0(inferred_params["R0"], scenario=scenario)

model.suppression_policy = sp.get_epsilon_interpolator(
    eps=inferred_params["eps"],
    t_break=inferred_params["t_break"],
    eps2=inferred_params["eps2"],
    t_delta_phases=inferred_params["t_delta_phases"],
    t_break_final=(
        datetime.datetime.today()
        - datetime.datetime.fromisoformat(inferred_params["t0_date"])
    ).days,
    eps_final=eps_final,
)
model.run()

In [None]:
inferred_params

In [None]:
model.suppression_policy

In [None]:
model.run()

In [None]:
model.