In [1]:
!pip install numpyro

In [2]:
import pandas as pd
import numpy as np
import arviz as az

import numpyro
import numpyro.distributions as dist
from numpyro.infer.reparam import LocScaleReparam
from numpyro.handlers import reparam
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer import MCMC, NUTS, Predictive

import jax
from jax import random
import jax.numpy as jnp

# plotting
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
import arviz as az
az.style.use("arviz-darkgrid")

numpyro.set_host_device_count(4)

# Pooling in Hierarchical Models

Code for [Estimating NBA Free Throw % with Hierarchical Models](https://jramkiss.github.io/2021/01/29/hierarchical-models/)

In [3]:
all_free_throws = pd.read_csv("/kaggle/input/nba-free-throws/free_throws.csv")
all_free_throws = all_free_throws[all_free_throws.playoffs == "regular"] # regular season only

free_throws = all_free_throws[all_free_throws.season == "2015 - 2016"].reset_index(drop=True)

In [4]:
first_half = free_throws[free_throws.game_id < np.quantile(free_throws.game_id, 0.50)].copy()
print("Number of games in first half of season: ", len(first_half.game_id.unique()))
second_half = free_throws[free_throws.game_id >= np.quantile(free_throws.game_id, 0.50)].copy()
print("Number of games in rest of season: ", len(second_half.game_id.unique()))

In [5]:
free_throws.game_id.nunique()

In [6]:
free_throws.head(20)

In [7]:
first_half.head(20)

In [17]:
# Use first half of the season and estimate free throw % for for the top 16 players
num_players = 16
top_player_names = list(first_half.groupby('player')\
                         .sum(['shot_made'])\
                         .sort_values("shot_made", ascending = False)[:num_players].index)
top_player_names

In [45]:
top_player_data = first_half[first_half.player.isin(top_player_names)]\
                    .groupby("player")["shot_made"].agg({'count', 'sum'}).reset_index()
top_player_data.columns = ["player", "shots_made", "total_shots"]
top_player_data["free_throw_percentage"] = top_player_data.shots_made/top_player_data.total_shots
top_player_data.sort_values("free_throw_percentage", ascending = False).head(num_players)

In [67]:
test_data = second_half[second_half.player.isin(top_player_names)]\
                    .groupby("player")["shot_made"].agg({'count', 'sum'}).reset_index()
test_data.columns = ["player",  "shots_made","total_shots"]
test_data["free_throw_percentage"] = test_data.shots_made/test_data.total_shots 
test_data.sort_values("free_throw_percentage", ascending = False).head(num_players)

### Inference Functions

In [46]:
def run_inference(model, ft_attempts, ft_makes, rng_key,
                  num_chains = 2,
                  num_warmup = 50, 
                  num_samples = 200):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=False
    )
    mcmc.run(rng_key, ft_attempts, ft_makes)
    return mcmc

## Complete Pooling

In complete pooling we assume that each player has the same probability of success, $\theta$. Assuming that each player's shots are independent Bernoulli trials, we can model the number of ft made for player $i$, $y_i$:

$$ p(y_i \mid \theta) = \text{Binomial}(K_i, \theta) $$

$$ \alpha = \text{logit}(\theta) $$

$$ p(y \mid \theta) =  \prod_{i = 1}^N \text{BinomialLogit}(K_i, \alpha) $$

$$ p(\alpha) =  \text{Normal}(1, 1) $$


This prior specification for $\alpha$ corresponds to $95\%$ of values between $0.26$ and $0.95$ chance of success. 

In [47]:
def fully_pooled(ft_attempts, ft_makes=None):
    num_players = ft_attempts.shape[0]
    alpha = numpyro.sample("alpha", dist.Normal(1, 1))
    theta = numpyro.deterministic("theta", jax.nn.sigmoid(alpha))
    with numpyro.plate("num_players", num_players):
        numpyro.sample("obs", dist.BinomialLogits(total_count = ft_attempts, logits=alpha), 
                       obs=ft_makes)

In [48]:
rng_key, rng_key_predict = random.split(random.PRNGKey(1))

pooled_mcmc = run_inference(fully_pooled, 
                            ft_attempts = jnp.array(top_player_data.total_shots),
                            ft_makes = jnp.array(top_player_data.shots_made),
                            num_warmup = 500,
                            num_samples = 1000,
                            rng_key = rng_key)
pooled_samples = pooled_mcmc.get_samples()

