In [1]:
import pandas as pd
import numpy as np
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random
import jax.numpy as jnp

In [3]:
sales = pd.read_csv('storeSales.csv')

In [5]:
def model(StoreID, logPrice, logVol_obs=None):
    μ_α = numpyro.sample("μ_α", dist.Normal(0.0,10.))
    σ_α = numpyro.sample("σ_α", dist.HalfCauchy(3.0))
    μ_β = numpyro.sample("μ_β", dist.Normal(0.0,10.))
    σ_β = numpyro.sample("σ_β", dist.HalfCauchy(3.0))

    unique_IDs = np.unique(StoreID)
    n_stores= len(unique_IDs)

    with numpyro.plate("plate_i", n_stores):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = numpyro.sample("σ", dist.HalfCauchy(3.))
    logVol_est = α[StoreID] + β[StoreID] * logPrice

    with numpyro.plate("data", len(StoreID)):
        numpyro.sample("obs", dist.Normal(logVol_est, σ), obs=logVol_obs)

In [6]:
#
# convert raw store numbers to an index using LabelEncoder
#
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
sales['StoreID'] = le.fit_transform(sales['STORE_NUM'].values)

logVol_obs = sales['logVol'].values
logPrice = sales['logFL'].values
StoreID = sales['StoreID'].values

In [7]:
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, StoreID, logPrice, logVol_obs=logVol_obs)

posterior_samples = mcmc.get_samples()

sample: 100%|██████████| 4000/4000 [02:54<00:00, 22.97it/s, 127 steps of size 3.85e-02. acc. prob=0.87]


In [26]:
# get posterior mean of average price sensitivity
coef = posterior_samples["μ_β"]
means = jnp.mean(coef, axis=0)
print(means)



-2.3809643
