
# Bayesian Iris Classifier (NumPyro + NUTS)

This notebook replicates a classic Bayesian demonstration using the **Iris** dataset:
- A 1-hidden-layer MLP with **Gaussian priors** over weights and biases
- **Hamiltonian Monte Carlo (NUTS)** to sample from the posterior
- **Posterior predictive** probabilities
- A simple **abstention rule** (“I don’t know”) controlled by a confidence threshold `τ`
- Evaluation with accuracy and **selective prediction** (accuracy vs. coverage)

> Works on Apple Silicon (CPU) out of the box.



## Environment

If you haven't installed the dependencies inside your Pipenv already, run this once in a terminal:

```bash
pipenv run pip install numpyro==0.15.2 jax[cpu]==0.4.28 scikit-learn pandas matplotlib
```


In [None]:

# Imports
import math
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS


In [None]:

# Data: load, split, scale
iris = load_iris()
X = iris.data.astype(np.float32)   # (150, 4)
y = iris.target.astype(np.int32)   # 0..2

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42, stratify=y
)

scaler = StandardScaler().fit(X_train)
X_train = scaler.transform(X_train).astype(np.float32)
X_test  = scaler.transform(X_test).astype(np.float32)

n_features = X_train.shape[1]
n_hidden   = 8
n_classes  = len(np.unique(y_train))

Xtr = jnp.array(X_train)
Xte = jnp.array(X_test)
ytr = jnp.array(y_train)
yte = jnp.array(y_test)


In [None]:

# Bayesian MLP with Gaussian priors
def relu(x):
    return jnp.maximum(0, x)

def model(X, y=None):
    W1 = numpyro.sample("W1", dist.Normal(0, 1).expand([n_features, n_hidden]).to_event(2))
    b1 = numpyro.sample("b1", dist.Normal(0, 1).expand([n_hidden]).to_event(1))
    W2 = numpyro.sample("W2", dist.Normal(0, 1).expand([n_hidden, n_classes]).to_event(2))
    b2 = numpyro.sample("b2", dist.Normal(0, 1).expand([n_classes]).to_event(1))

    h = relu(jnp.dot(X, W1) + b1)
    logits = jnp.dot(h, W2) + b2

    with numpyro.plate("data", X.shape[0]):
        numpyro.sample("obs", dist.Categorical(logits=logits), obs=y)


In [None]:

# Inference with NUTS
rng_key = jax.random.PRNGKey(0)
nuts = NUTS(model, dense_mass=True, target_accept_prob=0.85)
mcmc = MCMC(nuts, num_warmup=800, num_samples=800, num_chains=2, progress_bar=True)
mcmc.run(rng_key, X=Xtr, y=ytr)
posterior = mcmc.get_samples()
mcmc.print_summary()


In [None]:

# Posterior predictive probabilities
def forward_logits(X, params):
    h = relu(jnp.dot(X, params["W1"]) + params["b1"])
    return jnp.dot(h, params["W2"]) + params["b2"]

def posterior_class_probs(X, posterior_dict):
    # vectorize forward pass over posterior samples
    logits_samples = jax.vmap(lambda p: forward_logits(X, p))(posterior_dict)  # (S, N, C)
    probs_samples = jax.nn.softmax(logits_samples, axis=-1)
    mean_probs = probs_samples.mean(axis=0)  # (N, C)
    return np.asarray(mean_probs)

p_tr = posterior_class_probs(Xtr, posterior)
p_te = posterior_class_probs(Xte, posterior)

yhat_tr = p_tr.argmax(axis=1)
yhat_te = p_te.argmax(axis=1)

acc_tr = accuracy_score(y_train, yhat_tr)
acc_te = accuracy_score(y_test, yhat_te)
acc_tr, acc_te


In [None]:

# Selective prediction: abstain below threshold tau
def evaluate_with_abstention(probs, y_true, tau=0.8):
    probs = np.asarray(probs)
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    answered = conf >= tau
    coverage = answered.mean()
    acc = float("nan") if coverage == 0 else (pred[answered] == y_true[answered]).mean()
    return coverage, acc

taus = np.linspace(0.5, 0.95, 10)
coverages = []
accs = []
for t in taus:
    cov, acc = evaluate_with_abstention(p_te, y_test, tau=float(t))
    coverages.append(cov)
    accs.append(acc)

# Plot: accuracy vs coverage (selective risk curve)
plt.figure(figsize=(6,4))
plt.plot(coverages, accs, marker='o')
plt.xlabel("Coverage (fraction answered)")
plt.ylabel("Accuracy on answered")
plt.title("Selective Prediction: Accuracy vs Coverage")
plt.grid(True)
plt.show()


In [None]:

# Reliability diagram (optional): 10-bin calibration
def reliability_curve(probs, y_true, bins=10):
    probs = np.asarray(probs)
    y_true = np.asarray(y_true)
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)

    bin_edges = np.linspace(0.0, 1.0, bins + 1)
    bin_ids = np.digitize(conf, bin_edges) - 1
    bin_acc = []
    bin_conf = []
    bin_count = []
    for b in range(bins):
        mask = bin_ids == b
        if mask.sum() > 0:
            bin_acc.append((pred[mask] == y_true[mask]).mean())
            bin_conf.append(conf[mask].mean())
            bin_count.append(mask.sum())
        else:
            bin_acc.append(np.nan)
            bin_conf.append(np.nan)
            bin_count.append(0)
    return np.array(bin_conf), np.array(bin_acc), np.array(bin_count), bin_edges

bin_conf, bin_acc, bin_cnt, edges = reliability_curve(p_te, y_test, bins=10)

plt.figure(figsize=(6,6))
# perfect calibration line
plt.plot([0,1], [0,1], linestyle='--')
# points per bin (skip NaNs)
valid = ~np.isnan(bin_conf)
plt.scatter(bin_conf[valid], bin_acc[valid])
plt.xlabel("Mean predicted confidence")
plt.ylabel("Empirical accuracy")
plt.title("Reliability Diagram (Test)")
plt.grid(True)
plt.show()
