In [1]:
import os

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import gaussian_kde

import jax.numpy as jnp
from jax import random, vmap

import numpyro
import numpyro.distributions as dist

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")


In [2]:
p_grid = jnp.linspace(start=0, stop=1, num=1000)
prior = jnp.repeat(1, 1000)
likelihood = jnp.exp(dist.Binomial(total_count=9, probs=p_grid).log_prob(6))
posterior = likelihood * prior
posterior = posterior / jnp.sum(posterior)
samples = p_grid[dist.Categorical(posterior).sample(random.PRNGKey(100), (10000,))]

In [16]:
# 3E1
(samples < .2).sum() / 1e4

DeviceArray(0.001, dtype=float32)

In [17]:
# 3E2
(samples < 0.8).sum() / 1e4

DeviceArray(0.879, dtype=float32)

In [21]:
# 3E3
((0.2 < samples) & (samples < 0.8)) / 1e4

DeviceArray([1.e-04, 1.e-04, 1.e-04, ..., 1.e-04, 1.e-04, 1.e-04], dtype=float32)

In [41]:
# 3E4 - 20% of the posterior probability lies below what value of p?
def infer_bruteforce_lt(samples: type(samples), x: float, rtol: float = 1e-5) -> float: 
    """x percent of posterior probability lies below what value of p?"""
    answers = list()
    for k in  jnp.linspace(start=0, stop=1, num=1000): 
        if jnp.isclose((samples < k).sum() / samples.shape[0], x, rtol=rtol): 
            answers.append(k)
    if not answers: 
        raise ValueError("needs to be called with a more permissive tolerance for this data")
    return jnp.array(answers).mean()

infer_bruteforce_lt(samples, 0.2, 1e-2)

DeviceArray(0.5195195, dtype=float32)

In [43]:
# 3E5 - 20% of the posterior probability lies above what value of p?
def infer_bruteforce_gt(samples: type(samples), x: float, rtol: float = 1e-5) -> float: 
    """x percent of posterior probability lies below what value of p?"""
    answers = list()
    for k in  jnp.linspace(start=0, stop=1, num=1000): 
        if jnp.isclose((samples > k).sum() / samples.shape[0], x, rtol=rtol): 
            answers.append(k)
    if not answers: 
        raise ValueError("needs to be called with a more permissive tolerance for this data")
    return jnp.array(answers).mean()

infer_bruteforce_gt(samples, 0.2, 1e-2)

DeviceArray(0.7602603, dtype=float32)