In [51]:
import pandas as pd
import numpy as np
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random
import jax.numpy as jnp 

In [2]:
uber = pd.read_csv('data/uberRidesSmall.csv')

In [4]:
def model(userID, y=None):
    μ_α = numpyro.sample("μ_α", dist.Normal(0., 5.))
    σ_α = numpyro.sample("σ_α", dist.HalfCauchy(3.))    

    unique_IDs = np.unique(userID)
    n_Users = len(unique_IDs)

    with numpyro.plate("plate_i", n_Users):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))        

    σ = numpyro.sample("σ", dist.HalfCauchy(3.))
    y_est = α[userID]

    with numpyro.plate("data", len(userID)):
        numpyro.sample("obs", dist.Normal(y_est, σ), obs=y)

In [5]:
y = uber['logAmount'].values
userID = uber['userIndex'].values

In [None]:
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)

rng_key = random.PRNGKey(0)
mcmc.run(rng_key, userID, y=y)

posterior_samples = mcmc.get_samples()

In [32]:
coef = posterior_samples["α"]

In [64]:
quantiles = jnp.quantile(coef,jnp.array([0.025, 0.975]), axis=0)
means = jnp.mean(coef, axis=0)

# save the posterior means 
np.savetxt("alphaMeanNumPyro.csv", np.asarray(means), delimiter=",")