# Funnel Target Results

Visualize ETD variants vs baselines on Neal's funnel distribution.

The funnel has density $v \sim \mathcal{N}(0, \sigma_v^2)$,
$x_k \mid v \sim \mathcal{N}(0, e^v)$ for $k = 1, \dots, d-1$.
The scale variable $v$ (last coordinate) controls the spread of all
other dimensions, creating severe heteroscedasticity: the "mouth"
($v > 0$) has large spread while the "neck" ($v < 0$) compresses
particles to near-zero.

This is a notoriously difficult target — methods that collapse to
one scale regime fail to approximate the full posterior.

**Prerequisites:**
```bash
python -m experiments.run experiments/configs/funnel_10d.yaml
```

In [None]:
import json
import sys
from pathlib import Path

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

# Project imports
ROOT = Path.cwd().parent
sys.path.insert(0, str(ROOT / "src"))
sys.path.insert(0, str(ROOT))

from etd.targets.funnel import FunnelTarget
from figures.style import (
    ALGO_COLORS, FULL_WIDTH, COL_WIDTH,
    load_display_metadata, setup_style,
)

setup_style()
%config InlineBackend.figure_format = 'retina'

## 1. Load data

In [None]:
import yaml

# --- Find the latest results directory ---
results_root = ROOT / "results" / "funnel-10d"
assert results_root.exists(), (
    f"No results at {results_root}.\n"
    "Run: python -m experiments.run experiments/configs/funnel_10d.yaml"
)
runs = sorted(results_root.iterdir())
RESULTS_DIR = runs[-1]
print(f"Loading from: {RESULTS_DIR}")

# --- Load config ---
with open(RESULTS_DIR / "config.yaml") as f:
    exp_config = yaml.safe_load(f)

exp = exp_config["experiment"]
target_cfg = exp["target"]
target_params = target_cfg.get("params", {})

# --- Load metrics ---
with open(RESULTS_DIR / "metrics.json") as f:
    raw_metrics = json.load(f)

metrics = {}
for seed, algo_dict in raw_metrics.items():
    metrics[seed] = {}
    for algo, ckpt_dict in algo_dict.items():
        metrics[seed][algo] = {int(k): v for k, v in ckpt_dict.items()}

# --- Load particles ---
particles = dict(np.load(RESULTS_DIR / "particles.npz"))

# --- Load display metadata (if available) ---
display_meta = load_display_metadata(str(RESULTS_DIR))

def algo_color(algo):
    """Resolve color: metadata.json > ALGO_COLORS > dark gray."""
    if algo in display_meta:
        return display_meta[algo]["color"]
    return ALGO_COLORS.get(algo, "#333")

def algo_linestyle(algo):
    """Resolve linestyle: metadata.json > solid."""
    if algo in display_meta:
        return display_meta[algo]["linestyle"]
    return "-"

# --- Build target ---
target = FunnelTarget(**target_params)
DIM = target.dim
V_IDX = DIM - 1  # last coordinate is the scale variable v
print(f"Target: funnel, dim={DIM}, sigma_v={target.sigma_v}")
print(f"  mean = {np.array(target.mean)}")
print(f"  variance = {np.array(target.variance)}")
print(f"  (Var[x_k] = exp(sigma_v^2 / 2) = {np.exp(target.sigma_v**2 / 2):.1f})")

# --- Constants ---
SEEDS = [f"seed{i}" for i in exp["seeds"]]
CHECKPOINTS = sorted(set(
    int(k) for algo_dict in metrics.values()
    for ckpt_dict in algo_dict.values()
    for k in ckpt_dict
))
ALL_ALGOS = sorted(set(
    algo for algo_dict in metrics.values() for algo in algo_dict
))
ALGOS = ALL_ALGOS
N_ALGOS = len(ALGOS)

