# Bag of Marbles

In [1]:
import jax
import numpy as np
import numpyro as pyr
import numpyro.distributions as dist

from numpyro.infer import MCMC, NUTS, HMC




Suppose we have a bag containing an unknown number blue and green marbles.

If we knew the number of each colour, we could compute the probability of randomly picking a green marble as follows:

$$
\theta = \frac{\text{#green}}{\text{#green} + \text{#blue}}
$$

Formally, the probability of picking a green marble $Y=\text{green}$ given that we know the number of each colour is:

$$
P( Y=\text{green} \mid \theta) = \theta
$$

Based on our previous knowledge about these kind of probability exercises, we can formulate a prior probability $P(\theta)$ of what we think may be the likely ratio of $\theta$. 

Say we we were allowed to pick one marble out of the bag at the time. First time, we pick a green, then blue, followed by green and then green. Now seeing the evidence, we would update our belief about $\theta$ by computing the posterior probability:

$$
P(\theta \mid Y = \text{green},\text{blue},\text{green},\text{green})
$$

In [2]:
n_green = 10*1
n_blue = 3*1
n_total = n_green + n_blue
theta = n_green / n_total

# bag_of_marbles = np.random.binomial(n=1, p=theta, size=n_total)
bag_of_marbles = np.concatenate([
    np.ones(n_green),
    np.zeros(n_blue),
])

In [3]:
theta

0.7692307692307693

In [4]:
np.mean(bag_of_marbles)

0.7692307692307693

In [5]:
def marbles_ratio(data=None):
    p = pyr.sample("p", dist.Uniform(low=0.0, high=1.0))
    y = pyr.sample("y", dist.Bernoulli(probs=p), obs=data)
    

In [6]:
kernel = NUTS(marbles_ratio)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000)

rng_key = jax.random.PRNGKey(42)
rng_key, rng_key_ = jax.random.split(rng_key)

In [7]:
mcmc.run(
    rng_key_, data=bag_of_marbles
)

sample: 100%|███████████████████████████| 6000/6000 [00:02<00:00, 2157.84it/s, 3 steps of size 1.06e+00. acc. prob=0.88]


In [8]:
mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         p      0.74      0.11      0.75      0.57      0.92   2164.73      1.00

Number of divergences: 0
