In [1]:
import jax.numpy as jnp
import jax.random as random
import numpyro
from numpyro import deterministic, sample, plate
from numpyro.distributions import Normal, InverseGamma, Exponential
from numpyro.infer import MCMC, NUTS, Predictive
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
df = pd.read_csv("data/radon_dataset.txt")
df.columns = df.columns.str.strip()
state = "MN"
cols = ["county", "floor", "log_radon"]
df2 = df.query("state == @state").assign(log_radon=lambda x: np.log(x.activity.where(x.activity != 0, other=0.1)), county=lambda x: x.county.str.strip().str.title())[cols].reset_index(drop=True)

df2["floor"] = df2["floor"] -1
df2["floor_jittered"] = df2["floor"] + 0.01 * np.random.normal(size=len(df2))

df2[["county", "floor", "log_radon"]]

Unnamed: 0,county,floor,log_radon
0,Aitkin,0,0.788457
1,Aitkin,-1,0.788457
2,Aitkin,-1,1.064711
3,Aitkin,-1,0.000000
4,Anoka,-1,1.131402
...,...,...,...
914,Wright,-1,1.856298
915,Wright,-1,1.504077
916,Wright,-1,1.609438
917,Yellow Medicine,-1,1.308333


In [3]:

mcmc_kwargs = dict(num_samples=2000, num_warmup=2000, num_chains=4)
rng_key = random.PRNGKey(12)
seed1, seed2, seed3, seed4, seed5 = random.split(rng_key, 5)


In [4]:
data_dict = dict(
    floor=jnp.array(df2.floor),
    log_radon=jnp.array(df2.log_radon),
)

In [5]:
def complete_pooling(floor, log_radon=None):
    α = sample("α", Normal(0, 5))
    β = sample("β", Normal(0, 2))

    μ = deterministic("μ", α + β * floor)
    σ = sample("σ", InverseGamma(1, 0.5))
    sample("log_radon", Normal(μ, σ), obs=log_radon)

In [6]:
pooled_mcmc = MCMC(NUTS(complete_pooling), **mcmc_kwargs)

  pooled_mcmc = MCMC(NUTS(complete_pooling), **mcmc_kwargs)


In [7]:
pooled_mcmc.run(
    seed1,
    **data_dict
)

sample: 100%|██████████████████████████████████| 4000/4000 [00:30<00:00, 131.22it/s, 15 steps of size 2.77e-01. acc. prob=0.93]
sample: 100%|████████████████████████████████████| 4000/4000 [06:24<00:00, 10.41it/s, 9 steps of size 2.47e-01. acc. prob=0.96]
sample: 100%|████████████████████████████████████| 4000/4000 [07:02<00:00,  9.46it/s, 3 steps of size 2.91e-01. acc. prob=0.93]
sample: 100%|████████████████████████████████████| 4000/4000 [07:58<00:00,  8.35it/s, 1 steps of size 3.07e-01. acc. prob=0.93]


In [17]:
pooled_mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         α      0.72      0.07      0.72      0.61      0.83   3001.77      1.00
         β     -0.61      0.07     -0.61     -0.74     -0.50   2976.53      1.00
         σ      0.82      0.02      0.82      0.79      0.85   4192.81      1.00

Number of divergences: 0


In [18]:
def partial_pooling(σ_α, county, floor, log_radon=None):

    μ_α = sample("μ_α", Normal(0, 5))

    with plate("counties", N_COUNTIES):
        α = sample("α", Normal(μ_α, σ_α))
        β = sample("β", Normal(0, 1))

    μ = deterministic("μ", α[county] + β[county] * floor)
    τ = sample("τ", InverseGamma(1, 0.5))
    sample("log_radon", Normal(μ, τ), obs=log_radon)

mcmc = MCMC(NUTS(partial_pooling), progress_bar=False, **mcmc_kwargs)

pooled_alphas = {}
pooled_taus = {}
for σ_α in [0.01, 0.02, 0.05, 0.07, 0.09] + list(np.linspace(0.1, 1, num=20)):
    print("Pooling: ", σ_α)
    data_dict.update({"σ_α": σ_α})
    mcmc.run(seed3, **data_dict)
    pooled_alphas[σ_α] = mcmc.get_samples()["α"]
    pooled_taus[σ_α] = mcmc.get_samples()["τ"]
    

Pooling:  0.01


  mcmc = MCMC(NUTS(partial_pooling), progress_bar=False, **mcmc_kwargs)


TypeError: partial_pooling() missing 1 required positional argument: 'county'