# TLDR

The contact distribution from positive cases is a biased representation of the underlying contact distribution of the population. (People with more contacts are more likely to become infected and contribute to the observed contact distribution, and vice versa.)

We would like to use Bayesian inference to obtain a more accurate underlying contact distribution. 

In [1]:
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import DirichletMultinomial

from numpyro.infer import MCMC, NUTS

from jax import random


  from .autonotebook import tqdm as notebook_tqdm


In [21]:
def observe_contact_distribution(N: int, lamda: float, beta: float, y: np.array =None):

    r""" 
    Args:
        N: population size
        lamda: prevalence
        beta: probability of infection given contact with positive
    """

    # p = distribution of the number of contacts of a person; use Dirichlet prior
    p = numpyro.sample(
        'p', 
        DirichletMultinomial(concentration=np.array([0.1, 0.1, 0.1])), 
        rng_key=random.PRNGKey(1)
    )

    # loop over k, the number of contacts, and sample the number of infections with k contacts
    for k in range(3):
        p_inf = 1-(1-lamda*beta)**k 
        numpyro.sample(f'infected_w_num_contacts_{k}', dist.Binomial(total_count = N * p[k], probs = p_inf))


    # NOTE: logic that samples each person as a Bernoulli
    # is technically correct but makes it a lot harder to obtain the observed contact distribution 
    # and the naming of RVs is not right, has duplicate

    # with numpyro.plate('N', N):
    #     # for each individual, sample their number of contacts k from the Dirichlet distribution
    #     # then sample a Bernoulli with probability 1-(1-lambda*beta)^k

    #     num_contacts = numpyro.sample('num_contacts', dist.Multinomial(probs=p))
    #     p_inf = 1-(1-lamda*beta)**num_contacts
    #     numpyro.sample('infected', dist.Bernoulli(1, p_inf))
        

In [22]:
N=20
lamda = 0.1
beta=0.1
y = np.array([10,5,5])


nuts_kernel = NUTS(observe_contact_distribution) # also called "sampler"
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)


In [20]:
mcmc.run(rng_key, N=N, lamda = lamda, beta=beta, y=y, extra_fields=('potential_energy',))


RuntimeError: This algorithm might only work for discrete sites with enumerate support. But the DirichletMultinomial distribution at site p does not have enumerate support.