Skip to content

Commit

Permalink
Replace prediction part for generative model (#260)
Browse files Browse the repository at this point in the history
* replace prediction part for generative mmodel

* add save trace output

* revert back to the original prediction for generative model

* add logger

* dummy change to trigger quay build
  • Loading branch information
wangfan860 committed Mar 15, 2022
1 parent 7974969 commit 00f72dd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 25 deletions.
34 changes: 27 additions & 7 deletions covid19-notebooks/covid19-sir-bayes-model/sir_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,25 @@
from datetime import date
import time
import matplotlib
import logging
import sys

logger = logging.getLogger(__name__)


def setup_logger():
"""
Sets up logger
"""
logger_format = "[%(levelname)s] [%(asctime)s] [%(name)s] - %(message)s"
logger.setLevel(level=logging.INFO)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter(logger_format, datefmt="%Y%m%d %H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)


setup_logger()
# ------------------------------------------------------------------------------ #
# Step 1: load data
# ------------------------------------------------------------------------------ #
Expand All @@ -31,7 +49,7 @@
(confirmed_cases["Province_State"] == "Illinois")
& (confirmed_cases["Admin2"] == "Cook"),
:,
].columns[-72]
].columns[-96]
month, day, year = map(int, date_data_end.split("/"))

data_begin = date_data_begin
Expand All @@ -52,9 +70,10 @@
"1/22/20":,
]
)[0]

date_data_end = date(year + 2000, month, day)
date_today = date_data_end + datetime.timedelta(days=1)
print(
logger.info(
"Cases yesterday ({}): {} and day before yesterday: {}".format(
date_data_end.isoformat(), *cases_obs[:-3:-1]
)
Expand Down Expand Up @@ -320,7 +339,8 @@ def next_day(λ, S_t, I_t, _):
# -------------------------------------------------------------------------- #
time_beg = time.time()
trace = pm.sample(draws=500, tune=800, chains=2)
print("Model run in {:.2f} s".format(time.time() - time_beg))
pm.save_trace(trace=trace, directory="./sir_model_trace", overwrite=True)
logger.info("Model run in {:.2f} s".format(time.time() - time_beg))

# -------------------------------------------------------------------------------
# Step 4 Plot data of new infections
Expand Down Expand Up @@ -413,7 +433,7 @@ def next_day(λ, S_t, I_t, _):
ylabel="Daily confirmed cases",
)
plt.suptitle(
"With Reported Data From Past 2 Months and SIR Model Predictions",
"With Reported Data From Past 3 Months and SIR Model Predictions",
fontsize=12,
y=0.92,
)
Expand Down Expand Up @@ -470,6 +490,7 @@ def return_obs_cases_future(trace):
obs_cases_future[label] = (
np.cumsum(trace[label], axis=1)
+ np.sum(trace.new_I_past, axis=1)[:, None]
+ cases_obs[0]
+ trace.I_begin[:, None]
)
obs_cases_future[label] = obs_cases_future[label].T
Expand Down Expand Up @@ -503,7 +524,6 @@ def return_obs_cases_future(trace):
for label, color, legend in zip(obs_cases_labels_local, colors, legends_list[1]):
time = np.arange(0, num_days_to_predict)
cases = dict_obsc_cases[label]
cases = cases + cases_obs[0]
# find median
median = np.median(cases, axis=-1)
percentiles = (
Expand Down Expand Up @@ -553,7 +573,7 @@ def return_obs_cases_future(trace):
ax.yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(func_format))

plt.suptitle(
"With Reported Data From Past 2 Months and SIR Model Predictions",
"With Reported Data From Past 3 Months and SIR Model Predictions",
fontsize=12,
y=0.92,
)
Expand All @@ -564,7 +584,7 @@ def return_obs_cases_future(trace):
"IL_tab_charts_cumulative_logistic_last360.svg", dpi=60, bbox_inches="tight"
)

