# Banana Target Results

Visualize ETD variants vs baselines on the 2D banana (Rosenbrock) target.

The banana distribution has density $x_1 \sim \mathcal{N}(0, \sigma_1^2)$,
$x_k \mid x_1 \sim \mathcal{N}(b(x_1^2 - a), \sigma_2^2)$. The curvature
parameter $b$ creates a curved ridge that challenges methods assuming
Gaussian or factorized structure.

**Prerequisites:**
```bash
python -m experiments.run experiments/configs/banana_2d.yaml
```

In [None]:
import json
import sys
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

# Project imports
ROOT = Path.cwd().parent
sys.path.insert(0, str(ROOT / "src"))
sys.path.insert(0, str(ROOT))

from etd.targets.banana import BananaTarget
from figures.style import (
    ALGO_COLORS, ALGO_LINESTYLES, FULL_WIDTH, COL_WIDTH,
    facet_grid, frame_panel, ref_line,
    load_display_metadata, plot_contours, plot_particles,
    setup_style,
)

setup_style()
%config InlineBackend.figure_format = 'retina'

## 1. Load data

In [None]:
import yaml

# --- Find the latest results directory ---
results_root = ROOT / "results" / "banana-2d"
assert results_root.exists(), (
    f"No results at {results_root}.\n"
    "Run: python -m experiments.run experiments/configs/banana_2d.yaml"
)
runs = sorted(results_root.iterdir())
RESULTS_DIR = runs[-1]
print(f"Loading from: {RESULTS_DIR}")

# --- Load config (source of truth) ---
with open(RESULTS_DIR / "config.yaml") as f:
    exp_config = yaml.safe_load(f)

exp = exp_config["experiment"]
target_cfg = exp["target"]
target_params = target_cfg.get("params", {})

# --- Load metrics ---
with open(RESULTS_DIR / "metrics.json") as f:
    raw_metrics = json.load(f)

metrics = {}
for seed, algo_dict in raw_metrics.items():
    metrics[seed] = {}
    for algo, ckpt_dict in algo_dict.items():
        metrics[seed][algo] = {int(k): v for k, v in ckpt_dict.items()}

# --- Load particles ---
particles = dict(np.load(RESULTS_DIR / "particles.npz"))

# --- Load display metadata (if available) ---
display_meta = load_display_metadata(str(RESULTS_DIR))

def algo_color(algo):
    """Resolve color: metadata.json > ALGO_COLORS > dark gray."""
    if algo in display_meta:
        return display_meta[algo]["color"]
    return ALGO_COLORS.get(algo, "#333")

def algo_linestyle(algo):
    """Resolve linestyle: metadata.json > ALGO_LINESTYLES > solid."""
    if algo in display_meta:
        return display_meta[algo]["linestyle"]
    return ALGO_LINESTYLES.get(algo, "-")

# --- Build target from saved config ---
target = BananaTarget(**target_params)
print(f"Target: banana, dim={target.dim}")
print(f"  curvature={target.curvature}, offset={target.offset}")
print(f"  sigma1={target.sigma1}, sigma2={target.sigma2}")
print(f"  mean={np.array(target.mean)}")

# --- Constants ---
SEEDS = [f"seed{i}" for i in exp["seeds"]]
CHECKPOINTS = sorted(set(
    int(k) for algo_dict in metrics.values()
    for ckpt_dict in algo_dict.values()
    for k in ckpt_dict
))
ALL_ALGOS = sorted(set(
    algo for algo_dict in metrics.values() for algo in algo_dict
))
ALGOS = ALL_ALGOS
N_ALGOS = len(ALGOS)

