After many attempts with SEIR models and other ODEs in pymc, they are still slow! The bottom line seems to be that NUTS is struggling. I've tried many different formulations including alternate DifferentialEquation nodes, one using Euler's method to see if odeint was the bottleneck but nothing helped. I think I need to work harder to make the parameter space firendlier for NUTS.

- using inverse parameters (average time) is more intuitive but in log space that's just the negative so it shouldn't affect sampler performance
- https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/
- https://discourse.pymc.io/t/pymc3-slows-rapidly-with-increasing-numbers-of-parameters/134/2
- Should try to scale parameters to have similar variances
- Should try to reparameterize so parameters are as independent as possible
- Should try to identify unidentifiable parameters and determine what to do with those
  - If they are independent, we can fix them in sampling then vary them in prediction
- https://docs.pymc.io/notebooks/ODE_API_introduction.html
- https://docs.pymc.io/notebooks/ODE_API_shapes_and_benchmarking.html
- https://docs.pymc.io/notebooks/ODE_with_manual_gradients.html
- https://colindcarroll.com/2019/11/29/highlights-of-pymc3-v3.8/

In [1]:
from itertools import product

In [2]:
import holoviews as hv
from holoviews.operation import gridmatrix
import numpy as np
import pandas as pd
import pymc3 as pm
from pymc3.ode import DifferentialEquation
from scipy.integrate import odeint

In [3]:
hv.notebook_extension('bokeh', logo=False)
%opts Overlay [aspect=5/3, responsive=True]

In [4]:
def plot_trace(trace, varnames=None, tune=0):
    """Plot the distribution and trace for each latent variable in a pymc trace object.

    trace: the trace output from pymc.sample
    varnames: Optional specification of variables to include in the trace plot. If None, use all variables not ending with '_'
    """
    vline = hv.VLine(tune).options(color='grey', line_width=1, line_dash='dashed', aspect=3, responsive=True)
    plots = []
    for var in varnames or [var for var in trace.varnames if not var.endswith('_')]:
        x = trace.get_values(var, combine=False)
        if not isinstance(x, list):
            x = [x]
        plots.append(
            hv.Overlay([hv.Distribution(xi[tune:], [var], [f'p({var})']) for xi in x], group=var)
            .options(aspect=3, responsive=True)
        )
        plots.append(
            hv.Overlay([hv.Curve(xi, 'index', var).options(alpha=0.6) for xi in x] + [vline])
            .options(aspect=3, responsive=True)
        )
    return hv.Layout(plots).cols(2)

In [35]:
def trace_grid(trace, combine=False):
    def append_chain_num(df, n):
        df['chain'] = n
        return df

    if combine:
        df = append_chain_num(pm.trace_to_dataframe(trace), 0)
    else:
        df = pd.concat([append_chain_num(pm.trace_to_dataframe(trace, chains=i), i)
                        for i in range(trace.nchains)])

    def make_scatter(x, y):
        return hv.Overlay([
            hv.Points(df.loc[df['chain'] == chain], [x, y]).options(size=2, alpha=0.2, tools=['box_select'])
            for chain in set(df['chain'])
        ]).options(show_legend=False, aspect=None, responsive=False)

    def make_dist(x):
        return hv.Overlay([
            hv.Distribution(df.loc[df['chain'] == chain], [x], [f'p({x})'])
            for chain in set(df['chain'])
        ]).options(show_legend=False, ylabel=x, aspect=None, responsive=False)

    varnames = [v for v in df.columns[:-1]]
    return hv.GridMatrix({
        (x, y): make_dist(x) if x == y else make_scatter(x, y)
        for x, y in product(varnames, varnames)
    })

In [6]:
def diff_eq(y, t, p):
    s, e, i, i_d, r, r_d, f, f_d = [y[i] for i in range(8)]
    beta, beta_d, sigma, theta, gamma, gamma_d, mu, mu_d = [p[i] for i in range(8)]
    
    newly_exposed = s * (i * beta + i_d * beta_d) / (s + e + i + r)
    newly_infectious = e * sigma
    detections = i * theta
    recoveries = i * gamma
    recoveries_d = i_d * gamma_d
    deaths = i * mu
    deaths_d = i * mu_d
    
    ds = - newly_exposed
    de = newly_exposed - newly_infectious
    di = newly_infectious - detections - recoveries - deaths
    di_d = detections - recoveries_d - deaths_d
    dr = recoveries
    dr_d = recoveries_d
    df = deaths
    df_d = deaths_d
    
    return ds, de, di, di_d, dr, dr_d, df, df_d

In [7]:
t = np.linspace(0, 100, 31)

In [17]:
y0 = [8000, 2000, 0, 0, 0, 0, 0, 0]
beta = 1 / 5
beta_d = 1 / 10
sigma = 1 / 10
theta = 1 / 4
gamma = 1 / 20
gamma_d = 1 / 10
mu = 1 / 50
mu_d = 1 / 100
p = beta, beta_d, sigma, theta, gamma, gamma_d, mu, mu_d

Y = odeint(diff_eq, y0, t, args=(p,))
s, e, i, i_d, r, r_d, f, f_d = Y.T

hv.Area.stack(hv.Overlay([
    hv.Area((t, y), 'time', '# people', label=label)
    for y, label in [
        (i_d, 'infectious (known)'),
        (r_d, 'recovered (known)'),
        (f_d, 'deceased (known)'),
        (i, 'infectious'),
        (r, 'recovered'),
        (f, 'deceased'),
        (e, 'exposed'),
        (s, 'susceptible'),
    ]
]))

