# TLDR

The contact distribution from positive cases is a biased representation of the underlying contact distribution of the population. (People with more contacts are more likely to become infected and contribute to the observed contact distribution, and vice versa.)

We would like to use Bayesian inference to obtain a more accurate underlying contact distribution. 

In [39]:
import numpy as np
# import numpyro
# import numpyro.distributions as dist
# from numpyro.distributions import DirichletMultinomial
# from numpyro.infer import MCMC, NUTS, HMC
# from jax import random

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO, Predictive

import torch
import pandas as pd
import time

In [25]:
def observe_contact_distribution(N: int, lamda: float, beta: float, y: np.array =None):

    r""" 
    Args:
        N: population size
        lamda: prevalence
        beta: probability of infection given contact with positive
    """

    # p = distribution of the number of contacts of a person; use Dirichlet prior
    p = pyro.sample(
        'p', 
        dist.Dirichlet(concentration=torch.tensor([0.1, 0.1, 0.1])), 
        # rng_key=random.PRNGKey(1)
    )

    # loop over k, the number of contacts, and sample the number of infections with k contacts
    for k in range(3):
        p_inf = 1-(1-lamda*beta)**k 
        # print("k: ", k, "p_inf: ", p_inf, "total_count: ", int(N*p[k]))
        pyro.sample(f'infected_w_num_contacts_{k}', dist.Binomial(total_count = int(N * p[k]), probs = torch.tensor([p_inf])))


    # NOTE: logic that samples each person as a Bernoulli
    # is technically correct but makes it a lot harder to obtain the observed contact distribution 
    # and the naming of RVs is not right, has duplicate

    # with numpyro.plate('N', N):
    #     # for each individual, sample their number of contacts k from the Dirichlet distribution
    #     # then sample a Bernoulli with probability 1-(1-lambda*beta)^k

    #     num_contacts = numpyro.sample('num_contacts', dist.Multinomial(probs=p))
    #     p_inf = 1-(1-lamda*beta)**num_contacts
    #     numpyro.sample('infected', dist.Bernoulli(1, p_inf))
        

In [44]:
N=20
lamda = 0.1
beta=0.1
y = torch.tensor([20,5,5])


nuts_kernel = NUTS(observe_contact_distribution) # also called "sampler"
mcmc = MCMC(nuts_kernel, warmup_steps=500, num_samples=10000)
# rng_key = random.PRNGKey(0)


In [45]:
start_time = time.time()
mcmc.run(N, lamda, beta, y)
print(f"Took {time.time()-start_time} seconds")

Warmup:   1%|          | 97/10500 [00:04, 33.89it/s, step size=5.17e+00, acc. prob=0.781]

In [41]:
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}


In [42]:
# Utility function to print latent sites' quantile information.
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

In [43]:
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: p
       mean       std            5%       25%       50%       75%       95%
0  0.349389  0.418804  3.849434e-11  0.000049  0.064225  0.869990  0.999715
1  0.277617  0.381926  1.504859e-14  0.000009  0.018681  0.595093  0.997639
2  0.372994  0.423901  1.492209e-07  0.000464  0.098867  0.898001  0.999970 



## Below is using numpyro

In [21]:
# numpyro version

def observe_contact_distribution(N: int, lamda: float, beta: float, y: np.array =None):

    r""" 
    Args:
        N: population size
        lamda: prevalence
        beta: probability of infection given contact with positive
    """

    # p = distribution of the number of contacts of a person; use Dirichlet prior
    p = numpyro.sample(
        'p', 
        DirichletMultinomial(concentration=np.array([0.1, 0.1, 0.1])), 
        rng_key=random.PRNGKey(1)
    )

    # loop over k, the number of contacts, and sample the number of infections with k contacts
    for k in range(3):
        p_inf = 1-(1-lamda*beta)**k 
        numpyro.sample(f'infected_w_num_contacts_{k}', dist.Binomial(total_count = N * p[k], probs = p_inf))


    # NOTE: logic that samples each person as a Bernoulli
    # is technically correct but makes it a lot harder to obtain the observed contact distribution 
    # and the naming of RVs is not right, has duplicate

    # with numpyro.plate('N', N):
    #     # for each individual, sample their number of contacts k from the Dirichlet distribution
    #     # then sample a Bernoulli with probability 1-(1-lambda*beta)^k

    #     num_contacts = numpyro.sample('num_contacts', dist.Multinomial(probs=p))
    #     p_inf = 1-(1-lamda*beta)**num_contacts
    #     numpyro.sample('infected', dist.Bernoulli(1, p_inf))
        
    

In [24]:
N=20
lamda = 0.1
beta=0.1
y = np.array([10,5,5])


nuts_kernel = HMC(observe_contact_distribution) # also called "sampler"
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)


In [25]:
mcmc.run(rng_key, N=N, lamda = lamda, beta=beta, y=y, extra_fields=('potential_energy',))


RuntimeError: This algorithm might only work for discrete sites with enumerate support. But the DirichletMultinomial distribution at site p does not have enumerate support.