# Evenly-spaced x positions for convergence plots
X_POS = np.arange(len(CHECKPOINTS))
TICK_SHOW = set(range(0, len(CHECKPOINTS), max(1, len(CHECKPOINTS) // 5)))
TICK_SHOW.add(len(CHECKPOINTS) - 1)
X_LABELS_SPARSE = [
    str(c) if i in TICK_SHOW else "" for i, c in enumerate(CHECKPOINTS)
]

print(f"Checkpoints: {CHECKPOINTS}")
print(f"Algorithms: {ALGOS}")

In [None]:
# --- Reference samples (exact sampling) ---
import jax

ref_path = RESULTS_DIR / "reference.npz"
if ref_path.exists():
    reference_samples = np.load(ref_path)["samples"]
else:
    reference_samples = np.array(target.sample(jax.random.PRNGKey(99999), 5000))

print(f"Reference samples: {reference_samples.shape}")

# Axis limits from reference bulk (percentile-based, robust to heavy tails)
ref_pad = 2.0
lo, hi = 2, 98  # percentiles
X1_LIM = (
    float(np.percentile(reference_samples[:, 0], lo)) - ref_pad,
    float(np.percentile(reference_samples[:, 0], hi)) + ref_pad,
)
X2_LIM = (
    float(np.percentile(reference_samples[:, 1], lo)) - ref_pad,
    float(np.percentile(reference_samples[:, 1], hi)) + ref_pad,
)
print(f"x1 range: {X1_LIM}")
print(f"x2 range: {X2_LIM}")

## 2. Final particle scatter

Seed 0, final iteration. Banana density contours as background.
The key question: do particles fill the curved ridge or collapse to a blob?

In [None]:
final_ckpt = max(CHECKPOINTS)

fig, axes = facet_grid(N_ALGOS, panel_size=1.6, square=False)

for ax, algo in zip(axes, ALGOS):
    plot_contours(ax, target.log_prob, X1_LIM, X2_LIM)
    key = f"seed0__{algo}__iter{final_ckpt}"
    if key in particles:
        plot_particles(ax, particles[key], color=algo_color(algo), s=10)
    ax.set_title(algo, fontsize=9)
    ax.set_xlim(X1_LIM)
    ax.set_ylim(X2_LIM)
    frame_panel(ax)

plt.show()

## 3. Convergence traces (faceted by algorithm)

One column per algorithm, one row per metric. Foreground is median over seeds
with IQR shading; gray traces show other algorithms for context.

In [None]:
def gather_metric(metric_name, algos=None):
    """Return {algo: (median, q25, q75)} arrays over checkpoints."""
    if algos is None:
        algos = ALGOS
    out = {}
    for algo in algos:
        vals = []
        for seed in SEEDS:
            row = [
                metrics.get(seed, {}).get(algo, {}).get(c, {}).get(metric_name, np.nan)
                for c in CHECKPOINTS
            ]
            vals.append(row)
        vals = np.array(vals)
        out[algo] = (
            np.nanmedian(vals, axis=0),
            np.nanpercentile(vals, 25, axis=0),
            np.nanpercentile(vals, 75, axis=0),
        )
    return out

In [None]:
metric_specs = [
    ("energy_distance", "Energy dist."),
    ("sliced_wasserstein", "Sliced $W_2$"),
    ("mean_error", "Mean error"),
]
n_metrics = len(metric_specs)

fig, axes = plt.subplots(
    n_metrics, N_ALGOS,
    figsize=(FULL_WIDTH, 1.5 * n_metrics),
    sharex=True, sharey="row",
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[:, np.newaxis]

for row, (metric_name, ylabel) in enumerate(metric_specs):
    data = gather_metric(metric_name)
    for col, algo in enumerate(ALGOS):
        ax = axes[row, col]

        # Gray context
        for other in ALGOS:
            if other == algo:
                continue
            med_o, _, _ = data[other]
            ax.plot(X_POS, med_o, color="#cccccc", linewidth=0.6, zorder=1)

        # Foreground
        median, q25, q75 = data[algo]
        color = algo_color(algo)
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18, zorder=2)
        ax.plot(X_POS, median, color=color, linewidth=1.5, zorder=3)

        ax.set_xticks(X_POS)
        ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
        ax.tick_params(direction="out", length=3)

        if row == 0:
            ax.set_title(algo, fontsize=9)
        if col == 0:
            ax.set_ylabel(ylabel, fontsize=8)
        if row == n_metrics - 1:
            ax.set_xlabel("Iteration", fontsize=8)

plt.show()

## 4. Convergence overlay

All algorithms on the same axes for direct comparison.

In [None]:
fig, axes = plt.subplots(
    1, len(metric_specs),
    figsize=(FULL_WIDTH, 2.2),
    constrained_layout=True,
)

for ax, (metric_name, ylabel) in zip(axes, metric_specs):
    data = gather_metric(metric_name)
    for algo in ALGOS:
        median, q25, q75 = data[algo]
        color = algo_color(algo)
        ls = algo_linestyle(algo)
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.10)
        ax.plot(X_POS, median, color=color, linewidth=1.5,
                linestyle=ls, label=algo)

    ax.set_xticks(X_POS)
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=7)
    ax.set_ylabel(ylabel, fontsize=8)
    ax.set_xlabel("Iteration", fontsize=8)
    ax.tick_params(direction="out", length=3)

axes[0].legend(fontsize=7, frameon=False, loc="upper right")
plt.show()

## 5. Marginal distributions

KDE of each coordinate at the final iteration, overlaid with the true marginal.
For the banana, $x_1$ is Gaussian but $x_2$ has heavy tails due to the
$b(x_1^2 - a)$ dependence.

In [None]:
fig, axes = plt.subplots(
    target.dim, 1,
    figsize=(COL_WIDTH * 1.8, 1.5 * target.dim),
    constrained_layout=True,
)
if target.dim == 1:
    axes = [axes]

for row, dim_idx in enumerate(range(target.dim)):
    ax = axes[row]

    # Reference KDE (percentile-based grid, robust to tails)
    ref_dim = reference_samples[:, dim_idx]
    kde_ref = gaussian_kde(ref_dim)
    x_grid = np.linspace(
        np.percentile(ref_dim, 1) - 1.0,
        np.percentile(ref_dim, 99) + 1.0,
        300,
    )
    ax.fill_between(
        x_grid, kde_ref(x_grid),
        color="#DDDDDD", alpha=0.6,
        label="Reference" if row == 0 else None,
    )
    ax.plot(x_grid, kde_ref(x_grid), color="#999999", linewidth=0.8)

    # Algorithm KDEs (seed 0)
    for algo in ALGOS:
        key = f"seed0__{algo}__iter{final_ckpt}"
        if key not in particles:
            continue
        vals = particles[key][:, dim_idx]
        try:
            kde = gaussian_kde(vals)
            ax.plot(
                x_grid, kde(x_grid),
                color=algo_color(algo), linewidth=1.2,
                linestyle=algo_linestyle(algo),
                label=algo if row == 0 else None,
            )
        except np.linalg.LinAlgError:
            pass

    ax.set_ylabel(f"$x_{dim_idx + 1}$", fontsize=9)
    ax.set_yticks([])
    ax.tick_params(direction="out", length=3)

axes[-1].set_xlabel("Value", fontsize=9)

# Shared legend below the figure
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles, labels, fontsize=7, frameon=False,
    ncol=min(N_ALGOS + 1, 6),
    loc="upper center", bbox_to_anchor=(0.5, 1.06),
)
plt.show()

