In [5]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import requests
import pymc3 as pm
import pandas as pd
import numpy as np
import theano
import theano.tensor as tt

from matplotlib import pyplot as plt
from matplotlib import dates as mdates
from matplotlib import ticker

from datetime import date
from datetime import datetime

from IPython.display import clear_output

%config InlineBackend.figure_format = 'retina'

In [19]:
def get_dataframe_india(url):
  print("Reading: "+url)
  df = pd.read_csv(url,usecols=['Patient Number', 'Date Announced', 'Detected State'])
  return df

urls = ["https://api.covid19india.org/csv/latest/raw_data"+str(i)+".csv" for i in range(3,7)]
data_frames = [get_dataframe_india(url) for url in urls]

df = pd.concat(data_frames).reset_index(drop = True)
df.rename(columns={'Detected State':'country',
                    'Date Announced':'date',
                   'Patient Number': 'positive'}, 
                 inplace=True)
df.to_csv("today_data.csv",index=False)

dateparse = lambda x: datetime.strptime(x, '%d/%m/%Y')

countries = pd.read_csv("today_data.csv",
                     parse_dates=['date'], date_parser=dateparse,#infer_datetime_format=True,
                     index_col=['country', 'date'],squeeze=True).sort_index()

countries = countries[~countries.index.duplicated(keep='last')]

Reading: https://api.covid19india.org/csv/latest/raw_data3.csv
Reading: https://api.covid19india.org/csv/latest/raw_data4.csv
Reading: https://api.covid19india.org/csv/latest/raw_data5.csv
Reading: https://api.covid19india.org/csv/latest/raw_data6.csv


In [21]:
countries.xs('West Bengal').diff().dropna()

date
2020-04-28      819.0
2020-04-29      116.0
2020-04-30      680.0
2020-05-01     1551.0
2020-05-02       94.0
2020-05-03     1479.0
2020-05-04      -12.0
2020-05-05     1071.0
2020-05-06      179.0
2020-05-07     1728.0
2020-05-08     1124.0
2020-05-09      354.0
2020-05-10     3042.0
2020-05-11     1266.0
2020-05-12     1060.0
2020-05-13     1078.0
2020-05-14      874.0
2020-05-15      972.0
2020-05-16      484.0
2020-05-17     1806.0
2020-05-18     2789.0
2020-05-19       42.0
2020-05-20      511.0
2020-05-21     2611.0
2020-05-22     1419.0
2020-05-23     1773.0
2020-05-24     1627.0
2020-05-25     1661.0
2020-05-26     1509.0
2020-05-27    31981.0
2020-05-28   -26376.0
2020-05-29      -98.0
2020-05-30     1965.0
2020-05-31      882.0
2020-06-01     2894.0
2020-06-02     1433.0
2020-06-03     1579.0
2020-06-04     1180.0
2020-06-05     1960.0
2020-06-06     1797.0
2020-06-07     1366.0
2020-06-08     2004.0
2020-06-09     1133.0
2020-06-10     1402.0
2020-06-11     6877.0
2020-

In [12]:
country_name = 'West Bengal'
cases = countries.xs(country_name).rename(f"{country_name} cases")
original, smoothed = prepare_cases_ksys(cases)

In [13]:
smoothed

date
2020-04-27     559.0
2020-04-28     653.0
2020-04-29     652.0
2020-04-30     740.0
2020-05-01     748.0
2020-05-02     769.0
2020-05-03     739.0
2020-05-04     750.0
2020-05-05     769.0
2020-05-06     847.0
2020-05-07    1021.0
2020-05-08    1226.0
2020-05-09    1350.0
2020-05-10    1449.0
2020-05-11    1384.0
2020-05-12    1282.0
2020-05-13    1143.0
2020-05-14     998.0
2020-05-15    1117.0
2020-05-16    1193.0
2020-05-17    1245.0
2020-05-18    1337.0
2020-05-19    1351.0
2020-05-20    1419.0
2020-05-21    1480.0
2020-05-22    1542.0
2020-05-23    1666.0
2020-05-24    3817.0
2020-05-25    3630.0
2020-05-26    3616.0
2020-05-27    2626.0
2020-05-28    1014.0
2020-05-29     -60.0
2020-05-30    -159.0
2020-05-31    -447.0
2020-06-01    1607.0
2020-06-02    1708.0
2020-06-03    1664.0
2020-06-04    1652.0
2020-06-05    1618.0
2020-06-06    1630.0
2020-06-07    1607.0
2020-06-08    1949.0
2020-06-09    1840.0
2020-06-10    1809.0
2020-06-11    1650.0
2020-06-12    1379.0
2020-06-

In [4]:
def prepare_cases_ksys(cases, cutoff=25):
    new_cases = cases.diff()

    smoothed = new_cases.rolling(7,
        win_type='gaussian',
        min_periods=1,
        center=True).mean(std=2).round()
    
    idx_start = np.searchsorted(smoothed, cutoff)
    
    smoothed = smoothed.iloc[idx_start:]
    original = new_cases.loc[smoothed.index]
    
    return original, smoothed

