<a href="https://colab.research.google.com/github/tkorsi/ClusterPrePermissions/blob/master/hierarchical%20model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt

# 1) Read the data
url = "https://stepik.org/media/attachments/lesson/1251114/euroweight.dat.txt"
euro = pd.read_csv(url, sep="\t", names=["weight", "batch"], index_col=0)

# Each 'batch' is an integer 1..8, each with 250 observations
print(euro.head())
print(euro["batch"].value_counts())

# 2) Build the hierarchical model
with pm.Model() as model:
    # Global hyperparameters
    mu_b = pm.Normal("mu_b", mu=0, sigma=100)
    sigma_b = pm.HalfNormal("sigma_b", sigma=5)

    # Batch-level means
    # shape=8 because there are 8 batches
    mu_batch = pm.Normal("mu_batch", mu=mu_b, sigma=sigma_b, shape=8)

    # Observation-level scale
    sigma = pm.HalfNormal("sigma", sigma=10)

    # Map each coin's batch to the correct mu_batch index
    # The 'batch' column is 1-based, so shift by -1 for 0-based indexing:
    batch_idx = euro["batch"].values - 1

    # Likelihood for each coin's weight
    pm.Normal(
        "obs",
        mu=mu_batch[batch_idx],
        sigma=sigma,
        observed=euro["weight"].values
    )

    # 3) Sample from the posterior
    trace = pm.sample(10000, chains=1, random_seed=42, progressbar=True)

   weight  batch
1   7.512      1
2   7.502      1
3   7.461      1
4   7.562      1
5   7.528      1
batch
1    250
2    250
3    250
4    250
5    250
6    250
7    250
8    250
Name: count, dtype: int64


Output()

In [12]:

# 4) Summarize
print(az.summary(trace, var_names=["mu_b", "sigma_b", "mu_batch", "sigma"]))

# --- Identify the highest and lowest posterior means among the 8 mu_batch ---
# 'trace.posterior["mu_batch"]' typically has shape (chain, draw, 8)
post_mu_batch = trace.posterior["mu_batch"].values  # numpy array
# We get the mean across chain/draw for each of the 8 batch means:
batch_means = post_mu_batch.mean(axis=(0, 1))  # shape: (8,)

i_max = np.argmax(batch_means)  # index of highest mean
i_min = np.argmin(batch_means)  # index of lowest mean

print("\nPosterior mean of mu_batch:", batch_means)
print(f"Highest batch index: {i_max}, mean={batch_means[i_max]:.4f}")
print(f"Lowest  batch index: {i_min}, mean={batch_means[i_min]:.4f}")



Shape validation failed: input_shape: (1, 10000), minimum_shape: (chains=2, draws=4)


              mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  \
mu_b         7.521  0.003   7.515    7.528        0.0      0.0   10458.0   
sigma_b      0.009  0.003   0.004    0.015        0.0      0.0    9629.0   
mu_batch[0]  7.520  0.002   7.516    7.524        0.0      0.0   15868.0   
mu_batch[1]  7.523  0.002   7.519    7.527        0.0      0.0   16248.0   
mu_batch[2]  7.510  0.002   7.506    7.514        0.0      0.0   14000.0   
mu_batch[3]  7.530  0.002   7.526    7.534        0.0      0.0   16683.0   
mu_batch[4]  7.531  0.002   7.527    7.535        0.0      0.0   14823.0   
mu_batch[5]  7.516  0.002   7.512    7.520        0.0      0.0   14704.0   
mu_batch[6]  7.523  0.002   7.519    7.527        0.0      0.0   18092.0   
mu_batch[7]  7.517  0.002   7.513    7.521        0.0      0.0   16326.0   
sigma        0.034  0.001   0.033    0.035        0.0      0.0   16137.0   

             ess_tail  r_hat  
mu_b           6973.0    NaN  
sigma_b        6954.0    

In [13]:

# 5) Compute mu_i - mu_j for each MCMC sample
mu_i_samples = post_mu_batch[..., i_max]  # shape (chain, draw)
mu_j_samples = post_mu_batch[..., i_min]
diff_samples = mu_i_samples - mu_j_samples  # shape (chain, draw)

# Flatten across chains and draws
diff_samples_1d = diff_samples.flatten()

# 6) 95% HDI for the difference
hdi_95 = az.hdi(diff_samples_1d, hdi_prob=0.95)
diff_lower, diff_upper = hdi_95[0], hdi_95[1]

print(f"\n95% credible interval for mu_{i_max} - mu_{i_min}:")
print(f"[{diff_lower:.4f}, {diff_upper:.4f}]")

# 7) Check whether the interval contains 0
if diff_lower <= 0 <= diff_upper:
    print("The 95% interval contains 0.")
else:
    print("The 95% interval does NOT contain 0.")

# 8) Also compute E[ mu_i ]
mu_i_mean = mu_i_samples.mean()  # average across chain & draws
print(f"\nE[ mu_{i_max} ] = {mu_i_mean:.4f}")


95% credible interval for mu_4 - mu_2:
[0.0143, 0.0264]
The 95% interval does NOT contain 0.

E[ mu_4 ] = 7.5307


In [None]:

# 9) Plot the posterior for the difference (optional)
az.plot_posterior(diff_samples_1d, hdi_prob=0.95)
plt.title(f"Posterior of mu_{i_max} - mu_{i_min}")
plt.show()
