# Intro to Epidemiological models

This notebook introduces the [pyro.contrib.epidemiology](http://docs.pyro.ai/en/latest/contrib.epidemiology.html) module, an epidemiological modeling language with a number of black box inference algorithms. This tutorial assumes you have already learned the basics of [modeling](http://pyro.ai/examples/intro_part_ii.html), [inference](http://pyro.ai/examples/intro_part_ii.html), and [distribution shapes](http://pyro.ai/examples/tensor_shapes.html).

#### Summary

- To create a new model, inherit from the [CompartmentalModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel) base class.
- Override methods [.global_model()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.global_model), [.initialize(params)](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.initialize), and [.transition(params, state, t)](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.transition).
- Take care to support broadcasting and vectorized interpretation in those methods.
- For single time series, set `population` to an integer.
- For batched time series, let `population` be a vector, and use [self.region_plate](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel).
- For models with complex inter-compartment flows, override the [.compute_flows()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.compute_flows) method. 
- Flows with loops (undirected or directed) are not currently supported.
- To perform cheap approximate inference via SVI, call the [.fit_svi()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.fit_svi) method.
- To perform more expensive inference via MCMC, call the [.fit_mcmc()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.fit_mcmc) method.
- To stochastically predict latent and future variables, call the [.predict()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.predict) method.

#### Table of contents

In [None]:
import os
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist

%matplotlib inline
assert pyro.__version__.startswith('1.3.1')
pyro.enable_validation(True)
smoke_test = ('CI' in os.environ)

## Scope

The [pyro.contrib.epidemiology](http://docs.pyro.ai/en/latest/contrib.epidemiology.html) module provides a modeling language for a class of discrete-valued discrete-time stochastic epidemiological compartmental models, together with a number of black box inference algorithms to perform joint inference on global parameters and latent variables. This modeling language is more restrictive than the full Pyro probabilistic programming language:

- control flow must be static;
- compartmental distributions are restricted to [binomial_dist()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.distributions.binomial_dist), [beta_binomial_dist()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.distributions.beta_binomial_dist), and [infection_dist()](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.distributions.infection_dist);
- plates are not allowed, except for the single optional [.region_plate](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.compartmental.CompartmentalModel.region_plate);
- all random variables must be either global or Markov and sampled at every time step, so e.g. windowed random variables are not supported;
- models must support broadcasting and vectorization of time `t`.

These restrictions allow inference algorithms to vectorize over the time dimension, leading to inference algorithms with per-iteration parallel complexity sublinear in length of the time axis. The restriction on distributions allows inference algorithms to approximate parts of the model as Gaussian via moment matching, further speeding up inference. Finally, because real data is so often overdispersed relative to Binomial idealizations, the three distribution helpers provide an [overdispersion](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.distributions.binomial_dist) parameter calibrated so that in the large-population limit all distribution helpers converge to log-normal.

Black box inference algorithms currently include: [SVI](http://docs.pyro.ai/en/latest/inference_algos.html#pyro.infer.svi.SVI) with a moment-matching approximation, and [NUTS](http://docs.pyro.ai/en/latest/mcmc.html#pyro.infer.mcmc.NUTS) either with a moment-matched approximation or with an exact auxiliary variable method detailed in the [SIR HMC tutorial](http://pyro.ai/examples/sir_hmc.html). All three algorithms initialize using [SMC](http://docs.pyro.ai/en/latest/inference_algos.html#pyro.infer.smcfilter.SMCFilter) and reparameterize time dependent variables using a fast [Haar wavelet](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.haar.HaarReparam) transform. Default inference parameters are set for cheap approximate results; accurate results will require more steps and ideally comparison among different inference algorithms. We especially recommend running mcmc algorithms with multiple chains to diagnose mixing issues.

## Modeling

The [pyro.contrib.epidemiology.models](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#module-pyro.contrib.epidemiology.models) module provides a number of example models. While in principle these are reusable, we recommend forking and modifying these models for your task. Let's take a look at one of the simplest examples, [SimpleSIRModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.models.SimpleSIRModel):

In [None]:
class SimpleSIRModel(CompartmentalModel):
    def __init__(self, population, recovery_time, data):
        compartments = ("S", "I")  # R is implicit.
        duration = len(data)
        super().__init__(compartments, duration, population)
        assert isinstance(recovery_time, float)
        assert recovery_time > 1
        self.recovery_time = recovery_time
        self.data = data

    def global_model(self):
        tau = self.recovery_time
        R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
        rho = pyro.sample("rho", dist.Beta(10, 10))
        return R0, tau, rho

    def initialize(self, params):
        # Start with a single infection.
        return {"S": self.population - 1, "I": 1}

    def transition(self, params, state, t):
        R0, tau, rho = params

        # Sample flows between compartments.
        S2I = pyro.sample("S2I_{}".format(t),
                          infection_dist(individual_rate=R0 / tau,
                                         num_susceptible=state["S"],
                                         num_infectious=state["I"],
                                         population=self.population))
        I2R = pyro.sample("I2R_{}".format(t),
                          binomial_dist(state["I"], 1 / tau))

        # Update compartments with flows.
        state["S"] = state["S"] - S2I
        state["I"] = state["I"] + S2I - I2R

        # Condition on observations.
        t_is_observed = isinstance(t, slice) or t < self.duration
        pyro.sample("obs_{}".format(t),
                    binomial_dist(S2I, rho),
                    obs=self.data[t] if t_is_observed else None)