In [1]:
class MCMCModel(object):
    
    def __init__(self, region, onset, cumulative_p_delay, window=50):
        
        # Just for identification purposes
        self.region = region
        
        # For the model, we'll only look at the last N
        self.onset = onset.iloc[-window:]
        self.cumulative_p_delay = cumulative_p_delay[-window:]
        
        # Where we store the results
        self.trace = None
        self.trace_index = self.onset.index[1:]

    def run(self, chains=1, tune=3000, draws=1000, target_accept=.95):

        with pm.Model() as model:

            # Random walk magnitude
            step_size = pm.HalfNormal('step_size', sigma=.03)

            # Theta random walk
            theta_raw_init = pm.Normal('theta_raw_init', 0.1, 0.1)
            theta_raw_steps = pm.Normal('theta_raw_steps', shape=len(self.onset)-2) * step_size
            theta_raw = tt.concatenate([[theta_raw_init], theta_raw_steps])
            theta = pm.Deterministic('theta', theta_raw.cumsum())

            # Let the serial interval be a random variable and calculate r_t
            serial_interval = pm.Gamma('serial_interval', alpha=6, beta=1.5)
            gamma = 1.0 / serial_interval
            r_t = pm.Deterministic('r_t', theta/gamma + 1)

            inferred_yesterday = self.onset.values[:-1] / self.cumulative_p_delay[:-1]
            
            expected_today = inferred_yesterday * self.cumulative_p_delay[1:] * pm.math.exp(theta)

            # Ensure cases stay above zero for poisson
            mu = pm.math.maximum(.1, expected_today)
            observed = self.onset.round().values[1:]
            cases = pm.Poisson('cases', mu=mu, observed=observed)

            self.trace = pm.sample(
                chains=chains,
                tune=tune,
                draws=draws,
                target_accept=target_accept)
            
            return self
    
    def run_gp(self):
        with pm.Model() as model:
            gp_shape = len(self.onset) - 1

            length_scale = pm.Gamma("length_scale", alpha=3, beta=.4)

            eta = .05
            cov_func = eta**2 * pm.gp.cov.ExpQuad(1, length_scale)

            gp = pm.gp.Latent(mean_func=pm.gp.mean.Constant(c=0), 
                              cov_func=cov_func)

            # Place a GP prior over the function f.
            theta = gp.prior("theta", X=np.arange(gp_shape)[:, None])

            # Let the serial interval be a random variable and calculate r_t
            serial_interval = pm.Gamma('serial_interval', alpha=6, beta=1.5)
            gamma = 1.0 / serial_interval
            r_t = pm.Deterministic('r_t', theta / gamma + 1)

            inferred_yesterday = self.onset.values[:-1] / self.cumulative_p_delay[:-1]
            expected_today = inferred_yesterday * self.cumulative_p_delay[1:] * pm.math.exp(theta)

            # Ensure cases stay above zero for poisson
            mu = pm.math.maximum(.1, expected_today)
            observed = self.onset.round().values[1:]
            cases = pm.Poisson('cases', mu=mu, observed=observed)

            self.trace = pm.sample(chains=1, tune=1000, draws=1000, target_accept=.8)
        return self

In [2]:
def df_from_model(model):
    
    r_t = model.trace['r_t']
    mean = np.mean(r_t, axis=0)
    median = np.median(r_t, axis=0)
    hpd_90 = pm.stats.hpd(r_t, credible_interval=.9)
    hpd_50 = pm.stats.hpd(r_t, credible_interval=.5)
    
    idx = pd.MultiIndex.from_product([
            [model.region],
            model.trace_index
        ], names=['region', 'date'])
        
    df = pd.DataFrame(data=np.c_[mean, median, hpd_90, hpd_50], index=idx,
                 columns=['mean', 'median', 'lower_90', 'upper_90', 'lower_50','upper_50'])
    return df

In [None]:
def create_and_run_model(name, state):
    confirmed = state.positive.diff().dropna()
    onset = confirmed_to_onset(confirmed, p_delay)
    adjusted, cumulative_p_delay = adjust_onset_for_right_censorship(onset, p_delay)
    return MCMCModel(name, onset, cumulative_p_delay).run()

In [None]:
models = {}

for state, grp in states.groupby('state'):
    
    print(state)
    
    if state in models:
        print(f'Skipping {state}, already in cache')
        continue
    
    models[state] = create_and_run_model(grp.droplevel(0))

Downloading file, this will take a while ~100mb
Something went wrong. Try again.


In [None]:
# Check to see if there were divergences
n_diverging = lambda x: x.trace['diverging'].nonzero()[0].size
divergences = pd.Series([n_diverging(m) for m in models.values()], index=models.keys())
has_divergences = divergences.gt(0)

print('Diverging states:')
display(divergences[has_divergences])

# Rerun states with divergences
for state, n_divergences in divergences[has_divergences].items():
    models[state].run()

In [None]:
results = None

for state, model in models.items():

    df = df_from_model(model)

    if results is None:
        results = df
    else:
        results = pd.concat([results, df], axis=0)

In [3]:
def plot_rt(name, result, ax, c=(.3,.3,.3,1), ci=(0,0,0,.05)):
    ax.set_ylim(0.5, 1.6)
    ax.set_title(name)
    ax.plot(result['median'],
            marker='o',
            markersize=4,
            markerfacecolor='w',
            lw=1,
            c=c,
            markevery=2)
    ax.fill_between(
        result.index,
        result['lower_90'].values,
        result['upper_90'].values,
        color=ci,
        lw=0)
    ax.axhline(1.0, linestyle=':', lw=1)
    
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m/%d'))
    ax.xaxis.set_major_locator(mdates.WeekdayLocator(interval=2))

In [None]:
ncols = 4
nrows = int(np.ceil(results.index.levels[0].shape[0] / ncols))

fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(14, nrows*3),
    sharey='row')

for ax, (state, result) in zip(axes.flat, results.groupby('region')):
    plot_rt(state, result.droplevel(0), ax)

fig.tight_layout()
fig.set_facecolor('w')