In [1]:
# Import libraries.
import numpy as np
import pymc as pm

from scipy import stats
from scipy import integrate, optimize

import matplotlib.pyplot as plt

In [2]:
# British boarding school flu data
t_obs = np.arange(15)
I_obs = np.array([1, 3, 6, 25, 73, 222, 294, 258, 237, 191, 125, 69, 27, 11, 4])

N = 764
S_0 = 763
I_0 = 1
R_0 = 0

In [3]:
# SIR ODEs

# Defines the system of differential equations
# for the SIR model.
def sir(sir, t, beta, nu):
    S, I, R = sir
    dSdt = -beta * S * I / N
    dIdt = beta * S * I / N - nu * I
    dRdt = nu * I
    return dSdt, dIdt, dRdt

# Returns the infected data for the SIR model with 
# inputted parameters beta and nu.
def fit_odeint(t, beta, nu):
    fit = integrate.odeint(sir, (S_0, I_0, R_0), t_obs, args = (beta, nu))
    return fit[:, 1]

In [4]:
# Solve for the optimal values via curve-fitting
curv_opt, cov = optimize.curve_fit(fit_odeint, t_obs, I_obs)

In [5]:
# Create likelihood, prior, and posterior
def likelihood(state):
    beta, nu = state
    I_fit = fit_odeint(t_obs, beta, nu)
    lik = stats.poisson.logpmf(I_obs, I_fit).sum()
    return lik

def prior(state):
    beta, nu = state
    return stats.norm.logpdf(beta, 1.5, 0.25) + stats.norm.logpdf(nu, 0.5, 0.1)

def posterior(state):
    return likelihood(state) + prior(state)

In [6]:
# Metropolis algorithm

# Proposal function
def q(state):
    return stats.norm.rvs(state, [0.1, 0.1], size = 2)

# MCMC
def metropolis(start, num_iter):
    chain = np.zeros((num_iter + 1, 2))
    chain[0] = start
    
    for i in range(num_iter):
        if i % 100 == 0:
            print(np.round(chain[i], 3), '\t', np.round(likelihood(chain[i]), 3))
        proposal = q(chain[i])
        p = np.exp(posterior(proposal) - posterior(chain[i]))
        if stats.uniform.rvs() < p:
            chain[i + 1] = proposal
        else:
            chain[i + 1] = chain[i]
    return chain

In [7]:
# Solve for optimal values via MCMC
chain = metropolis([1, 0.1], 3000)
mcmc_opt = chain[-1, :]

[1.  0.1] 	 -1979.049
[1.673 0.489] 	 -83.234
[1.698 0.495] 	 -83.049
[1.7   0.493] 	 -82.906
[1.706 0.485] 	 -82.997
[1.667 0.483] 	 -83.345
[1.665 0.471] 	 -83.51
[1.707 0.479] 	 -83.233
[1.713 0.491] 	 -83.757
[1.654 0.452] 	 -86.986
[1.676 0.472] 	 -82.732
[1.691 0.5  ] 	 -83.686
[1.664 0.475] 	 -83.411
[1.705 0.484] 	 -82.922
[1.705 0.467] 	 -84.356
[1.678 0.48 ] 	 -82.438
[1.678 0.48 ] 	 -82.438
[1.678 0.48 ] 	 -82.438
[1.682 0.473] 	 -82.544
[1.682 0.473] 	 -82.544
[1.682 0.473] 	 -82.544
[1.71  0.485] 	 -83.346
[1.673 0.478] 	 -82.689
[1.656 0.477] 	 -84.403
[1.692 0.496] 	 -83.135
[1.675 0.476] 	 -82.602
[1.673 0.483] 	 -82.8
[1.705 0.487] 	 -82.937
[1.698 0.488] 	 -82.548
[1.682 0.476] 	 -82.361


In [8]:
print('Curve-fitting:\t', curv_opt)
print('Metropolis:\t', mcmc_opt)

Curve-fitting:	 [1.66427741 0.44582665]
Metropolis:	 [1.68223192 0.47617593]