pooled_posterior_predictive = Predictive(fully_pooled, pooled_samples)(
    random.PRNGKey(1), jnp.array(top_player_data.total_shots))

pooled_prior = Predictive(fully_pooled, num_samples=500)(
    random.PRNGKey(2), jnp.array(top_player_data.total_shots))

pooled_az = az.from_numpyro(
    pooled_mcmc,
    prior = pooled_prior,
    posterior_predictive = pooled_posterior_predictive,
)

pooled_mcmc.print_summary()

In [49]:
az.plot_density(
    pooled_az,
    var_names=["theta"],
    shade=0.1,
    figsize = (6, 3),
)
plt.show()

In [76]:
ft_attempts = top_player_data.total_shots
ft_made = top_player_data.shots_made
print(f" Complete pooling free throw percentage observed = {ft_made.sum()/ft_attempts.sum():.2f}")

## No Pooling

In the complete pooling example, all observations shared the same `chance of success` parameter. The no pooling model is the exact opposite, where there is no sharing of parameters between observations. Again, we'll model the log-odds of the chance of success with $\alpha = \text{logit}(\theta)$

$$ p(y_i \mid \theta_i) = \text{Binomial}(K_i, \theta_i) $$

$$ \alpha_i = \text{logit}(\theta_i) $$

$$ p(y_i \mid \theta_i) = \text{BinomialLogit}(K_i, \alpha_i) $$

$$ p(\alpha_i) = \text{Normal}(1, 1) $$


This prior specification for $\alpha$ corresponds to $95\%$ of values between $0.26$ and $0.95$ chance of success. 
Note that each player has a DIFFERENT $\alpha$. No info is shared among players.

In [51]:
def no_pooling (ft_attempts, ft_makes = None):
    num_players = ft_attempts.shape[0]
    
    with numpyro.plate("players", num_players): #NOTE THIS 'PLATE' MEANS EACH PLAYER HAS THEIR OWN ALPHA(AND THETA)
        alpha = numpyro.sample("alpha", dist.Normal(1, 1)) # prior
        assert alpha.shape == (num_players,), "alpha shape wrong"
        theta = numpyro.deterministic("theta", jax.nn.sigmoid(alpha))
        return numpyro.sample("obs", dist.BinomialLogits(total_count=ft_attempts, logits=alpha), 
                              obs = ft_makes) # 

In [52]:
rng_key, rng_key_predict = random.split(random.PRNGKey(1))

non_pooled_mcmc = run_inference(no_pooling, 
                                ft_attempts = jnp.array(top_player_data.total_shots),
                                ft_makes = jnp.array(top_player_data.shots_made),
                                num_warmup = 500,
                                num_samples = 1000,
                                num_chains = 4,
                                rng_key = rng_key)
non_pooled_samples = non_pooled_mcmc.get_samples()

non_pooled_posterior_predictive = Predictive(no_pooling, non_pooled_samples)(
    random.PRNGKey(1), jnp.array(top_player_data.total_shots))

non_pooled_prior = Predictive(no_pooling, num_samples=500)(
    random.PRNGKey(2), jnp.array(top_player_data.total_shots))

non_pooled_az = az.from_numpyro(
    non_pooled_mcmc,
    prior = non_pooled_prior,
    posterior_predictive = non_pooled_posterior_predictive,
    coords = {"theta_dim_0": top_player_data.player}
)

non_pooled_mcmc.print_summary()

In [77]:
# Each player has a completely separate distribution
az.plot_density(
    data = non_pooled_az,
    var_names=["theta"],
    shade=0.1,
    grid = (4, 4)
)
plt.show()

## Hierarchical Model - Partial Pooling

We ideally want a balance between these two extremes, and this comes in the form of a partially pooled model. This model has a very subtle but important difference to the `no pooling` model. The difference is in how we generate $\alpha_i$. Instead of sampling $\alpha_i$ directly from $N(1, 1)$, we estimate the mean, $\mu$, and standard deviation, $\sigma$, of the population using hyper-priors. Here, $\mu$ can be interpreted as the population chance of success. 

In [56]:
def partial_pooling (ft_attempts, ft_makes = None):
    num_players = ft_attempts.shape[0]
    # define the hierarchical hyperpriors
    mu = numpyro.sample("mu", dist.Normal(1, 1))
    sigma = numpyro.sample("sigma", dist.HalfNormal(scale=1.))
    
    with numpyro.plate("players", num_players): #PLATE means DRAW AN ALPHA FOR EACH PLAYER
        alpha = numpyro.sample("alpha", dist.Normal(mu, sigma))  #use the hyperpriors to generate alphas. Shared info!
        theta = numpyro.deterministic("theta", jax.nn.sigmoid(alpha))
        assert alpha.shape == (num_players, ), "alpha shape wrong"
        return numpyro.sample("y", dist.BinomialLogits(logits = alpha, total_count = ft_attempts), 
                              obs = ft_makes)

