# COVID-19 Spread Toy Example
## Initilization Anchored NN Ensemble Sanity Check


In this notebook I want to demonstrate that my tensorflow implementation of the ensemble neural network is actually working and useful. In the spirit of times, I will try to learn the _hypothetical_ spreading of the COVID-19 disease in the _hypothetical_ island of Wakanda through the period of one year.

In [1]:
from simba.infrastructure import MLPEnsemble
import tensorflow as tf
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


First we generate some data using the [SIR model](https://www.lewuathe.com/covid-19-dynamics-with-sir-model.html) of covid19: 

In [2]:
def generate_covid_19_infection_rate_data():
    # https://www.lewuathe.com/covid-19-dynamics-with-sir-model.html
    # https://scipython.com/book/chapter-8-scipy/additional-examples/the-sir-epidemic-model/
    population = 15000
    days = 365
    i_0, r_0 = 2, 0
    s_0 = population - i_0 - r_0
    beta, gamma = 0.3, 0.02
    t = np.linspace(0, days, days)

    def deriv(y, t, population, beta, gamma):
        S, I, R = y
        dSdt = -beta * S * I / population
        dIdt = beta * S * I / population - gamma * I
        dRdt = gamma * I
        return dSdt, dIdt, dRdt
    y_0 = s_0, i_0, r_0
    ret = odeint(deriv, y_0, t, args=(population, beta, gamma))
    _, infected_people, _ = ret.transpose()
    return t, infected_people

Say we have only have access to noisy measurements of how many people were sick on a certain day: 

In [3]:
time, infected_people = generate_covid_19_infection_rate_data()
n_samples = 2
noise = 0.1
inputs = np.array([])
targets = np.array([])
for day, sick_people_that_day in zip(time, infected_people):
    inputs = np.append(inputs, np.full(n_samples, day))
    targets = np.append(targets, np.random.normal(
    sick_people_that_day, noise * sick_people_that_day, n_samples))

In [4]:
# Some hyperparameters
def make_model(sess):
    mlp_dict = dict(
        input_dim=1,
        targets_dim=1,
        learning_rate=0.01,
        n_layers=3,
        hidden_size=128,
        activation=tf.nn.relu,
        anchor=False,
        init_std_bias=0.5,
        init_std_weights=0.5,
        data_noise=0.5
    )
    return MLPEnsemble(
        sess=sess,
        ensemble_size=5,
        n_epochs=4000,
        batch_size=512,
        **mlp_dict
    )

Run the training loop:

In [6]:
tf.reset_default_graph()

In [None]:
mean, std = inputs.mean(), inputs.std()
n_particles = 20
x_val = np.tile(time, n_particles)
x = (inputs - mean) / (std + 1e-8)
with tf.Session() as sess:
    model = make_model(sess)
    sess.run(tf.global_variables_initializer())
    losses = model.fit(x[:, np.newaxis], targets[:, np.newaxis])
    pred = np.squeeze(model.predict(
        (x_val[:, np.newaxis] - mean) / (std + 1e-8)))

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Epoch  0  | Losses = [216438.68977813932, 247910.72325751593, 240608.1325727085, 223406.16587615234, 244755.86627295462]
Epoch  20  | Losses = [95283.43399483076, 215017.64836877983, 88428.76510919986, 107969.99611274821, 89255.34101078685]
Epoch  40  | Losses = [87205.29129651809, 142031.46124455694, 89907.35753266234, 84035.60742481922, 83597.9734372082]
Epoch  60  | Losses = [71895.46821505246, 97233.23930277654, 82065.80097302527, 76709.30095421463

Reshaping back to (ensemble_size, n_particles, 365days). n_samples will help estimating aleatoric uncertainty and ensemble_size will help estimating epistemic uncertainty.

In [None]:
pred = np.reshape(pred, 
                  (model.ensemble_size, n_particles, time.shape[0]))
# For more details on decomposition of uncertainties: http://proceedings.mlr.press/v80/depeweg18a/depeweg18a.pdf 
aleatoric_uncertainty = np.mean(np.std(pred, axis=1) ** 2, axis=0)
epistemic_uncertainty = np.std(np.mean(pred, axis=1), axis=0) ** 2
print(aleatoric_uncertainty)

Do some plotting

In [None]:
fig = plt.figure(figsize=(12, 10), dpi= 80, facecolor='w', edgecolor='k')
ax = fig.subplots()
ax.set_ylim([-100, 15e3])
ax.scatter(inputs, targets, color='#FF764D', alpha=0.6,
           s=5, label='Infectious people a day')
# mean = np.mean(pred, axis=(0, 1))
mean = pred[0, 0, :]
ax.plot(time, mean, '-', color='#C20093', linewidth=1, label='Predictions')
ax.fill_between(time, mean - np.sqrt(epistemic_uncertainty), mean + np.sqrt(epistemic_uncertainty),
                color='#FC206C', alpha=0.15, label='Epistemic uncertainty')
ax.errorbar(time, mean, yerr=np.sqrt(aleatoric_uncertainty), linewidth=0.0,
             ecolor='silver', elinewidth=3, capsize=0.0, label='Aleatoric uncertainty')
legend = ax.legend(loc='upper right', fontsize='medium')
plt.xlabel("Days")
plt.ylabel("Infectious people")
plt.show()

In [None]:
# fig = plt.figure(figsize=(12, 10), dpi= 80, facecolor='w', edgecolor='k')
# ax = fig.subplots()
# ax.set_ylim([-100, 15e3])
# ax.scatter(inputs, targets, color='#FF764D', alpha=0.6,
#            s=5, label='Infectious people a day')
# ax.plot(time, mus, '-', color='#C20093', linewidth=1, label='Predictions')
# ax.errorbar(time, mus, yerr=sigmas, linewidth=0.0,
#              ecolor='silver', elinewidth=3, capsize=0.0, label='Aleatoric uncertainty', alpha=0.2)
# legend = ax.legend(loc='upper right', fontsize='medium')
# plt.xlabel("Days")
# plt.ylabel("Infectious people")
# plt.show()