# Bayesian Logistic Regression — German Credit

Visualize and analyze ETD variants vs baselines on the Bayesian logistic regression posterior.

**Prerequisites:**
```bash
# 1. Generate NUTS reference (one-time, ~5 min)
python -m experiments.nuts --target blr --dataset german_credit --force

# 2. Run the experiment
python -m experiments.run configs/blr_german.yaml
```

In [None]:
import json
import os
import sys
from pathlib import Path

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

# Project imports — chdir to project root so relative paths
# (e.g., data/etd.duckdb, results/reference/) resolve correctly.
ROOT = Path.cwd().parent
os.chdir(ROOT)
sys.path.insert(0, str(ROOT / "src"))
sys.path.insert(0, str(ROOT))

from etd.targets.blr import BLRTarget
from experiments.nuts import load_reference
from figures.style import (
    ALGO_COLORS, FULL_WIDTH, COL_WIDTH,
    savefig_paper, setup_style,
)

setup_style()
%config InlineBackend.figure_format = 'retina'

## 1. Load data

In [None]:
import yaml

# --- Find the latest results directory ---
results_root = ROOT / "results" / "blr-ionosphere"
assert results_root.exists(), (
    f"No results at {results_root}.\n"
    "Run: python -m experiments.run configs/blr/CONFIG_NAME.yaml"
)
runs = sorted(results_root.iterdir())
RESULTS_DIR = runs[-1]
print(f"Loading from: {RESULTS_DIR}")

# --- Load config ---
with open(RESULTS_DIR / "config.yaml") as f:
    exp_config = yaml.safe_load(f)

exp = exp_config["experiment"]
target_cfg = exp["target"]
target_params = target_cfg.get("params", {})

# --- Load metrics ---
with open(RESULTS_DIR / "metrics.json") as f:
    raw_metrics = json.load(f)

metrics = {}
for seed, algo_dict in raw_metrics.items():
    metrics[seed] = {}
    for algo, ckpt_dict in algo_dict.items():
        metrics[seed][algo] = {int(k): v for k, v in ckpt_dict.items()}

# --- Load particles ---
particles = dict(np.load(RESULTS_DIR / "particles.npz"))

# --- Build target ---
target = BLRTarget(**target_params)
d = target.dim
print(f"Target: blr, dim={d}, prior_std={target.prior_std}")
print(f"Data: X {target.X.shape}, y {target.y.shape} ({float(target.y.mean()):.1%} positive)")

# --- Constants ---
SEEDS = [f"seed{i}" for i in exp["seeds"]]
CHECKPOINTS = sorted(set(
    int(k) for algo_dict in metrics.values()
    for ckpt_dict in algo_dict.values()
    for k in ckpt_dict
))
ALL_ALGOS = sorted(set(
    algo for algo_dict in metrics.values() for algo in algo_dict
))
ALGOS = ALL_ALGOS  # Keep all for BLR
print(f"Checkpoints: {CHECKPOINTS}")
print(f"Algorithms: {ALGOS}")

In [None]:
# --- Load NUTS reference ---
ref_samples = load_reference("blr", target_params)
if ref_samples is not None:
    print(f"NUTS reference: {ref_samples.shape}")
    ref_mean = ref_samples.mean(axis=0)
    ref_std = ref_samples.std(axis=0)
    HAS_REF = True
else:
    print("No NUTS reference found. Run:")
    print("  python -m experiments.nuts --target blr --dataset german_credit")
    HAS_REF = False

## 2. Convergence traces

Each column is one algorithm; light gray traces show other algorithms for context.

In [None]:
def gather_metric(metric_name, algos=None):
    """Return {algo: (median, q25, q75)} arrays over checkpoints."""
    if algos is None:
        algos = ALGOS
    out = {}
    for algo in algos:
        vals = []
        for seed in SEEDS:
            row = [
                metrics.get(seed, {}).get(algo, {}).get(c, {}).get(metric_name, np.nan)
                for c in CHECKPOINTS
            ]
            vals.append(row)
        vals = np.array(vals)
        out[algo] = (
            np.nanmedian(vals, axis=0),
            np.nanpercentile(vals, 25, axis=0),
            np.nanpercentile(vals, 75, axis=0),
        )
    return out