In [78]:
rng_key, rng_key_predict = random.split(random.PRNGKey(1))

partial_pooled_mcmc = run_inference(partial_pooling, 
                                    ft_attempts = jnp.array(top_player_data.total_shots),
                                    ft_makes = jnp.array(top_player_data.shots_made),
                                    num_warmup = 500,
                                    num_samples = 1000,
                                    num_chains = 4,
                                    rng_key = rng_key)
partial_pooled_samples = partial_pooled_mcmc.get_samples()

partial_pooled_posterior_predictive = Predictive(partial_pooling, partial_pooled_samples)(
    random.PRNGKey(1), jnp.array(top_player_data.total_shots))

partial_pooled_prior = Predictive(partial_pooling, num_samples=500)(
    random.PRNGKey(2), jnp.array(top_player_data.total_shots))

partial_pooled_az = az.from_numpyro(
    partial_pooled_mcmc,
    prior = partial_pooled_prior,
    posterior_predictive = partial_pooled_posterior_predictive,
    coords = {"theta_dim_0": top_player_data.player}
)

partial_pooled_mcmc.print_summary()

In [79]:
az.plot_density(
    data = [partial_pooled_az, non_pooled_az],
    data_labels = ["partially pooled", "no pooling"],
    var_names=["theta"],
    shade=0.1,
    grid = (4, 4)
)
plt.show()

### Where does the difference come from?

The partially pooled and non-pooled models have very similar formulations, but produce very different posterior distributions. The most obvious difference for me is the prior on $\alpha_i$. The partial pooling formulation has more flexibility here as both $\mu$ an $\sigma$ are estimated from the data. Below I compare $p(\alpha)$ for the partially pooled and non-pooled models and it seems like the partially pooled prior has more variance than the non-pooled model.

Now I'm interested to see what the impact of flatter priors would have on the model. After increasing the prior variance for the non-pooled model, interval estimates were too wide to be useful, this is because we have such small data on each player. On the other hand, the interval estimates produced by the hierarchical model were very similar to before, this is because the hyperpriors are estimated using population data, which we have more of because of pooling. 

In [59]:
# compare p(\alpha) for no-pooling and partially pooled models
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(14, 4), sharey = True)
fig.suptitle(r'p($\alpha$) for Non-Pooled and Partially Pooled Models')

sns.histplot(partial_pooled_prior["alpha"].reshape(500*num_players,),
             ax = ax[0],
             color = "red",
             stat="density", common_norm=False,
             label = "partially pooled")
sns.histplot(non_pooled_prior["alpha"].reshape(500*num_players,), 
             stat="density", common_norm=False,
             ax = ax[0],
             label = "non-pooled")
ax[0].legend(fontsize = "small")

sns.histplot(partial_pooled_prior["alpha"].reshape(500*num_players,),
             ax = ax[1],
             stat="density", common_norm=False,
             color = "red",
             label = "partially pooled")
sns.histplot(np.random.normal(1, 1, 1000),
             ax = ax[1],
             stat="density", common_norm=False,
             label = "N(1, 1)")
ax[1].legend(fontsize = "small");

sns.histplot(non_pooled_prior["alpha"].reshape(500*num_players,), 
             color = "red",
             stat="density", common_norm=False,
             ax = ax[2],
             label = "non-pooled")
sns.histplot(np.random.normal(1, 1, 500),
             ax = ax[2],
             stat="density", common_norm=False,
             label = "N(1, 1)")
ax[2].legend(fontsize = "small");

In [60]:
"""
Compare \theta in the complete pooling model and \mu in the partial pooling model. 
Based on our interpretation they should be very similar.

\theta = sigmoid(mu) ?
"""

sns.histplot(jax.nn.sigmoid(partial_pooled_samples["mu"]), 
             color = "red", 
             label = r"$\mu$ mean: %.3f" % jax.nn.sigmoid(np.mean(partial_pooled_samples["mu"])),
             stat = "density")
sns.histplot(pooled_samples["theta"],
             label = r"$\theta_{pooled}$ mean: %.3f" % np.mean(pooled_samples["theta"]),
             stat = "density")\
