# Epidemiology II: Regional Models

In [None]:
import os
import logging
import urllib.request
from collections import OrderedDict

import pandas as pd
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist
from pyro.ops.tensor_utils import convolve

%matplotlib inline
torch.set_default_dtype(torch.double)  # Needed for numerical stability.
pyro.enable_validation(True)           # Always a good idea.
print(torch.__version__)
print(pyro.__version__)

# This is useful for debugging.
# logging.getLogger("pyro").setLevel(logging.DEBUG)
# logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)

## Getting data

Let's consider regional data for Bay Area counties. We'll use data from [Johns Hopkins University](https://github.com/CSSEGISandData).

In [None]:
url = ("https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/"
            "csse_covid_19_data/csse_covid_19_time_series/")

def download_df(basename):
    local_path = os.path.join("/tmp", basename)
    if not os.path.exists(local_path):
        urllib.request.urlretrieve(url + basename, local_path)
    return pd.read_csv(local_path)

cum_cases_df = download_df("time_series_covid19_confirmed_US.csv")
cum_deaths_df = download_df("time_series_covid19_deaths_US.csv")
cum_cases_df.columns[:13]

We'll pull out the Bay Area counties with their approximate populations.

In [None]:
counties = OrderedDict([
    ("Santa Clara", 1763000),
    ("Alameda", 1495000),
    ("Contra Costa", 1038000),
    ("San Francisco", 871000),
    ("San Mateo", 712000),
    ("Sonoma", 479000),
    ("Solano", 412000),
    ("Marin", 251000),
    ("Napa", 135000),
])
population = torch.tensor(list(counties.values()), dtype=torch.double)

And convert from dataframe to PyTorch tensor.

In [None]:
cum_cases = []
cum_deaths = []
for county in counties:
    i = list(cum_cases_df["Admin2"]).index(county)
    cum_cases.append(cum_cases_df.iloc[i, 11:])
    i = list(cum_deaths_df["Admin2"]).index(county)
    cum_deaths.append(cum_deaths_df.iloc[i, 12:])
cum_cases = torch.tensor(cum_cases, dtype=torch.float64).T.contiguous()
cum_deaths = torch.tensor(cum_deaths, dtype=torch.float64).T.contiguous()
assert cum_cases.shape == cum_deaths.shape
print(cum_cases.shape, cum_deaths.shape)

We need to convert from cumulative to difference data. However the original data is inconsistent due to later corrections of earlier errors, and is noisy due to reporting jitter. To clean up both of these artifacts, we will smooth both cumulative data sources with a square window, effectively attributing a report recoreded at date ``t`` to a true report at date ``t-n`` where ``n`` is uniformly distributed in an interval ``[0, window)``.

In [None]:
T = len(cum_cases)
for window in range(1, 100):
    kernel = torch.ones(window) / window
    smooth_cases = convolve(cum_cases.T, kernel).T[:T].round()
    smooth_deaths = convolve(cum_deaths.T, kernel).T[:T].round()
    new_cases = smooth_cases[1:] - smooth_cases[:-1]
    new_deaths = smooth_deaths[1:] - smooth_deaths[:-1]
    if (new_cases >= 0).all() and (new_deaths >= 0).all():
        break
print("window = {}".format(window))
print("shape = {}".format(tuple(new_cases.shape)))

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(8,6), sharex=True)
axes[0].plot(new_cases)
axes[1].plot(new_deaths)
axes[0].set_ylabel("new cases")
axes[1].set_ylabel("new deaths")
axes[1].set_xlabel("Day after 1/23/2020")
axes[0].set_xlim(0, len(cum_cases) - 1)
axes[0].set_title("COVID-19 cases in nine Bay Area Counties")
plt.subplots_adjust(hspace=0)

## Creating a regional model

To create a realistic model of Bay Area data, we'll combine aspects of a number of simpler models: [OverdispersedSEIRModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#overdispersed-seir), [UnknownStartSIRModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#unknown-start-sir), and [RegionalSIRModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#regional-sir).