X_POS = np.arange(len(CHECKPOINTS))
TICK_SHOW = set(range(0, len(CHECKPOINTS), max(1, len(CHECKPOINTS) // 5)))
TICK_SHOW.add(len(CHECKPOINTS) - 1)
X_LABELS_SPARSE = [
    str(c) if i in TICK_SHOW else "" for i, c in enumerate(CHECKPOINTS)
]

print(f"Checkpoints: {CHECKPOINTS}")
print(f"Algorithms: {ALGOS}")

In [None]:
# --- Reference samples (exact ancestral sampling) ---
ref_path = RESULTS_DIR / "reference.npz"
if ref_path.exists():
    reference_samples = np.load(ref_path)["samples"]
else:
    reference_samples = np.array(target.sample(jax.random.PRNGKey(99999), 5000))

print(f"Reference samples: {reference_samples.shape}")
ref_v = reference_samples[:, V_IDX]
print(f"  v range: [{ref_v.min():.1f}, {ref_v.max():.1f}]")
print(f"  v std: {ref_v.std():.2f} (true: {target.sigma_v:.1f})")

## 2. v–x₁ scatter (the funnel shape)

The defining diagnostic for the funnel: plot $x_1$ vs $v$ (the scale
variable). The reference forms a funnel opening to the right ($v > 0$).
Algorithms that collapse to one scale regime will show a truncated scatter.

In [None]:
final_ckpt = max(CHECKPOINTS)

fig, axes = plt.subplots(
    1, N_ALGOS + 1,
    figsize=(FULL_WIDTH, FULL_WIDTH / (N_ALGOS + 1) + 0.3),
    sharex=True, sharey=True,
)

# Reference
ax_ref = axes[0]
ax_ref.scatter(
    ref_v[:1000], reference_samples[:1000, 0],
    s=6, color="#999999", alpha=0.3, edgecolors="none", rasterized=True,
)
ax_ref.set_title("Reference", fontsize=9)
ax_ref.set_ylabel("$x_1$", fontsize=9)
ax_ref.set_xlabel("$v$", fontsize=9)

# Algorithms
for ax, algo in zip(axes[1:], ALGOS):
    key = f"seed0__{algo}__iter{final_ckpt}"
    if key in particles:
        pts = particles[key]
        ax.scatter(
            pts[:, V_IDX], pts[:, 0],
            s=12, color=algo_color(algo),
            edgecolors="white", linewidths=0.3,
            alpha=0.8,
        )
    ax.set_title(algo, fontsize=9)
    ax.set_xlabel("$v$", fontsize=9)
    ax.tick_params(labelleft=False)

fig.subplots_adjust(wspace=0.08)
plt.show()

## 3. Convergence traces (faceted by algorithm)

In [None]:
def gather_metric(metric_name, algos=None):
    """Return {algo: (median, q25, q75)} arrays over checkpoints."""
    if algos is None:
        algos = ALGOS
    out = {}
    for algo in algos:
        vals = []
        for seed in SEEDS:
            row = [
                metrics.get(seed, {}).get(algo, {}).get(c, {}).get(metric_name, np.nan)
                for c in CHECKPOINTS
            ]
            vals.append(row)
        vals = np.array(vals)
        out[algo] = (
            np.nanmedian(vals, axis=0),
            np.nanpercentile(vals, 25, axis=0),
            np.nanpercentile(vals, 75, axis=0),
        )
    return out

In [None]:
metric_specs = [
    ("energy_distance", "Energy dist."),
    ("mean_error", "Mean error"),
]
n_metrics = len(metric_specs)

fig, axes = plt.subplots(
    n_metrics, N_ALGOS,
    figsize=(FULL_WIDTH, 1.5 * n_metrics),
    sharex=True, sharey="row",
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[:, np.newaxis]

for row, (metric_name, ylabel) in enumerate(metric_specs):
    data = gather_metric(metric_name)
    for col, algo in enumerate(ALGOS):
        ax = axes[row, col]

        # Gray context
        for other in ALGOS:
            if other == algo:
                continue
            med_o, _, _ = data[other]
            ax.plot(X_POS, med_o, color="#cccccc", linewidth=0.6, zorder=1)

        # Foreground
        median, q25, q75 = data[algo]
        color = algo_color(algo)
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18, zorder=2)
        ax.plot(X_POS, median, color=color, linewidth=1.5, zorder=3)

        ax.set_xticks(X_POS)
        ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
        ax.tick_params(direction="out", length=3)

        if row == 0:
            ax.set_title(algo, fontsize=9)
        if col == 0:
            ax.set_ylabel(ylabel, fontsize=8)
        if row == n_metrics - 1:
            ax.set_xlabel("Iteration", fontsize=8)

plt.show()

## 4. Convergence overlay

In [None]:
fig, axes = plt.subplots(
    1, len(metric_specs),
    figsize=(FULL_WIDTH, 2.2),
    constrained_layout=True,
)

for ax, (metric_name, ylabel) in zip(axes, metric_specs):
    data = gather_metric(metric_name)
    for algo in ALGOS:
        median, q25, q75 = data[algo]
        color = algo_color(algo)
        ls = algo_linestyle(algo)
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.10)
        ax.plot(X_POS, median, color=color, linewidth=1.5,
                linestyle=ls, label=algo)

    ax.set_xticks(X_POS)
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=7)
    ax.set_ylabel(ylabel, fontsize=8)
    ax.set_xlabel("Iteration", fontsize=8)
    ax.tick_params(direction="out", length=3)

axes[0].legend(fontsize=7, frameon=False)
plt.show()

## 5. Marginal distributions

Compare per-dimension marginals against the reference.

Key things to watch:
- **$v$ (last dim)**: should be Gaussian with std $\sigma_v$
- **$x_k$ (all other dims)**: heavy-tailed due to the scale mixture $x_k \mid v \sim \mathcal{N}(0, e^v)$

In [None]:
# Show v + a few x dimensions
show_dims = [V_IDX, 0, 1]
if DIM > 4:
    show_dims.append(DIM // 2)
dim_names = {
    V_IDX: "$v$",
}
for i in range(DIM - 1):
    dim_names[i] = f"$x_{{{i+1}}}$"

n_show = len(show_dims)

fig, axes = plt.subplots(
    n_show, 1,
    figsize=(COL_WIDTH * 1.8, 1.3 * n_show),
    constrained_layout=True,
)

for row, dim_idx in enumerate(show_dims):
    ax = axes[row]

    # Reference KDE
    ref_dim = reference_samples[:, dim_idx]
    kde_ref = gaussian_kde(ref_dim)
    x_grid = np.linspace(
        np.percentile(ref_dim, 0.5), np.percentile(ref_dim, 99.5), 300,
    )
    ax.fill_between(
        x_grid, kde_ref(x_grid),
        color="#DDDDDD", alpha=0.6,
        label="Reference" if row == 0 else None,
    )
    ax.plot(x_grid, kde_ref(x_grid), color="#999999", linewidth=0.8)

    # Algorithm KDEs
    for algo in ALGOS:
        key = f"seed0__{algo}__iter{final_ckpt}"
        if key not in particles:
            continue
        vals = particles[key][:, dim_idx]
        try:
            kde = gaussian_kde(vals)
            ax.plot(
                x_grid, kde(x_grid),
                color=algo_color(algo), linewidth=1.2,
                linestyle=algo_linestyle(algo),
                label=algo if row == 0 else None,
            )
        except np.linalg.LinAlgError:
            pass

    ax.set_ylabel(dim_names.get(dim_idx, f"dim {dim_idx}"), fontsize=9)
    ax.set_yticks([])
    ax.tick_params(direction="out", length=3)

axes[0].legend(fontsize=7, frameon=False, ncol=min(N_ALGOS + 1, 4))
axes[-1].set_xlabel("Value", fontsize=9)
fig.suptitle("Marginal distributions (seed 0, final iteration)", fontsize=10, y=1.01)
plt.show()

## 6. Per-dimension variance ratio

**This is the critical funnel diagnostic.** The true marginal variance of
$x_k$ is $\exp(\sigma_v^2 / 2)$, which for $\sigma_v = 3$ equals
$\exp(4.5) \approx 90$. Algorithms that fail to explore the mouth of
the funnel will massively under-estimate this variance.

In [None]:
true_var = np.array(target.variance)
print(f"True variance per dim: {true_var}")

fig, axes = plt.subplots(
    1, N_ALGOS,
    figsize=(FULL_WIDTH, 2.0),
    sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = [axes]

dim_labels = [dim_names.get(i, f"{i}") for i in range(DIM)]

for ax, algo in zip(axes, ALGOS):
    ratios_per_seed = []
    for seed_idx in range(len(SEEDS)):
        key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
        if key in particles:
            emp_var = np.var(particles[key], axis=0)
            ratio = emp_var / np.where(true_var > 1e-12, true_var, 1.0)
            ratios_per_seed.append(ratio)

    if ratios_per_seed:
        arr = np.array(ratios_per_seed)
        median_r = np.nanmedian(arr, axis=0)
        q25 = np.nanpercentile(arr, 25, axis=0)
        q75 = np.nanpercentile(arr, 75, axis=0)

        color = algo_color(algo)
        x = np.arange(DIM)
        ax.bar(x, median_r, 0.6, color=color, alpha=0.7,
               yerr=[median_r - q25, q75 - median_r], capsize=2,
               error_kw={"linewidth": 0.8})

    ax.axhline(1.0, color="#999", linewidth=0.8, linestyle="--")
    ax.set_title(algo, fontsize=9)
    ax.set_xticks(range(DIM))
    ax.set_xticklabels(dim_labels, fontsize=6)
    ax.set_xlabel("Dimension", fontsize=7)
    if ax is axes[0]:
        ax.set_ylabel("Var ratio (emp / true)", fontsize=8)

plt.show()

## 7. Scale variable $v$ over iterations

Track the spread of the scale variable $v$ over time. The true
distribution is $v \sim \mathcal{N}(0, \sigma_v^2)$. Algorithms that
collapse in the neck will show a narrow $v$ distribution concentrated
near negative values.

In [None]:
fig, axes = plt.subplots(
    2, N_ALGOS,
    figsize=(FULL_WIDTH, 3.0),
    sharex=True, sharey="row",
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[:, np.newaxis]

for col, algo in enumerate(ALGOS):
    # Row 0: mean of v over iterations
    # Row 1: std of v over iterations
    v_mean_seeds = []
    v_std_seeds = []
    for seed_idx in range(len(SEEDS)):
        mean_row, std_row = [], []
        for ckpt in CHECKPOINTS:
            key = f"seed{seed_idx}__{algo}__iter{ckpt}"
            if key in particles:
                v_vals = particles[key][:, V_IDX]
                mean_row.append(np.mean(v_vals))
                std_row.append(np.std(v_vals))
            else:
                mean_row.append(np.nan)
                std_row.append(np.nan)
        v_mean_seeds.append(mean_row)
        v_std_seeds.append(std_row)

    color = algo_color(algo)

    # Mean of v
    ax = axes[0, col]
    arr = np.array(v_mean_seeds)
    med = np.nanmedian(arr, axis=0)
    q25, q75 = np.nanpercentile(arr, 25, axis=0), np.nanpercentile(arr, 75, axis=0)
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18)
    ax.plot(X_POS, med, color=color, linewidth=1.5)
    ax.axhline(0, color="#999", linewidth=0.5, linestyle="--")
    ax.set_title(algo, fontsize=9)
    if col == 0:
        ax.set_ylabel("Mean $v$", fontsize=8)

    # Std of v
    ax = axes[1, col]
    arr = np.array(v_std_seeds)
    med = np.nanmedian(arr, axis=0)
    q25, q75 = np.nanpercentile(arr, 25, axis=0), np.nanpercentile(arr, 75, axis=0)
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18)
    ax.plot(X_POS, med, color=color, linewidth=1.5)
    ax.axhline(target.sigma_v, color="#999", linewidth=0.8, linestyle="--",
               label=f"True $\\sigma_v={target.sigma_v}$" if col == 0 else None)
    ax.set_xlabel("Iteration", fontsize=8)
    ax.set_xticks(X_POS)
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
    if col == 0:
        ax.set_ylabel("Std $v$", fontsize=8)
        ax.legend(fontsize=7, frameon=False)

plt.show()

## 8. Conditional spread: $\text{std}(x_1 \mid v)$

Bin particles by their $v$ value and measure the conditional standard
deviation of $x_1$ in each bin. The true conditional std is $\exp(v/2)$.
This reveals whether algorithms correctly adapt their x-spread to the
local scale set by $v$.

In [None]:
n_bins = 8
v_range = np.linspace(
    np.percentile(ref_v, 2), np.percentile(ref_v, 98), n_bins + 1,
)
v_centers = 0.5 * (v_range[:-1] + v_range[1:])
true_cond_std = np.exp(v_centers / 2)

fig, axes = plt.subplots(
    1, N_ALGOS + 1,
    figsize=(FULL_WIDTH, 2.0),
    sharey=True,
    constrained_layout=True,
)

# Reference
ax = axes[0]
ref_cond_std = []
for b in range(n_bins):
    mask = (ref_v >= v_range[b]) & (ref_v < v_range[b + 1])
    if mask.sum() > 5:
        ref_cond_std.append(np.std(reference_samples[mask, 0]))
    else:
        ref_cond_std.append(np.nan)

ax.plot(v_centers, true_cond_std, "k--", linewidth=1, label="$e^{v/2}$")
ax.plot(v_centers, ref_cond_std, "o-", color="#999", markersize=4, label="Reference")
ax.set_title("Reference", fontsize=9)
ax.set_xlabel("$v$", fontsize=8)
ax.set_ylabel("$\\text{std}(x_1 \\mid v)$", fontsize=8)
ax.legend(fontsize=6, frameon=False)

# Algorithms
for ax, algo in zip(axes[1:], ALGOS):
    ax.plot(v_centers, true_cond_std, "k--", linewidth=1, alpha=0.5)

    cond_stds_seeds = []
    for seed_idx in range(len(SEEDS)):
        key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
        if key not in particles:
            continue
        pts = particles[key]
        row = []
        for b in range(n_bins):
            mask = (pts[:, V_IDX] >= v_range[b]) & (pts[:, V_IDX] < v_range[b + 1])
            if mask.sum() > 3:
                row.append(np.std(pts[mask, 0]))
            else:
                row.append(np.nan)
        cond_stds_seeds.append(row)

    if cond_stds_seeds:
        arr = np.array(cond_stds_seeds)
        med = np.nanmedian(arr, axis=0)
        ax.plot(v_centers, med, "o-", color=algo_color(algo), markersize=4)

    ax.set_title(algo, fontsize=9)
    ax.set_xlabel("$v$", fontsize=8)
    ax.tick_params(labelleft=False)

plt.show()

## 9. PCA projections over time

Project particles onto the top-2 PCs fitted on the reference samples.
The funnel's structure should be visible as a fan-shaped cloud.

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca.fit(reference_samples)
print(f"Explained variance ratio: {pca.explained_variance_ratio_}")

# Pick ~5 checkpoints
n_snaps = min(5, len(CHECKPOINTS))
snap_indices = np.linspace(0, len(CHECKPOINTS) - 1, n_snaps, dtype=int)
snapshot_iters = [CHECKPOINTS[i] for i in snap_indices]

fig, axes = plt.subplots(
    N_ALGOS, n_snaps,
    figsize=(FULL_WIDTH, 1.2 * N_ALGOS + 0.3),
    sharex=True, sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[np.newaxis, :]

# Reference cloud for extent
ref_2d = pca.transform(reference_samples[:500])

for row, algo in enumerate(ALGOS):
    for col, it in enumerate(snapshot_iters):
        ax = axes[row, col]

        # Light reference cloud
        ax.scatter(
            ref_2d[:, 0], ref_2d[:, 1],
            s=3, color="#DDDDDD", alpha=0.3, edgecolors="none",
            rasterized=True,
        )

        # Algorithm particles
        key = f"seed0__{algo}__iter{it}"
        if key in particles:
            pts_2d = pca.transform(particles[key])
            ax.scatter(
                pts_2d[:, 0], pts_2d[:, 1],
                s=10, color=algo_color(algo), alpha=0.6,
                edgecolors="none",
            )

        if row == 0:
            ax.set_title(f"iter {it}", fontsize=8)
        if col == 0:
            ax.set_ylabel(algo, fontsize=8)
        ax.set_aspect("equal")
        ax.tick_params(labelsize=6)

fig.suptitle("PCA projection (seed 0)", fontsize=10, y=1.02)
plt.show()

## 10. Mean log-posterior over iterations

In [None]:
fig, ax = plt.subplots(figsize=(COL_WIDTH * 1.5, 2.2))

for algo in ALGOS:
    lp_per_seed = []
    for seed_idx in range(len(SEEDS)):
        lp_ckpts = []
        for ckpt in CHECKPOINTS:
            key = f"seed{seed_idx}__{algo}__iter{ckpt}"
            if key in particles:
                pts = jnp.array(particles[key])
                lp = float(target.log_prob(pts).mean())
            else:
                lp = np.nan
            lp_ckpts.append(lp)
        lp_per_seed.append(lp_ckpts)

    lp_arr = np.array(lp_per_seed)
    median = np.nanmedian(lp_arr, axis=0)
    q25 = np.nanpercentile(lp_arr, 25, axis=0)
    q75 = np.nanpercentile(lp_arr, 75, axis=0)

    color = algo_color(algo)
    ls = algo_linestyle(algo)
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.10)
    ax.plot(X_POS, median, color=color, linewidth=1.5, linestyle=ls, label=algo)

# Reference log-posterior
ref_lp = float(target.log_prob(jnp.array(reference_samples)).mean())
ax.axhline(ref_lp, color="#999", linewidth=0.8, linestyle=":", label="Reference mean")

ax.set_xticks(X_POS)
ax.set_xticklabels(X_LABELS_SPARSE, fontsize=7)
ax.set_ylabel("Mean log-posterior", fontsize=8)
ax.set_xlabel("Iteration", fontsize=8)
ax.legend(fontsize=7, frameon=False)
ax.tick_params(direction="out", length=3)
plt.show()

## 11. Summary table

In [None]:
metric_names = exp.get("metrics", ["energy_distance", "mean_error"])

# Add computed v-std as a bonus metric
print(f"{'Algorithm':<15}", end="")
for m in metric_names:
    print(f"  {m:<24}", end="")
print(f"  {'v_std':<16}")
print("-" * (15 + 26 * len(metric_names) + 18))

for algo in ALGOS:
    print(f"{algo:<15}", end="")
    for m in metric_names:
        vals = [
            metrics[s].get(algo, {}).get(final_ckpt, {}).get(m, np.nan)
            for s in SEEDS
        ]
        valid = [v for v in vals if not np.isnan(v)]
        if valid:
            med = np.median(valid)
            q25 = np.percentile(valid, 25)
            q75 = np.percentile(valid, 75)
            print(f"  {med:>7.4f} ({q25:.3f}\u2013{q75:.3f})", end="")
        else:
            print(f"  {'N/A':>24}", end="")

    # Compute v-std at final checkpoint
    v_stds = []
    for seed_idx in range(len(SEEDS)):
        key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
        if key in particles:
            v_stds.append(np.std(particles[key][:, V_IDX]))
    if v_stds:
        print(f"  {np.median(v_stds):>5.2f} (true={target.sigma_v})")
    else:
        print(f"  {'N/A':>16}")