.set_title(r"Posterior distributions for $\mu$ and $\theta_{pooled}$")
plt.legend(fontsize = "small");

## Centered and Non-Centered Models

In [61]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(14, 5), sharey = True)

sns.scatterplot(y = np.log(partial_pooled_samples["sigma"]), 
                x = partial_pooled_samples["theta"][:, 0],
                alpha = 0.2, ax = ax[0]
               ).set(title = "Centered Hierarchical Model Funnel Plot",
                     xlabel = r"$\theta_0$",
                     ylabel = "SD for chance of success, sigma")

sns.scatterplot(y = np.log(partial_pooled_samples["sigma"]), 
                x = partial_pooled_samples["theta"][:, 1],
                alpha = 0.2, ax = ax[1]
               ).set(title = "Centered Hierarchical Model Funnel Plot",
                     xlabel = r"$\theta_1$",
                     ylabel = "SD for chance of success, sigma");

## Posterior Predictive Distribution and Predictions

In [62]:
posterior_predictions = Predictive(partial_pooling, partial_pooled_samples)(
    random.PRNGKey(1), jnp.array(test_data.total_shots))

test_data_predictions = test_data.copy()
test_data_predictions["FT_prediction_mean"] = jnp.apply_along_axis(jnp.mean, 0, posterior_predictions["theta"])
test_data_predictions["FT_prediction_median"] = jnp.apply_along_axis(jnp.median, 0, posterior_predictions["theta"])
test_data_predictions

In [63]:
posterior_predictions_no_pooling = Predictive(no_pooling, non_pooled_samples)(
    random.PRNGKey(1), jnp.array(test_data.total_shots))

test_data_predictions_no_pool = test_data.copy()
test_data_predictions_no_pool["FT_prediction_mean"] = jnp.apply_along_axis(jnp.mean, 0, posterior_predictions_no_pooling["theta"])
test_data_predictions_no_pool["FT_prediction_median"] = jnp.apply_along_axis(jnp.median, 0, posterior_predictions_no_pooling["theta"])
test_data_predictions_no_pool

## Partial Pooling - Non-Centered Paramaterization

- What is the problem with the centered parameterization?
- How does this new formulation solve the problem?
- Write up non-centered model from scratch and compare it to `reparam`
- Why are the posteriors for the non-centered formulation smoother?
- Why does the `effective_sample_size` change for the non-centered and centered formulations?

In [64]:
rng_key, rng_key_predict = random.split(random.PRNGKey(1))

reparam_model = reparam(partial_pooling, config={'alpha': LocScaleReparam(0)})
reparam_mcmc = run_inference(reparam_model, 
                             ft_attempts = jnp.array(top_player_data.total_shots),
                             ft_makes = jnp.array(top_player_data.shots_made),
                             num_warmup = 500,
                             num_samples = 1000,
                             num_chains = 4,
                             rng_key = rng_key)
reparam_samples = reparam_mcmc.get_samples()

reparam_posterior_predictive = Predictive(reparam_model, reparam_samples)(
    random.PRNGKey(1), jnp.array(top_player_data.total_shots))

reparam_prior = Predictive(reparam_model, num_samples=500)(
    random.PRNGKey(2), jnp.array(top_player_data.total_shots))

reparam_az = az.from_numpyro(
    reparam_mcmc,
    prior = reparam_prior,
    posterior_predictive = reparam_posterior_predictive,
    coords = {"theta_dim_0": top_player_data.player}
)

reparam_mcmc.print_summary()

In [65]:
az.plot_density(
    data = [partial_pooled_az, reparam_az],
    data_labels = ["partially pooled", "reparam"],
    var_names=["theta"],
    shade=0.1,
    grid = (4, 4)
)
plt.show()

In [66]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(14, 5), sharey = True)

sns.scatterplot(y = np.log(reparam_samples["sigma"]), 
                x = reparam_samples["theta"][:, 0],
                alpha = 0.2, ax = ax[0]
               ).set(title = "Non-Centered Hierarchical Model Funnel Plot",
                     xlabel = r"\theta_0",
                     ylabel = "SD for chance of success, sigma")

sns.scatterplot(y = np.log(reparam_samples["sigma"]), 
                x = reparam_samples["theta"][:, 1],
                alpha = 0.2, ax = ax[1]
               ).set(title = "Non-Centered Hierarchical Model Funnel Plot",
                     xlabel = r"\theta_1",
                     ylabel = "SD for chance of success, sigma");