## 6. Particle evolution snapshots

How particles evolve over iterations. The banana's curved ridge is clearly
visible in the contours â€” watch how different algorithms navigate it.

In [None]:
# Pick ~5 evenly spaced checkpoints
n_snaps = min(5, len(CHECKPOINTS))
snap_indices = np.linspace(0, len(CHECKPOINTS) - 1, n_snaps, dtype=int)
snapshot_iters = [CHECKPOINTS[i] for i in snap_indices]

# Compute global limits from all particles across all algos/snapshots
# (percentile-based to handle outliers)
all_pts = []
for algo in ALGOS:
    for it in snapshot_iters:
        key = f"seed0__{algo}__iter{it}"
        if key in particles:
            all_pts.append(particles[key])
all_pts = np.concatenate(all_pts, axis=0)
pad = 2.0
evo_x1_lim = (float(np.percentile(all_pts[:, 0], 2)) - pad,
              float(np.percentile(all_pts[:, 0], 98)) + pad)
evo_x2_lim = (float(np.percentile(all_pts[:, 1], 2)) - pad,
              float(np.percentile(all_pts[:, 1], 98)) + pad)

fig, axes = plt.subplots(
    N_ALGOS, n_snaps,
    figsize=(FULL_WIDTH, 1.3 * N_ALGOS + 0.3),
    sharex=True, sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[np.newaxis, :]

for row, algo in enumerate(ALGOS):
    for col, it in enumerate(snapshot_iters):
        ax = axes[row, col]
        plot_contours(ax, target.log_prob, evo_x1_lim, evo_x2_lim)
        key = f"seed0__{algo}__iter{it}"
        if key in particles:
            plot_particles(
                ax, particles[key], color=algo_color(algo), s=12,
            )
        ax.set_xlim(evo_x1_lim)
        ax.set_ylim(evo_x2_lim)
        frame_panel(ax)

        if row == 0:
            ax.set_title(f"iter {it}", fontsize=8)
        if col == 0:
            ax.set_ylabel(algo, fontsize=8)

plt.show()

## 7. Per-dimension variance ratio

Ratio of particle variance to true marginal variance per dimension.
For the banana, $\text{Var}(x_2)$ is large due to the nonlinear dependence
on $x_1$: algorithms that fail to capture the curvature will under-estimate
the $x_2$ variance.

In [None]:
true_var = np.array(target.variance)
print(f"True variance: {true_var}")

fig, axes = plt.subplots(
    1, N_ALGOS,
    figsize=(min(FULL_WIDTH, 2.0 * N_ALGOS), 2.0),
    sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = [axes]

dim_labels = [f"$x_{i+1}$" for i in range(target.dim)]

for ax, algo in zip(axes, ALGOS):
    ratios_per_seed = []
    for seed_idx in range(len(SEEDS)):
        key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
        if key in particles:
            emp_var = np.var(particles[key], axis=0)
            ratio = emp_var / np.where(true_var > 1e-12, true_var, 1.0)
            ratios_per_seed.append(ratio)

    if ratios_per_seed:
        arr = np.array(ratios_per_seed)
        median_r = np.nanmedian(arr, axis=0)
        q25 = np.nanpercentile(arr, 25, axis=0)
        q75 = np.nanpercentile(arr, 75, axis=0)

        color = algo_color(algo)
        x = np.arange(target.dim)
        ax.errorbar(
            x, median_r,
            yerr=[median_r - q25, q75 - median_r],
            fmt="o", color=color, markersize=4,
            markeredgecolor="white", markeredgewidth=0.5,
            capsize=3, capthick=0.8, linewidth=0.8, zorder=5,
        )

    ref_line(ax, 1.0)
    ax.set_title(algo, fontsize=9)
    ax.set_xticks(range(target.dim))
    ax.set_xticklabels(dim_labels, fontsize=8)
    if ax is axes[0]:
        ax.set_ylabel("Var ratio (emp / true)", fontsize=8)

plt.show()

## 8. Correlation capture

The banana has strong nonlinear correlation between $x_1$ and $x_2$.
Compare the empirical $x_1$-$x_2$ scatter against the reference to see
whether each algorithm captures this structure.

In [None]:
fig, axes = facet_grid(N_ALGOS + 1, panel_size=1.6, square=False)

# Reference
ax_ref = axes[0]
ax_ref.scatter(
    reference_samples[:500, 0], reference_samples[:500, 1],
    s=8, color="#999999", alpha=0.4, edgecolors="none",
)
ax_ref.set_title("Reference", fontsize=9)
frame_panel(ax_ref)

# Algorithms
for ax, algo in zip(axes[1:], ALGOS):
    key = f"seed0__{algo}__iter{final_ckpt}"
    if key in particles:
        pts = particles[key]
        ax.scatter(
            pts[:, 0], pts[:, 1],
            s=10, color=algo_color(algo),
            edgecolors="white", linewidths=0.4,
            alpha=0.8,
        )
    ax.set_title(algo, fontsize=9)
    frame_panel(ax)

plt.show()

## 9. Mean log-posterior over iterations

Sanity check that particles move toward high-probability regions.
The reference line shows the mean log-posterior of exact samples.

In [None]:
fig, ax = plt.subplots(figsize=(COL_WIDTH * 1.5, 2.2))

for algo in ALGOS:
    lp_per_seed = []
    for seed_idx in range(len(SEEDS)):
        lp_ckpts = []
        for ckpt in CHECKPOINTS:
            key = f"seed{seed_idx}__{algo}__iter{ckpt}"
            if key in particles:
                pts = jnp.array(particles[key])
                lp = float(target.log_prob(pts).mean())
            else:
                lp = np.nan
            lp_ckpts.append(lp)
        lp_per_seed.append(lp_ckpts)

    lp_arr = np.array(lp_per_seed)
    median = np.nanmedian(lp_arr, axis=0)
    q25 = np.nanpercentile(lp_arr, 25, axis=0)
    q75 = np.nanpercentile(lp_arr, 75, axis=0)

    color = algo_color(algo)
    ls = algo_linestyle(algo)
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.10)
    ax.plot(X_POS, median, color=color, linewidth=1.5, linestyle=ls, label=algo)

# Reference log-posterior
ref_lp = float(target.log_prob(jnp.array(reference_samples)).mean())
ref_line(ax, ref_lp, linestyle=":", label="Reference mean")

ax.set_xticks(X_POS)
ax.set_xticklabels(X_LABELS_SPARSE, fontsize=7)
ax.set_ylabel("Mean log-posterior", fontsize=8)
ax.set_xlabel("Iteration", fontsize=8)
ax.legend(fontsize=7, frameon=False)
ax.tick_params(direction="out", length=3)
plt.show()

## 10. Summary table

In [None]:
metric_names = exp.get("metrics", ["energy_distance", "mean_error"])

print(f"{'Algorithm':<18}", end="")
for m in metric_names:
    print(f"  {m:<28}", end="")
print()
print("-" * (18 + 30 * len(metric_names)))

for algo in ALGOS:
    print(f"{algo:<18}", end="")
    for m in metric_names:
        vals = [
            metrics[s].get(algo, {}).get(final_ckpt, {}).get(m, np.nan)
            for s in SEEDS
        ]
        valid = [v for v in vals if not np.isnan(v)]
        if valid:
            med = np.median(valid)
            q25 = np.percentile(valid, 25)
            q75 = np.percentile(valid, 75)
            print(f"  {med:>8.4f} ({q25:.4f}\u2013{q75:.4f})  ", end="")
        else:
            print(f"  {'N/A':>28}", end="")
    print()