# Evenly-spaced x positions with checkpoint labels
X_POS = np.arange(len(CHECKPOINTS))
X_LABELS = [str(c) for c in CHECKPOINTS]

# Sparse tick labels
TICK_SHOW = set(range(0, len(CHECKPOINTS), max(1, len(CHECKPOINTS) // 5)))
TICK_SHOW.add(len(CHECKPOINTS) - 1)  # always show last
X_LABELS_SPARSE = [str(c) if i in TICK_SHOW else "" for i, c in enumerate(CHECKPOINTS)]

In [None]:
metric_specs = [
    ("energy_distance", "Energy distance"),
    ("sliced_wasserstein", "Sliced $W_2$"),
    ("mean_rmse", "Mean RMSE"),
    ("mean_error", "Mean error"),
]
n_metrics = len(metric_specs)
n_algos = len(ALGOS)

fig, axes = plt.subplots(
    n_metrics, n_algos,
    figsize=(FULL_WIDTH, 1.5 * n_metrics),
    sharex=True, sharey="row",
    constrained_layout=True,
)

for row, (metric_name, ylabel) in enumerate(metric_specs):
    data = gather_metric(metric_name)
    for col, algo in enumerate(ALGOS):
        ax = axes[row, col]

        # Gray context: all OTHER algorithms
        for other in ALGOS:
            if other == algo:
                continue
            med_o, _, _ = data[other]
            ax.plot(X_POS, med_o, color="#cccccc", linewidth=0.6, zorder=1)

        # Foreground: this algorithm's median + IQR
        median, q25, q75 = data[algo]
        color = ALGO_COLORS.get(algo, "#333")
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18, zorder=2)
        ax.plot(X_POS, median, color=color, linewidth=1.5, zorder=3)

        ax.set_xticks(X_POS)
        ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
        ax.tick_params(direction="out", length=3)

        if row == 0:
            ax.set_title(algo, fontsize=9)
        if col == 0:
            ax.set_ylabel(ylabel, fontsize=8)
        if row == n_metrics - 1:
            ax.set_xlabel("Iteration", fontsize=8)

plt.show()

## 3. Convergence overlay

All algorithms on the same axes for direct comparison.

In [None]:
fig, axes = plt.subplots(
    1, len(metric_specs),
    figsize=(FULL_WIDTH, 2.2),
    constrained_layout=True,
)

# Sparse tick indices — show ~5 labels to prevent overlap
n_ckpts = len(CHECKPOINTS)
step = max(1, n_ckpts // 4)
tick_idx = sorted(set(list(range(0, n_ckpts, step)) + [n_ckpts - 1]))

for ax, (metric_name, ylabel) in zip(axes, metric_specs):
    data = gather_metric(metric_name)
    for algo in ALGOS:
        median, q25, q75 = data[algo]
        color = ALGO_COLORS.get(algo, "#333")
        is_baseline = algo in ("SVGD", "ULA", "MPPI", "MALA")
        ls = "--" if is_baseline else "-"
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.12)
        ax.plot(X_POS, median, color=color, linewidth=1.5, linestyle=ls, label=algo)

    ax.set_xticks([X_POS[i] for i in tick_idx])
    ax.set_xticklabels([str(CHECKPOINTS[i]) for i in tick_idx], fontsize=7)
    ax.set_ylabel(ylabel, fontsize=8)
    ax.set_xlabel("Iteration", fontsize=8)
    ax.tick_params(direction="out", length=3)

axes[0].legend(fontsize=7, frameon=False, loc="center right")
plt.show()

## 4. Marginal posterior ridgeline

Ridgeline plot for the most informative dimensions (largest NUTS posterior std).
Each row shows one dimension: NUTS reference as a gray fill, algorithm KDEs overlaid.

In [None]:
# Select dimensions to display: top-6 by NUTS posterior std (or particle spread)
final_ckpt = max(CHECKPOINTS)
N_DIMS_SHOW = 6

if HAS_REF:
    dim_order = np.argsort(-ref_std)  # largest std first
else:
    # Fallback: use particle spread from first algorithm
    key0 = f"seed0__{ALGOS[0]}__iter{final_ckpt}"
    fallback_std = particles[key0].std(axis=0)
    dim_order = np.argsort(-fallback_std)

DIMS_SHOW = dim_order[:N_DIMS_SHOW]
print(f"Displaying dimensions: {DIMS_SHOW}")

In [None]:
fig, axes = plt.subplots(
    N_DIMS_SHOW, 1,
    figsize=(COL_WIDTH * 1.8, 1.0 * N_DIMS_SHOW),
    constrained_layout=True,
)

for row, dim_idx in enumerate(DIMS_SHOW):
    ax = axes[row]

    # --- Determine x range ---
    if HAS_REF:
        ref_dim = ref_samples[:, dim_idx]
        pad = 0.3 * ref_dim.std()
        x_lo, x_hi = ref_dim.min() - pad, ref_dim.max() + pad
    else:
        all_vals = []
        for algo in ALGOS:
            key = f"seed0__{algo}__iter{final_ckpt}"
            if key in particles:
                all_vals.append(particles[key][:, dim_idx])
        combined = np.concatenate(all_vals)
        pad = 0.3 * combined.std()
        x_lo, x_hi = combined.min() - pad, combined.max() + pad

    x_grid = np.linspace(x_lo, x_hi, 300)

    # --- NUTS reference ---
    if HAS_REF:
        kde_ref = stats.gaussian_kde(ref_dim)
        ref_density = kde_ref(x_grid)
        ax.fill_between(
            x_grid, ref_density, color="#DDDDDD", alpha=0.6,
            label="NUTS" if row == 0 else None,
        )
        ax.plot(x_grid, ref_density, color="#AAAAAA", linewidth=0.7)

    # --- Algorithm KDEs (seed 0) ---
    for algo in ALGOS:
        key = f"seed0__{algo}__iter{final_ckpt}"
        if key not in particles:
            continue
        vals = particles[key][:, dim_idx]
        kde = stats.gaussian_kde(vals)
        color = ALGO_COLORS.get(algo, "#333")
        is_baseline = algo in ("SVGD", "ULA", "MPPI", "MALA")
        ax.plot(
            x_grid, kde(x_grid),
            color=color, linewidth=1.2,
            linestyle="--" if is_baseline else "-",
            label=algo if row == 0 else None,
        )

    # --- Styling ---
    ax.set_ylabel(f"dim {dim_idx}", fontsize=8, rotation=0, labelpad=28, va="center")
    ax.set_yticks([])
    ax.set_xlim(x_lo, x_hi)

    # Spine cleanup: only keep bottom spine on the last row
    ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(row == N_DIMS_SHOW - 1)
    if row < N_DIMS_SHOW - 1:
        ax.tick_params(bottom=False, labelbottom=False)
    else:
        ax.tick_params(direction="out", length=3)
        ax.set_xlabel(r"$\theta$", fontsize=9)

axes[0].legend(fontsize=7, frameon=False, ncol=min(len(ALGOS), 4), loc="upper right")
fig.suptitle("Marginal posteriors (seed 0, final iteration)", fontsize=10, y=1.01)
plt.show()

## 5. Coefficient error heatmap

Signed error of each algorithm's posterior mean relative to the NUTS reference,
per dimension. Blue = algorithm underestimates, red = overestimates, white = agreement.

In [None]:
if HAS_REF:
    # --- Compute per-algorithm signed mean error vs NUTS ---
    mean_errors = np.full((d, len(ALGOS)), np.nan)

    for j, algo in enumerate(ALGOS):
        means_per_seed = []
        for seed_idx in range(len(SEEDS)):
            key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
            if key not in particles:
                continue
            means_per_seed.append(particles[key].mean(axis=0))

        if means_per_seed:
            algo_mean = np.median(means_per_seed, axis=0)
            mean_errors[:, j] = algo_mean - ref_mean

    # Transpose: (d, n_algos) -> (n_algos, d) — wide layout
    mean_errors_T = mean_errors.T

    # Signed square-root transform to compress outliers
    err_transformed = np.sign(mean_errors_T) * np.sqrt(np.abs(mean_errors_T))
    vmax_err = max(np.nanmax(np.abs(err_transformed)), 0.1)

    # Sparse feature-index labels
    feat_labels_sparse = [
        str(i) if (i % max(1, d // 16) == 0 or i == d - 1) else ""
        for i in range(d)
    ]

    n_algos = len(ALGOS)
    fig, ax = plt.subplots(
        figsize=(FULL_WIDTH, n_algos * 0.4 + 0.8),
        constrained_layout=True,
    )

    im = ax.imshow(
        err_transformed,
        aspect="auto",
        cmap="RdBu_r",
        vmin=-vmax_err,
        vmax=vmax_err,
        interpolation="nearest",
    )
    ax.set_yticks(np.arange(n_algos))
    ax.set_yticklabels(ALGOS, fontsize=7)
    ax.set_xticks(np.arange(d))
    ax.set_xticklabels(feat_labels_sparse, fontsize=6)
    ax.set_xlabel("Feature index", fontsize=8)
    ax.set_title("Posterior mean error vs NUTS", fontsize=9, loc="left")

    cb = fig.colorbar(im, ax=ax, shrink=0.7, pad=0.02, aspect=15)
    cb.ax.tick_params(labelsize=7)
    cb_ticks = cb.get_ticks()
    cb.set_ticklabels([f"{np.sign(t)*t**2:.2f}" for t in cb_ticks])
    cb.set_label("Signed error", fontsize=8)

    plt.show()
else:
    print("Skipping coefficient heatmap — no NUTS reference.")

## 6. Per-dimension variance ratio

Ratio of particle variance to NUTS reference variance per dimension.
Values < 1 indicate under-dispersion (common failure mode of VI).

In [None]:
if HAS_REF:
    fig, axes = plt.subplots(
        1, len(ALGOS),
        figsize=(FULL_WIDTH, 2.0),
        sharey=True,
        constrained_layout=True,
    )

    ref_var = ref_samples.var(axis=0)  # (d,)

    for ax, algo in zip(axes, ALGOS):
        ratios_per_seed = []
        for seed_idx in range(len(SEEDS)):
            key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
            if key not in particles:
                continue
            pts_var = particles[key].var(axis=0)
            ratios_per_seed.append(pts_var / np.maximum(ref_var, 1e-10))

        if not ratios_per_seed:
            continue

        median_ratio = np.median(ratios_per_seed, axis=0)
        color = ALGO_COLORS.get(algo, "#333")

        ax.bar(
            np.arange(d), median_ratio,
            color=color, alpha=0.7, width=0.7,
        )
        ax.axhline(1.0, color="#999", linewidth=0.8, linestyle="--")
        ax.set_title(algo, fontsize=9)
        ax.set_xlabel("Dimension", fontsize=8)
        ax.tick_params(direction="out", length=3)
        ax.set_xticks(np.arange(0, d, max(1, d // 5)))

    axes[0].set_ylabel("Var ratio (particle / NUTS)", fontsize=8)
    plt.show()
else:
    print("Skipping variance ratio plot — no NUTS reference.")

## 7. Pairwise posterior structure

2D scatter plots for selected pairs of dimensions to see how well
algorithms capture posterior correlations.

In [None]:
# Pick 3 pairs: the most correlated dimensions (from NUTS or particles)
N_PAIRS = 3

if HAS_REF:
    corr_mat = np.corrcoef(ref_samples.T)  # (d, d)
else:
    key0 = f"seed0__{ALGOS[0]}__iter{final_ckpt}"
    corr_mat = np.corrcoef(particles[key0].T)

# Zero out diagonal and take absolute value
np.fill_diagonal(corr_mat, 0)
abs_corr = np.abs(corr_mat)

# Find top pairs
pairs = []
used = set()
flat_idx = np.argsort(abs_corr.ravel())[::-1]
for idx in flat_idx:
    i, j = divmod(idx, d)
    if i >= j:  # upper triangle only
        continue
    if i in used and j in used:
        continue
    pairs.append((i, j))
    used.update([i, j])
    if len(pairs) == N_PAIRS:
        break

print(f"Most correlated pairs: {pairs}")
for i, j in pairs:
    print(f"  dims ({i}, {j}): corr = {corr_mat[i, j]:.3f}")

In [None]:
fig, axes = plt.subplots(
    len(ALGOS), N_PAIRS,
    figsize=(COL_WIDTH * 1.8, 1.4 * len(ALGOS)),
    constrained_layout=True,
)
if len(ALGOS) == 1:
    axes = axes[None, :]  # ensure 2D

for row, algo in enumerate(ALGOS):
    for col, (di, dj) in enumerate(pairs):
        ax = axes[row, col]

        # NUTS reference (small gray dots)
        if HAS_REF:
            subsample = ref_samples[::4]  # thin for speed
            ax.scatter(
                subsample[:, di], subsample[:, dj],
                s=4, color="#CCCCCC", alpha=0.3, zorder=1, rasterized=True,
            )

        # Algorithm particles (seed 0)
        key = f"seed0__{algo}__iter{final_ckpt}"
        if key in particles:
            pts = particles[key]
            color = ALGO_COLORS.get(algo, "#333")
            ax.scatter(
                pts[:, di], pts[:, dj],
                s=12, color=color, edgecolors="white", linewidths=0.4,
                zorder=5, alpha=0.8,
            )

        if row == 0:
            ax.set_title(f"dims ({di}, {dj})", fontsize=8)
        if col == 0:
            ax.set_ylabel(algo, fontsize=8)
        ax.tick_params(direction="out", length=2, labelsize=6)

plt.show()

## 8. Predictive performance

Posterior predictive accuracy and calibration on the training data.
For each algorithm, compute $p(y=1 \mid x) = \frac{1}{N}\sum_i \sigma(x \cdot \theta_i)$
using the final particles as an ensemble.

In [None]:
from sklearn.metrics import brier_score_loss

X_np = np.array(target.X)
y_np = np.array(target.y)


def posterior_predictive(pts, X):
    """Compute ensemble predictive probabilities.

    Args:
        pts: Particles, shape (N, d).
        X: Design matrix, shape (n_data, d).

    Returns:
        Predicted probabilities, shape (n_data,).
    """
    logits = X @ pts.T  # (n_data, N)
    probs = 1.0 / (1.0 + np.exp(-logits))  # sigmoid
    return probs.mean(axis=1)  # (n_data,)


print(f"{'Algorithm':<15} {'Accuracy':>10} {'Brier':>10}")
print("-" * 37)

pred_results = {}
for algo in ALGOS:
    accs, briers = [], []
    for seed_idx in range(len(SEEDS)):
        key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
        if key not in particles:
            continue
        p_pred = posterior_predictive(particles[key], X_np)
        acc = float(np.mean((p_pred >= 0.5) == y_np))
        brier = float(brier_score_loss(y_np, p_pred))
        accs.append(acc)
        briers.append(brier)

    pred_results[algo] = {"accuracy": accs, "brier": briers}
    print(
        f"{algo:<15} {np.median(accs):>9.3f}  {np.median(briers):>9.4f}"
    )

In [None]:
# Calibration plot: predicted probability vs observed frequency
N_BINS = 10

fig, axes = plt.subplots(
    1, len(ALGOS),
    figsize=(FULL_WIDTH, 2.0),
    sharex=True, sharey=True,
    constrained_layout=True,
)

for ax, algo in zip(axes, ALGOS):
    # Use seed 0
    key = f"seed0__{algo}__iter{final_ckpt}"
    if key not in particles:
        continue
    p_pred = posterior_predictive(particles[key], X_np)

    # Bin predictions
    bin_edges = np.linspace(0, 1, N_BINS + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    observed_freq = np.zeros(N_BINS)
    bin_counts = np.zeros(N_BINS)

    for b in range(N_BINS):
        mask = (p_pred >= bin_edges[b]) & (p_pred < bin_edges[b + 1])
        if b == N_BINS - 1:
            mask |= (p_pred == bin_edges[b + 1])  # include right edge
        if mask.sum() > 0:
            observed_freq[b] = y_np[mask].mean()
            bin_counts[b] = mask.sum()

    color = ALGO_COLORS.get(algo, "#333")
    valid = bin_counts > 0
    ax.bar(
        bin_centers[valid], observed_freq[valid],
        width=0.08, color=color, alpha=0.7,
    )
    ax.plot([0, 1], [0, 1], "--", color="#999", linewidth=0.8)
    ax.set_title(algo, fontsize=9)
    ax.set_aspect("equal")
    ax.tick_params(direction="out", length=3)

axes[0].set_ylabel("Observed frequency", fontsize=8)
axes[len(ALGOS) // 2].set_xlabel("Predicted probability", fontsize=8)
plt.show()

## 9. Variance ratio over iterations

Track the median variance ratio across dimensions over the course of optimization.
This reveals whether algorithms converge to the correct posterior spread or collapse.

In [None]:
if "variance_ratio_ref" in exp.get("metrics", []):
    vr_data = gather_metric("variance_ratio_ref")

    fig, ax = plt.subplots(figsize=(COL_WIDTH * 1.5, 2.0))

    for algo in ALGOS:
        median, q25, q75 = vr_data[algo]
        color = ALGO_COLORS.get(algo, "#333")
        is_baseline = algo in ("SVGD", "ULA", "MPPI")
        ls = "--" if is_baseline else "-"
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.10)
        ax.plot(X_POS, median, color=color, linewidth=1.5, linestyle=ls, label=algo)

    ax.axhline(1.0, color="#999", linewidth=0.8, linestyle="--", label="Ideal")
    ax.set_xticks(X_POS)
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=7)
    ax.set_ylabel("Median variance ratio", fontsize=8)
    ax.set_xlabel("Iteration", fontsize=8)
    ax.legend(fontsize=7, frameon=False)
    ax.tick_params(direction="out", length=3)
    plt.show()
else:
    print("variance_ratio_ref not in metrics config — skipping.")

## 10. Final metrics summary

In [None]:
metric_names = exp.get("metrics", ["energy_distance", "mean_error"])

print(f"{'Algorithm':<15}", end="")
for m in metric_names:
    print(f"  {m:<28}", end="")
print()
print("-" * (15 + 30 * len(metric_names)))

for algo in ALGOS:
    print(f"{algo:<15}", end="")
    for m in metric_names:
        vals = [
            metrics[s].get(algo, {}).get(final_ckpt, {}).get(m, np.nan)
            for s in SEEDS
        ]
        valid = [v for v in vals if not np.isnan(v)]
        if valid:
            med = np.median(valid)
            iqr_lo = np.percentile(valid, 25)
            iqr_hi = np.percentile(valid, 75)
            print(f"  {med:>8.4f}  ({iqr_lo:.4f}\u2013{iqr_hi:.4f})  ", end="")
        else:
            print(f"  {'N/A':>28}", end="")
    print()

## 11. Log-posterior evaluation

Mean log-posterior of particles over iterations — a quick sanity check
that particles are moving toward high-probability regions.

In [None]:
fig, ax = plt.subplots(figsize=(COL_WIDTH * 1.5, 2.0))

for algo in ALGOS:
    lp_per_seed = []
    for seed_idx in range(len(SEEDS)):
        lp_ckpts = []
        for ckpt in CHECKPOINTS:
            key = f"seed{seed_idx}__{algo}__iter{ckpt}"
            if key in particles:
                pts = jnp.array(particles[key])
                lp = float(target.log_prob(pts).mean())
            else:
                lp = np.nan
            lp_ckpts.append(lp)
        lp_per_seed.append(lp_ckpts)

    lp_arr = np.array(lp_per_seed)
    median = np.nanmedian(lp_arr, axis=0)
    q25 = np.nanpercentile(lp_arr, 25, axis=0)
    q75 = np.nanpercentile(lp_arr, 75, axis=0)

    color = ALGO_COLORS.get(algo, "#333")
    is_baseline = algo in ("SVGD", "ULA", "MPPI")
    ls = "--" if is_baseline else "-"
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.10)
    ax.plot(X_POS, median, color=color, linewidth=1.5, linestyle=ls, label=algo)

# NUTS reference log-posterior
if HAS_REF:
    ref_lp = float(target.log_prob(jnp.array(ref_samples)).mean())
    ax.axhline(ref_lp, color="#999", linewidth=0.8, linestyle=":", label="NUTS mean")

ax.set_xticks(X_POS)
ax.set_xticklabels(X_LABELS_SPARSE, fontsize=7)
ax.set_ylabel("Mean log-posterior", fontsize=8)
ax.set_xlabel("Iteration", fontsize=8)
ax.legend(fontsize=7, frameon=False)
ax.tick_params(direction="out", length=3)
plt.show()