# TLDR

The contact distribution from positive cases is a biased representation of the underlying contact distribution of the population. (People with more contacts are more likely to become infected and contribute to the observed contact distribution, and vice versa.)

We would like to use Bayesian inference to obtain a more accurate underlying contact distribution. 

In [29]:
%load_ext autoreload
%autoreload 2

import numpy as np

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO, Predictive

import torch
import pandas as pd
import time

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
def observe_contact_distribution(N: int, lamda: float, beta: float, data: torch.Tensor =None):

    r""" 
    Args:
        N: population size
        lamda: prevalence
        beta: probability of infection given contact with positive
    """

    # p = distribution of the number of contacts of a person; use Dirichlet prior
    p = pyro.sample(
        'p', 
        dist.Dirichlet(concentration=torch.tensor([0.1, 0.1, 0.1])), 
    )
    binomial_counts = pyro.deterministic("binomial_counts", (N*p).int())
    
    # print('p', p)
    # print('binomial_counts', binomial_counts)

    p_inf = 1-(1-lamda*beta)**torch.arange(3)

    with pyro.plate('data'):
        obs = pyro.sample(
            'obs', 
            dist.Binomial(
                total_count=binomial_counts, 
                probs=p_inf),
            obs=data 
        )
    return obs

In [39]:
N=200 # population size
lamda = 0.1 # prevalence
beta=0.1 # probability of infection given contact with positive person
y = torch.tensor([5,5,10]) # observed contact distribution, i.e., 5 have 0, 5 have 1, 10 have 2 contacts

nuts_kernel = NUTS(observe_contact_distribution) # also called "sampler"
mcmc = MCMC(nuts_kernel, warmup_steps=100, num_samples=500)


In [40]:
start_time = time.time()
mcmc.run(N, lamda, beta, y)
print(f"Took {time.time()-start_time} seconds")

Sample: 100%|██████████| 600/600 [39:47,  3.98s/it, step size=8.08e-04, acc. prob=0.862]

Took 2387.341058731079 seconds





In [41]:
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}


In [42]:
# Utility function to print latent sites' quantile information.
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

In [43]:
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.314451  0.089374  0.183863  0.251462  0.307641  0.376311  0.466320
1  0.245077  0.097898  0.122148  0.175583  0.218205  0.305509  0.425191
2  0.440472  0.111015  0.254480  0.370630  0.440094  0.504709  0.638854 

