In [24]:
### copied from https://jckantor.github.io/CBE30338/03.09-COVID-19.html

import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# parameter values
R0 = 5.15 # how many friends on avg each user has 
# t_infective = 5.1 + 3.3

# initial number of infected and recovered individuals
i_initial = 1/336224
s_initial = 1 - i_initial

# gamma = 1/t_infective
beta = R0

def discrete_laplacian(u):
    L = -2*u
    L += np.roll(u, -1)
    L += np.roll(u, +1)
    L[0]   = 2*(-u[0]+u[1])
    L[-1] = 2*(-u[-1]+u[-2])
    return L

# SIR model differential equations.
def deriv(x, t, beta,gamma):
    s, i = x
    dsdt = -beta * s * i + discrete_laplacian(s)
    didt = beta * s * i + discrete_laplacian(i)
    return [dsdt, didt]

t = np.linspace(0, 180, 2000)
x_initial = s_initial, i_initial
soln = odeint(deriv, x_initial, t, args=(beta,gamma))
s, i = soln.T
e = None

def plotdata(t, s, i, e=None):
    # plot the data
    fig = plt.figure(figsize=(12,6))
    ax = [fig.add_subplot(221, axisbelow=True),
          fig.add_subplot(223),
          fig.add_subplot(122)]

    ax[0].plot(t, s, lw=3, label='Fraction Susceptible')
    ax[0].plot(t, i, lw=3, label='Fraction Infective')
    ax[0].set_title('Susceptible and Infected Populations')
    ax[0].set_xlabel('Time /days')
    ax[0].set_ylabel('Fraction')

    ax[1].plot(t, i, lw=3, label='Infective')
    ax[1].set_title('Infectious Population')
    if e is not None: ax[1].plot(t, e, lw=3, label='Exposed')
    ax[1].set_ylim(0, 0.3)
    ax[1].set_xlabel('Time /days')
    ax[1].set_ylabel('Fraction')

    ax[2].plot(s, i, lw=3, label='s, i trajectory')
    ax[2].plot([1/R0, 1/R0], [0, 1], '--', lw=3, label='di/dt = 0')
    ax[2].plot(s[0], i[0], '.', ms=20, label='Initial Condition')
    ax[2].plot(s[-1], i[-1], '.', ms=20, label='Final Condition')
    ax[2].set_title('State Trajectory')
    ax[2].set_aspect('equal')
    ax[2].set_ylim(0, 1.05)
    ax[2].set_xlim(0, 1.05)
    ax[2].set_xlabel('Susceptible')
    ax[2].set_ylabel('Infectious')

    for a in ax:
        a.grid(True)
        a.legend()

    plt.tight_layout()

plotdata(t, s, i)

IndexError: invalid index to scalar variable.