print(
logger.info(
"effective m: {:.3f} +- {:.3f}".format(
1 + np.median(trace.λ - trace.μ), np.std(trace.λ - trace.μ)
)
Expand Down
46 changes: 28 additions & 18 deletions generative_bayes_model/pymc3_generative_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import os
import logging

t0 = time.time()
import pymc3 as pm
Expand Down Expand Up @@ -29,6 +30,23 @@
#%config InlineBackend.figure_format = 'svg'
theano.config.gcc.cxxflags = "-Wno-c++11-narrowing"

import logging
import sys

logger = logging.getLogger(__name__)


def setup_logger():
"""
Sets up the logger.
"""
logger_format = "[%(levelname)s] [%(asctime)s] [%(name)s] - %(message)s"
logger.setLevel(level=logging.INFO)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter(logger_format, datefmt="%Y%m%d %H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)


def _random(self, sigma, mu, size, sample_shape):
if size[len(sample_shape)] == sample_shape:
Expand Down Expand Up @@ -90,6 +108,8 @@ def maximum_zeros_length(a):
return max(all_length)


setup_logger()

jh_data = pd.read_csv(
"https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_US.csv"
)
Expand Down Expand Up @@ -171,6 +191,11 @@ def average_missing_data(numbers, window_size):
trace_r_t_infection_delay = pm.sample(
tune=500, chains=2, cores=8, target_accept=0.9
)
pm.save_trace(
trace=trace_r_t_infection_delay,
directory="./trace_r_t_infection_delay",
overwrite=True,
)


def conv(a, b, len_observed):
Expand Down Expand Up @@ -241,16 +266,6 @@ def get_delay_distribution():
"test_adjusted_positive", conv(infections, p_delay, len_observed)
)

# Stop infections from taking on unresonably large values that break the NegativeBinomial
# infections = tt.clip(infections, 0, 13_000_000)

# Likelihood
# pm.NegativeBinomial(
# 'obs',
# infections,
# alpha = pm.Gamma('alpha', mu=6, sigma=1),
# observed=daily_data.cases.values
# )
eps = pm.HalfNormal("eps", 10) # Error term
pm.Lognormal(
"obs",
Expand All @@ -263,6 +278,7 @@ def get_delay_distribution():

with model_r_t_onset:
trace_r_t_onset = pm.sample(tune=500, chains=2, cores=8, target_accept=0.9)
pm.save_trace(trace=trace_r_t_onset, directory="./trace_r_t_onset", overwrite=True)

start_date = daily_data.date[0]
fig, ax = plt.subplots(figsize=(10, 6))
Expand All @@ -272,7 +288,6 @@ def get_delay_distribution():
color="0.5",
alpha=0.05,
)
# plt.plot(pd.date_range(start=start_date, periods=len(daily_data.cases.values), freq='D'), trace_r_t_infection_delay['r_t'].T, color='r', alpha=0.1)
ax.set(
xlabel="Time",
ylabel="$R_e(t)$",
Expand Down Expand Up @@ -315,6 +330,7 @@ def get_delay_distribution():
trace = pm.sample(
500, tune=800, chains=1, target_accept=0.95, random_seed=42, cores=8
)
pm.save_trace(trace=trace, directory="./trace", overwrite=True)

with model:
y_future = pm.Poisson("y_future", mu=tt.exp(f[-F:]), shape=F)
Expand Down Expand Up @@ -365,12 +381,6 @@ def get_delay_distribution():
y=1.1,
)

# def thousands(x, pos):
# "The two args are the value and tick position"
# return "%1.0fK" % (x * 1e-3)

# formatter = FuncFormatter(thousands)
# ax.yaxis.set_major_formatter(formatter)
fig.autofmt_xdate()
legend_elements = [
Line2D([0], [0], color="red", lw=2, label="Reported cases"),
Expand Down Expand Up @@ -411,4 +421,4 @@ def get_delay_distribution():
fig.savefig("results/17031/cases.svg", dpi=60, bbox_inches="tight")
t1 = time.time()
totaltime = (t1 - t0) / 3600
print("total run time is {:.4f} hours".format(totaltime))
logger.info("total run time is {:.4f} hours".format(totaltime))

0 comments on commit 00f72dd

Please sign in to comment.