In [18]:
Yobs = Y[10:, [3, 5, 7]]

In [26]:
%%time

ode = DifferentialEquation(diff_eq, t, n_states=8, n_theta=8)

sd = 0.5

with pm.Model():

    s0 = pm.Lognormal('s0', np.log(8_000), sd)
    e0 = pm.Lognormal('e0', np.log(2_000), sd)
    y0 = s0, e0, 0, 0, 0, 0, 0, 0
    
    beta = 1 / pm.Lognormal('ibeta', np.log(5), sd)
    beta_d = beta * pm.Uniform('ratio_beta_d', 0, 1)
    sigma = 1 / pm.Lognormal('isigma', np.log(10), sd)
    theta = 1 / pm.Lognormal('itheta', np.log(4), sd)
    igamma =  pm.Lognormal('igamma', np.log(20), sd)
    gamma = 1 / igamma
    gamma_d = 1 / (igamma * pm.Uniform('ratio_gamma_d', 0, 1))
    mu = 1 / pm.Lognormal('imu', np.log(50), sd)
    mu_d = gamma * pm.Uniform('ratio_mu_d', 0, 1)
    p = beta, beta_d, sigma, theta, gamma, gamma_d, mu, mu_d
    
    y = ode(y0, p)
    y_with_obs = y[10:, (3, 5, 7)]
    
    # Setting this to a constant made the sampling MUCH faster but led to divergences
    # HalfNormal may be faster than HalfCauchy
    # error =  pm.HalfCauchy('error', 100.0)
    # error = 0.01
    error = np.sqrt(100**2 + y_with_obs * y_with_obs / 50**2)
    pm.Normal('y', mu=y_with_obs, sd=error, observed=Yobs)
    
    # trace = pm.sample(40, tune=10, target_accept=0.99, compute_convergence_checks=False)
    
    step = pm.NUTS(max_treedepth=6, early_max_treedepth=4, target_accept=0.9)
    trace = pm.sample(40, tune=10, step=step, compute_convergence_checks=False)
    
    # Though we have gradient info with DifferentialEquation, the tutorial at
    # https://docs.pymc.io/notebooks/ODE_with_manual_gradients.html, which seems
    # to predate DifferentialEquation, suggests that sequential monte carlo (SMC)
    # is a good choice for ODEs.
    # trace = pm.sample_smc(100, progressbar=True, parallel=True, cores=8)  # TODO: Try kernel='ABC'
    
plot_trace(trace)

  rval = inputs[0].__getitem__(inputs[1:])
Only 40 samples in chain.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [ratio_mu_d, imu, ratio_gamma_d, igamma, itheta, isigma, ratio_beta_d, ibeta, e0, s0]
Sampling 4 chains, 0 divergences: 100%|██████████| 200/200 [21:25<00:00,  6.43s/draws]
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.


CPU times: user 7.43 s, sys: 845 ms, total: 8.27 s
Wall time: 21min 33s


In [28]:
trace_grid(trace)

In [38]:
%%time

ode = DifferentialEquation(diff_eq, t[10:], n_states=8, n_theta=8)

sd = 0.5

with pm.Model():

    y0 = pm.Lognormal('y0', np.log(1_000), 1, shape=8)
    
    beta = 1 / pm.Lognormal('ibeta', np.log(5), sd)
    beta_d = beta * pm.Uniform('ratio_beta_d', 0, 1)
    sigma = 1 / pm.Lognormal('isigma', np.log(10), sd)
    theta = 1 / pm.Lognormal('itheta', np.log(4), sd)
    igamma =  pm.Lognormal('igamma', np.log(20), sd)
    gamma = 1 / igamma
    gamma_d = 1 / (igamma * pm.Uniform('ratio_gamma_d', 0, 1))
    mu = 1 / pm.Lognormal('imu', np.log(50), sd)
    mu_d = gamma * pm.Uniform('ratio_mu_d', 0, 1)
    p = beta, beta_d, sigma, theta, gamma, gamma_d, mu, mu_d
    
    y = ode(y0, p)
    y_with_obs = y[:, (3, 5, 7)]
    
    error = np.sqrt(100**2 + y_with_obs * y_with_obs / 50**2)
    pm.Normal('y', mu=y_with_obs, sd=error, observed=Yobs)
    
#     trace = pm.sample(40, tune=10, target_accept=0.9, compute_convergence_checks=False)
    
    step = pm.NUTS(max_treedepth=8, early_max_treedepth=6, target_accept=0.9)
    trace = pm.sample(100, tune=40, step=step, cores=8, chains=6, compute_convergence_checks=False)
    
plot_trace(trace)

  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
Only 100 samples in chain.
Multiprocess sampling (6 chains in 8 jobs)
NUTS: [ratio_mu_d, imu, ratio_gamma_d, igamma, itheta, isigma, ratio_beta_d, ibeta, y0]
Sampling 6 chains, 1 divergences:  63%|██████▎   | 533/840 [2:22:09<1:21:52, 16.00s/draws]


CPU times: user 7.71 s, sys: 921 ms, total: 8.63 s
Wall time: 2h 22min 19s


In [39]:
trace_grid(trace)

KeyError: 3