In [None]:
import os

from IPython.display import set_matplotlib_formats
import jax.numpy as np
from jax import random, vmap
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
import numpy as onp
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

plt.style.use('bmh')
if "NUMPYRO_SPHINXBUILD" in os.environ:
    set_matplotlib_formats('svg')

assert numpyro.__version__.startswith('0.2.4')

In [None]:
DATASET_URL = 'https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv'
dset = pd.read_csv(DATASET_URL, sep=';')
dset

In [None]:
vars = ["Population", "MedianAgeMarriage", "Marriage", "WaffleHouses", "South", "Divorce"]
sns.pairplot(dset, x_vars=vars, y_vars=vars, palette="husl")

sns.regplot("WaffleHouses", "Divorce", dset);


In [None]:
standardize = lambda x: (x - x.mean()) / x.std()
    
dset["AgeScaled"] = dset.MedianAgeMarriage.pipe(standardize)
dset['MarriageScaled'] = dset.Marriage.pipe(standardize)
dset['DivorceScaled'] = dset.Divorce.pipe(standardize)


In [None]:
def model(marriage=None, age=None, divorce=None):
    a = numpyro.sample("A", dist.Normal(0., 0.2))
    M, A = 0., 0.
    
    if marriage is not None:
        bM = numpyro.sample("bM", dist.Normal(0., 0.5))
        M = bM * marriage
    if age is not None:
        bA = numpyro.sample("bA", dist.Normal(0., 0.5))
        A = bA * age
    sigma = numpyro.sample("sigma", dist.Exponential(1.))
    mu = a + M + A
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)



In [None]:
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

num_warmup, num_samples = 1000, 2000

# Run NUTS
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.values)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

In [None]:
def plot_regression(x, y_mean, y_hpdi):
    idx = np.argsort(x)
    marriage = x[idx]
    mean = y_mean[idx]
    hpdi = y_hpdi[:, idx]
    divorce = dset.DivorceScaled.values[idx]
    
    # Plot
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
    ax.plot(marriage, mean)
    ax.plot(marriage, divorce, "o")
    ax.fill_between(marriage, hpdi[0], hpdi[1], alpha=0.3, interpolate=True)
    return ax

posterior_mu = np.expand_dims(samples_1["A"], -1) + \
               np.expand_dims(samples_1['bM'], -1) * dset.MarriageScaled.values

mean_mu = np.mean(posterior_mu, axis=0)
hpdi_mu = hpdi(posterior_mu, 0.9)
ax = plot_regression(dset.MarriageScaled.values, mean_mu, hpdi_mu)
ax.set(xlabel="Marriage rate", ylabel="Divorce rate", title="Regression line with 90% CI")