# TLDR

The contact distribution from positive cases is a biased representation of the underlying contact distribution of the population. (People with more contacts are more likely to become infected and contribute to the observed contact distribution, and vice versa.)

We would like to use Bayesian inference to obtain a more accurate underlying contact distribution. 

In [9]:
%load_ext autoreload
%autoreload 2

import numpy as np
# import numpyro
# import numpyro.distributions as dist
# from numpyro.distributions import DirichletMultinomial
# from numpyro.infer import MCMC, NUTS, HMC
# from jax import random

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO, Predictive

import torch
import pandas as pd
import time

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
def observe_contact_distribution(N: int, lamda: float, beta: float, y: torch.Tensor =None):

    r""" 
    Args:
        N: population size
        lamda: prevalence
        beta: probability of infection given contact with positive
    """

    # p = distribution of the number of contacts of a person; use Dirichlet prior
    p = pyro.sample(
        'p', 
        dist.Dirichlet(concentration=torch.tensor([0.1, 0.1, 0.1])), 
    )

    # loop over k, the number of contacts, and sample the number of infections with k contacts

    # for k in pyro.plate('data', 3):
    for k in range(3):

        p_inf = 1-(1-lamda*beta)**k 

        # print("k: ", k, ", p_inf: ", p_inf, ", total_count: ", int(N*p[k]))

        pyro.sample(
            #f'infected_w_num_contacts_{k}', 
            f"obs_{k}",
            dist.Binomial(total_count = int(N * p[k]), probs = torch.tensor([p_inf])),
            obs = y[k]
        )
        

In [11]:
N=20 # population size
lamda = 0.1 # prevalence
beta=0.1 # probability of infection given contact with positive person
y = torch.tensor([5,5,10]) # observed contact distribution, i.e., 5 have 0, 5 have 1, 10 have 2 contacts

nuts_kernel = NUTS(observe_contact_distribution) # also called "sampler"
mcmc = MCMC(nuts_kernel, warmup_steps=500, num_samples=5000)


In [12]:
start_time = time.time()
mcmc.run(N, lamda, beta, y)
print(f"Took {time.time()-start_time} seconds")

Warmup:   0%|          | 0/5500 [00:00, ?it/s]

model_args:  (20, 0.1, 0.1, tensor([ 5,  5, 10]))
model_kwargs:  {}
model:  _bound_partial(functools.partial(<function _context_wrap at 0x7f630c562940>, <pyro.infer.autoguide.initialization.InitMessenger object at 0x7f62667be970>, _bound_partial(functools.partial(<function _context_wrap at 0x7f630c562940>, <pyro.poutine.enum_messenger.EnumMessenger object at 0x7f6270052eb0>, _bound_partial(functools.partial(<function _context_wrap at 0x7f630c562940>, <pyro.poutine.infer_config_messenger.InferConfigMessenger object at 0x7f62667be910>, <function observe_contact_distribution at 0x7f62700e94c0>))))))
num_changes, max_tries_initial_params:  1 100
attempt  0
samples:  {'p': tensor([0.0405, 0.7584, 0.2011])}
pe_grad, pe:  {'p': tensor([-0.0879,  0.0581])} tensor(inf)
attempt  1
samples:  {'p': tensor([0.1178, 0.3895, 0.4927])}
pe_grad, pe:  {'p': tensor([-0.0647, -0.0117])} tensor(inf)
attempt  2
samples:  {'p': tensor([0.1809, 0.5329, 0.2862])}
pe_grad, pe:  {'p': tensor([-0.0457,  0.0301])}

ValueError: Model specification seems incorrect - cannot find valid initial params.

In [5]:
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}


In [6]:
# Utility function to print latent sites' quantile information.
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

In [7]:
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: p
       mean       std            5%           25%           50%       75%  \
0  0.174443  0.341887  1.175494e-38  1.175494e-38  2.783717e-10  0.062711   
1  0.366756  0.429835  7.356503e-11  5.900387e-05  6.351625e-02  0.924413   
2  0.458802  0.446026  4.768372e-07  2.090540e-03  3.112279e-01  0.992556   

        95%  
0  0.998156  
1  0.999974  
2  1.000000   



## Below is using numpyro

In [21]:
# numpyro version

def observe_contact_distribution(N: int, lamda: float, beta: float, y: np.array =None):

    r""" 
    Args:
        N: population size
        lamda: prevalence
        beta: probability of infection given contact with positive
    """

    # p = distribution of the number of contacts of a person; use Dirichlet prior
    p = numpyro.sample(
        'p', 
        DirichletMultinomial(concentration=np.array([0.1, 0.1, 0.1])), 
        rng_key=random.PRNGKey(1)
    )

    # loop over k, the number of contacts, and sample the number of infections with k contacts
    for k in range(3):
        p_inf = 1-(1-lamda*beta)**k 
        numpyro.sample(f'infected_w_num_contacts_{k}', dist.Binomial(total_count = N * p[k], probs = p_inf))


    # NOTE: logic that samples each person as a Bernoulli
    # is technically correct but makes it a lot harder to obtain the observed contact distribution 
    # and the naming of RVs is not right, has duplicate

    # with numpyro.plate('N', N):
    #     # for each individual, sample their number of contacts k from the Dirichlet distribution
    #     # then sample a Bernoulli with probability 1-(1-lambda*beta)^k

    #     num_contacts = numpyro.sample('num_contacts', dist.Multinomial(probs=p))
    #     p_inf = 1-(1-lamda*beta)**num_contacts
    #     numpyro.sample('infected', dist.Bernoulli(1, p_inf))
        
    

In [24]:
N=20
lamda = 0.1
beta=0.1
y = np.array([10,5,5])


nuts_kernel = HMC(observe_contact_distribution) # also called "sampler"
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)


In [25]:
mcmc.run(rng_key, N=N, lamda = lamda, beta=beta, y=y, extra_fields=('potential_energy',))


RuntimeError: This algorithm might only work for discrete sites with enumerate support. But the DirichletMultinomial distribution at site p does not have enumerate support.