<a href="https://colab.research.google.com/github/tkorsi/Machine-Learning-Seminars/blob/main/HDI%20Calculation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
import pymc as pm
from scipy.stats import norm

# --- 1. Generate EXACT truncated sample using scipy.stats.norm.rvs ---
data_full = norm.rvs(loc=0, scale=1, size=1000, random_state=1337)
data = data_full[(data_full > -1) & (data_full < 2)]

print("Truncated data size:", len(data))
print("Data min, max =", data.min(), data.max())

# --- 2. Build the PyMC model ---
with pm.Model() as model:
    # Flat prior on mu
    mu = pm.Flat("mu")

    # HalfFlat prior on sigma
    sigma = pm.HalfFlat("sigma")

    # Potential enforcing p(mu, sigma) ∝ 1/sigma
    pm.Potential("prior_potential", -pm.math.log(sigma))

    # Exponential priors for how far below/above the data the truncation might be
    L = pm.Exponential("L", lam=1.0)  # how much further left than min(data)
    U = pm.Exponential("U", lam=1.0)  # how much further right than max(data)

    # Define the actual truncation boundaries
    lower_ = pm.Deterministic("lower_", data.min() - L)
    upper_ = pm.Deterministic("upper_", data.max() + U)

    # Truncated likelihood
    obs = pm.Truncated(
        "obs",
        pm.Normal.dist(mu=mu, sigma=sigma),
        lower=lower_,
        upper=upper_,
        observed=data
    )

    # --- 3. Sample with 2 chains to avoid shape validation issues ---
    trace = pm.sample(
        10000,       # total draws PER chain
        tune=2000,   # tuning steps
        chains=2,    # use 2 chains
        cores=1,     # set cores=1 if you run into parallelization issues
        random_seed=1337,
        target_accept=0.9  # (optional) higher target accept might give more stable estimates
    )

# --- 4. Print ArviZ summary ---
summary_df = az.summary(
    trace,
    var_names=["mu", "sigma", "L", "U", "lower_", "upper_"],
    hdi_prob=0.95
)
print(summary_df)

# --- 5. Compute & print the 95% HDI for sigma in the format "lower upper" ---
sigma_vals = trace.posterior["sigma"].values.flatten()
sigma_hdi = az.hdi(sigma_vals, hdi_prob=0.95)
print("\n95% credible interval for sigma:", f"{sigma_hdi[0]:.3f} {sigma_hdi[1]:.3f}")

# --- 6. Plot the trace ---
az.plot_trace(trace, var_names=["mu", "sigma", "L", "U", "lower_", "upper_"])
plt.tight_layout()
plt.show()