In [None]:
class Model(CompartmentalModel):
    def __init__(self, population, coupling, new_cases, new_deaths,
                 incubation_time=4.0,
                 recovery_time=14.0,
                 external_rate=0.01):
        duration, num_regions = new_deaths.shape
        assert new_cases.shape == new_deaths.shape
        assert len(population) == num_regions
        assert coupling.shape == (num_regions, num_regions)
        assert (0 <= coupling).all()
        assert (coupling <= 1).all()
        
        compartments = ("S", "E", "I")  # R is implicit.
        super().__init__(compartments, duration, population, approximate=("I",))
    
        self.incubation_time = incubation_time
        self.recovery_time = recovery_time
        self.external_rate = external_rate
        self.coupling = coupling
        self.new_cases = new_cases
        self.new_deaths = new_deaths

    def global_model(self):
        tau_e = self.incubation_time
        tau_i = self.recovery_time
        R0 = pyro.sample("R0", dist.LogNormal(1., 0.5))  # Weak prior.
        rho = pyro.sample("rho", dist.Beta(10, 10))  # About 50% response rate.
        mu = pyro.sample("mu", dist.Beta(100, 2))  # About 2% mortality rate.
        od = pyro.sample("od", dist.Beta(2, 6))
        return R0, tau_e, tau_i, rho, mu, od

    def initialize(self, params):
        # Start with no local infections and base reproductive number.
        return {"S": self.population,
                "E": torch.zeros_like(self.population),
                "I": torch.zeros_like(self.population),
                "beta": torch.tensor(1.)}

    def transition(self, params, state, t):
        R0, tau_e, tau_i, rho, mu, od = params

        # Assume effective reproductive number Re varies in time,
        # but is the same across all regions (say due to synchronized policy).
        beta = pyro.sample("beta_{}".format(t),
                           dist.LogNormal(state["beta"].log(), 0.1))
        Re = pyro.deterministic("Re_{}".format(t), R0 * beta)

        # Account for strong intra-region infections, weak inter-region infections, and
        # even weaker background infections from external sources. This uses approximate
        # (point estimate) counts I_approx for infection from other regions, but uses
        # the exact (enumerated) count I for infections from one's own region.
        I_external = self.external_rate * tau_i / Re
        I_coupled = state["I_approx"] @ self.coupling + I_external
        I_coupled = I_coupled + (state["I"] - state["I_approx"]) * self.coupling.diag()
        I_coupled = I_coupled.clamp(min=0)  # In case I_approx is negative.
        pop_coupled = self.population @ self.coupling
        
        with self.region_plate:
            # Sample flows between compartments.
            S2E = pyro.sample("S2E_{}".format(t),
                              infection_dist(individual_rate=Re / tau_i,
                                             num_susceptible=state["S"],
                                             num_infectious=I_coupled,
                                             population=self.population,
                                             overdispersion=od))
            E2I = pyro.sample("E2I_{}".format(t),
                              binomial_dist(state["E"], 1 / tau_e, overdispersion=od))
            I2R = pyro.sample("I2R_{}".format(t),
                              binomial_dist(state["I"], 1 / tau_i, overdispersion=od))

            # Update compartments and heterogeneous variables.
            state["S"] = state["S"] - S2E
            state["E"] = state["E"] + S2E - E2I
            state["I"] = state["I"] + E2I - I2R
            state["beta"] = beta  # We store the latest beta value in the state dict.

            # Condition on observations.
            t_is_observed = isinstance(t, slice) or t < self.duration
            pyro.sample("new_cases_{}".format(t),
                        binomial_dist(S2E, rho, overdispersion=od),
                        obs=self.new_cases[t] if t_is_observed else None)
            pyro.sample("new_deaths_{}".format(t),
                        binomial_dist(I2R, mu, overdispersion=od),
                        obs=self.new_deaths[t] if t_is_observed else None)

We'll also need to specify a coupling matrix. We could be Bayesian about this and possibly use distance or mobility data. For simplicity we'll assume inter-region infection is 1/10 as strong as intra-region infection.

In [None]:
coupling = torch.eye(len(population)).clamp_(min=0.1)
model = Model(population, coupling, new_cases, new_deaths)

Now we'll train the model using MCMC.

In [None]:
pyro.set_rng_seed(20200607)
mcmc = model.fit(warmup_steps=300, num_samples=200, num_quant_bins=2, haar_full_mass=15)

In [None]:
mcmc.summary()