# Lesson 5 — Gibbs Sampling

In this notebook we build a **Gibbs sampler** step‑by‑step for a Normal model whose mean $\mu$ and variance
$\sigma^2$ are *both unknown*. 

1. Specify conjugate priors  
   $\;\mu\mid\sigma^2\sim\mathcal N\bigl(\mu_0,\,\sigma^2/\kappa_0\bigr),\;\;
   \sigma^2\sim\text{Inv‑Gamma}(\nu_0,\beta_0)$  
2. Derive and code the two full‑conditional distributions  
   $p(\mu\mid\sigma^2,\,y)$ and $p(\sigma^2\mid\mu,\,y)$  
3. Implement the Gibbs sampling loop in Python  
4. Diagnose convergence with trace plots and R‑hat  
5. Summarize the posterior and compare to prior beliefs

## 1  Setup

In [None]:

import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from scipy.stats import invgamma


In [None]:

# Data (same as Lesson 4)
y = np.array([1.2, 1.4, -0.5, 0.3, 0.9, 2.3, 1.0, 0.1, 1.3, 1.9])
n = len(y); ybar = y.mean()
print(f"n = {n}, ȳ = {ybar:.3f}")


### 1.1  Prior hyper‑parameters

In [None]:

mu0    = 0.0
kappa0 = 1.0    # equivalent to σ² prior scaling factor
nu0    = 1.0
beta0  = 1.0


## 2  Full conditional simulators

In [None]:

def sample_mu(n, ybar, sig2, mu0, kappa0):
    var = 1.0 / (n / sig2 + 1.0 / (kappa0 * sig2))
    mean = var * (n * ybar / sig2 + mu0 / (kappa0 * sig2))
    return np.random.normal(mean, np.sqrt(var))


In [None]:

def sample_sig2(n, y, mu, nu0, beta0):
    nu1 = nu0 + n/2.0
    beta1 = beta0 + 0.5 * np.sum((y - mu)**2)
    return invgamma.rvs(a=nu1, scale=beta1)


## 3  Gibbs sampler

In [None]:

def gibbs_normal(y, n_iter=1000, mu0=0.0, kappa0=1.0, nu0=1.0, beta0=1.0):
    n = len(y); ybar = y.mean()
    mu_chain   = np.empty(n_iter)
    sig2_chain = np.empty(n_iter)
    mu_now = mu0
    for i in range(n_iter):
        sig2_now = sample_sig2(n, y, mu_now, nu0, beta0)
        mu_now   = sample_mu(n, ybar, sig2_now, mu0, kappa0)
        mu_chain[i] = mu_now
        sig2_chain[i] = sig2_now
    return {"mu": mu_chain, "sig2": sig2_chain}


### 3.1  Run the sampler

In [None]:

np.random.seed(53)
post = gibbs_normal(y, n_iter=1000, mu0=mu0, kappa0=kappa0, nu0=nu0, beta0=beta0)
print('Posterior means:')
print('  mu   =', post['mu'].mean())
print('  sig2 =', post['sig2'].mean())


### 3.2  Trace diagnostics

In [None]:

idata = az.from_dict({'mu': post['mu'], 'sig2': post['sig2']})
az.plot_trace(idata, figsize=(8,4), compact=True)


## 4  Posterior summaries

In [None]:

az.summary(idata, kind='stats', hdi_prob=0.95)
