# GMM-2D-4 Benchmark Results

Visualize and analyze the Phase 2 benchmark: ETD variants vs baselines on a 2D, 4-mode Gaussian mixture.

**Algorithms:** ETD-B, ETD-SR, ETD-B-SF, SVGD, ULA, MPPI  
**Target:** 4 isotropic Gaussians on a grid (separation=6.0, σ=1.0)  
**Setup:** N=100 particles, 500 iterations, 5 seeds

In [None]:
import json
import sys
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax.scipy.special import logsumexp

# Project imports
ROOT = Path.cwd().parent
sys.path.insert(0, str(ROOT / "src"))
sys.path.insert(0, str(ROOT))

from etd.targets.gmm import GMMTarget
from figures.style import (
    ALGO_COLORS, FULL_WIDTH, COL_WIDTH,
    plot_contours, plot_particles, savefig_paper, setup_style,
)

setup_style()
%config InlineBackend.figure_format = 'retina'

## 1. Load data

In [None]:
# --- Find the latest results directory ---
results_root = ROOT / "results" / "gmm-2d-4"
runs = sorted(results_root.iterdir())
RESULTS_DIR = runs[-1]
print(f"Loading from: {RESULTS_DIR}")

# --- Load metrics ---
with open(RESULTS_DIR / "metrics.json") as f:
    raw_metrics = json.load(f)

# Convert checkpoint keys to ints
metrics = {}
for seed, algo_dict in raw_metrics.items():
    metrics[seed] = {}
    for algo, ckpt_dict in algo_dict.items():
        metrics[seed][algo] = {int(k): v for k, v in ckpt_dict.items()}

# --- Load particles ---
particles = dict(np.load(RESULTS_DIR / "particles.npz"))

# --- Build target ---
target = GMMTarget(dim=2, n_modes=4, arrangement="grid", separation=6.0, component_std=1.0)
modes = np.array(target.means)
print(f"Mode centers:\n{modes}")

# --- Constants ---
SEEDS = [f"seed{i}" for i in range(5)]
CHECKPOINTS = sorted(set(
    int(k) for algo_dict in metrics.values()
    for ckpt_dict in algo_dict.values()
    for k in ckpt_dict
))
ALL_ALGOS = sorted(set(
    algo for algo_dict in metrics.values() for algo in algo_dict
))
# Drop ETD-B-SF (diverged) for most plots
ALGOS = [a for a in ALL_ALGOS if a != "ETD-B-SF"]
print(f"Checkpoints: {CHECKPOINTS}")
print(f"Algorithms: {ALGOS}")

## 2. Final particle scatter plots

Seed 0, iteration 500. GMM contours as background.

In [None]:
AXIS_LIM = (-6.5, 6.5)
n_algos = len(ALGOS)

fig, axes = plt.subplots(
    1, n_algos,
    figsize=(FULL_WIDTH, FULL_WIDTH / n_algos + 0.15),
    sharex=True, sharey=True,
)

for ax, algo in zip(axes, ALGOS):
    plot_contours(ax, target.log_prob, AXIS_LIM, AXIS_LIM)
    key = f"seed0__{algo}__iter500"
    plot_particles(ax, particles[key], color=ALGO_COLORS.get(algo, "#333"), s=18)
    ax.set_title(algo, fontsize=9)
    ax.set_xlim(AXIS_LIM)
    ax.set_ylim(AXIS_LIM)
    ax.set_aspect("equal")
    if ax is not axes[0]:
        ax.tick_params(labelleft=False)

fig.subplots_adjust(wspace=0.08)
plt.show()

## 3. Convergence (faceted by algorithm)

One column per algorithm. Light gray traces of other algorithms for context.

In [None]:
def gather_metric(metric_name, algos=ALGOS):
    """Return {algo: (median, q25, q75)} arrays over checkpoints."""
    out = {}
    for algo in algos:
        vals = []
        for seed in SEEDS:
            row = [metrics.get(seed, {}).get(algo, {}).get(c, {}).get(metric_name, np.nan)
                   for c in CHECKPOINTS]
            vals.append(row)
        vals = np.array(vals)
        out[algo] = (
            np.nanmedian(vals, axis=0),
            np.nanpercentile(vals, 25, axis=0),
            np.nanpercentile(vals, 75, axis=0),
        )
    return out


# Evenly-spaced x positions with checkpoint labels (avoids log(0) problem)
X_POS = np.arange(len(CHECKPOINTS))
X_LABELS = [str(c) for c in CHECKPOINTS]

# Show a sparse subset of tick labels to avoid crowding
TICK_SHOW = {0, 2, 4, 6, 8}  # indices into CHECKPOINTS to label
X_LABELS_SPARSE = [str(c) if i in TICK_SHOW else "" for i, c in enumerate(CHECKPOINTS)]

In [None]:
metric_specs = [
    ("energy_distance", "Energy distance"),
    ("mean_error", "Mean error"),
]
n_metrics = len(metric_specs)
n_algos = len(ALGOS)

fig, axes = plt.subplots(
    n_metrics, n_algos,
    figsize=(FULL_WIDTH, 1.6 * n_metrics),
    sharex=True, sharey="row",
    constrained_layout=True,
)

for row, (metric_name, ylabel) in enumerate(metric_specs):
    data = gather_metric(metric_name)
    for col, algo in enumerate(ALGOS):
        ax = axes[row, col]

        # Gray context: all OTHER algorithms
        for other in ALGOS:
            if other == algo:
                continue
            med_o, _, _ = data[other]
            ax.plot(X_POS, med_o, color="#cccccc", linewidth=0.6, zorder=1)

        # Foreground: this algorithm's median + IQR
        median, q25, q75 = data[algo]
        color = ALGO_COLORS.get(algo, "#333")
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18, zorder=2)
        ax.plot(X_POS, median, color=color, linewidth=1.5, zorder=3)

        # Axis setup
        ax.set_xticks(X_POS)
        ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
        ax.tick_params(direction="out", length=3)

        if row == 0:
            ax.set_title(algo, fontsize=9)
        if col == 0:
            ax.set_ylabel(ylabel, fontsize=8)
        if row == n_metrics - 1:
            ax.set_xlabel("Iteration", fontsize=8)

plt.show()

## 4. Per-mode responsibility

For each particle $x_i$, compute the GMM posterior responsibility
$r_{ik} = p(k \mid x_i)$. Then the **mode share** $R_k = \frac{1}{N} \sum_i r_{ik}$
measures the soft fraction of mass on each mode.

Perfect sampling gives $R_k = 1/K = 0.25$ for all $k$. Mode imbalance shows up as
deviation from this line.

In [None]:
def compute_responsibilities(pts, target):
    """Compute per-mode responsibility shares for a particle set.

    Args:
        pts: Particle positions, shape (N, d).
        target: GMMTarget instance.

    Returns:
        Mode shares R_k, shape (K,).  Sums to 1.
    """
    # log p(k, x_i) = log w_k + log N(x_i; mu_k, sigma^2 I)
    log_comp = target._log_component_densities(jnp.array(pts))  # (N, K)
    # Responsibilities: r_ik = softmax over k
    r = np.array(jnp.exp(log_comp - logsumexp(log_comp, axis=1, keepdims=True)))  # (N, K)
    # Mode shares: average responsibility per mode
    return r.mean(axis=0)  # (K,)

In [None]:
# Compute responsibility shares at every checkpoint, for every seed and algo
resp_data = {}  # algo -> (n_seeds, n_checkpoints, K)
K = target.n_modes

for algo in ALGOS:
    all_seeds = []
    for seed_idx in range(len(SEEDS)):
        ckpt_shares = []
        for ckpt in CHECKPOINTS:
            key = f"seed{seed_idx}__{algo}__iter{ckpt}"
            if key in particles:
                shares = compute_responsibilities(particles[key], target)
            else:
                shares = np.full(K, np.nan)
            ckpt_shares.append(shares)
        all_seeds.append(ckpt_shares)
    resp_data[algo] = np.array(all_seeds)  # (n_seeds, n_checkpoints, K)

print(f"Shape per algo: {resp_data[ALGOS[0]].shape}")

In [None]:
# Max responsibility deviation |R_k - 1/K|, faceted by algorithm
ideal = 1.0 / K

fig, axes = plt.subplots(
    1, len(ALGOS),
    figsize=(FULL_WIDTH, 1.6),
    sharex=True, sharey=True,
    constrained_layout=True,
)

for col, algo in enumerate(ALGOS):
    ax = axes[col]

    # Gray context
    for other in ALGOS:
        if other == algo:
            continue
        devs_o = np.abs(resp_data[other] - ideal).max(axis=2)
        ax.plot(X_POS, np.nanmedian(devs_o, axis=0),
                color="#cccccc", linewidth=0.6, zorder=1)

    # Foreground
    devs = np.abs(resp_data[algo] - ideal).max(axis=2)  # (n_seeds, n_checkpoints)
    median = np.nanmedian(devs, axis=0)
    q25 = np.nanpercentile(devs, 25, axis=0)
    q75 = np.nanpercentile(devs, 75, axis=0)
    color = ALGO_COLORS.get(algo, "#333")
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18, zorder=2)
    ax.plot(X_POS, median, color=color, linewidth=1.5, zorder=3)

    ax.axhline(0, color="#999", linewidth=0.5, linestyle="--", zorder=0)
    ax.set_title(algo, fontsize=9)
    ax.set_xticks(X_POS)
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
    ax.tick_params(direction="out", length=3)
    ax.set_xlabel("Iteration", fontsize=8)
    if col == 0:
        ax.set_ylabel(r"Max $|R_k - 1/K|$", fontsize=8)

plt.show()

In [None]:
# Per-mode responsibility breakdown at final iteration, one panel per algo
fig, axes = plt.subplots(1, len(ALGOS), figsize=(FULL_WIDTH, 1.8), sharey=True)
mode_labels = [f"({modes[k,0]:+.0f},{modes[k,1]:+.0f})" for k in range(K)]
x_pos = np.arange(K)
bar_width = 0.6

for ax, algo in zip(axes, ALGOS):
    # Median responsibility across seeds at final checkpoint
    final_resp = resp_data[algo][:, -1, :]  # (n_seeds, K)
    median_r = np.nanmedian(final_resp, axis=0)
    q25_r = np.nanpercentile(final_resp, 25, axis=0)
    q75_r = np.nanpercentile(final_resp, 75, axis=0)

    color = ALGO_COLORS.get(algo, "#333")
    ax.bar(x_pos, median_r, bar_width, color=color, alpha=0.8,
           yerr=[median_r - q25_r, q75_r - median_r], capsize=2,
           error_kw={"linewidth": 0.8})
    ax.axhline(ideal, color="#999", linewidth=0.5, linestyle="--")
    ax.set_title(algo, fontsize=8)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(mode_labels, fontsize=6)
    if ax is axes[0]:
        ax.set_ylabel("Mode share $R_k$", fontsize=8)

fig.subplots_adjust(wspace=0.12)
plt.show()

## 5. Particle evolution snapshots

Show how particles evolve over time for ETD-B vs SVGD (seed 0).

In [None]:
snapshot_iters = [0, 10, 50, 200, 500]
compare_algos = ["ETD-B", "SVGD"]

fig, axes = plt.subplots(
    len(compare_algos), len(snapshot_iters),
    figsize=(FULL_WIDTH, 2.8),
    sharex=True, sharey=True,
)

for row, algo in enumerate(compare_algos):
    for col, it in enumerate(snapshot_iters):
        ax = axes[row, col]
        plot_contours(ax, target.log_prob, AXIS_LIM, AXIS_LIM)
        key = f"seed0__{algo}__iter{it}"
        if key in particles:
            plot_particles(ax, particles[key],
                           color=ALGO_COLORS.get(algo, "#333"), s=12)
        ax.set_xlim(AXIS_LIM)
        ax.set_ylim(AXIS_LIM)
        ax.set_aspect("equal")

        if row == 0:
            ax.set_title(f"iter {it}", fontsize=8)
        if col == 0:
            ax.set_ylabel(algo, fontsize=9)
        if col > 0:
            ax.tick_params(labelleft=False)
        if row < len(compare_algos) - 1:
            ax.tick_params(labelbottom=False)

fig.subplots_adjust(wspace=0.05, hspace=0.12)
plt.show()

## 6. Final metrics summary table

In [None]:
final_ckpt = max(CHECKPOINTS)
metric_names = ["energy_distance", "mode_coverage", "mean_error"]

print(f"{'Algorithm':<10}", end="")
for m in metric_names:
    print(f"  {m:<28}", end="")
print()
print("-" * 100)

for algo in ALGOS:
    print(f"{algo:<10}", end="")
    for m in metric_names:
        vals = [metrics[s].get(algo, {}).get(final_ckpt, {}).get(m, np.nan)
                for s in SEEDS]
        valid = [v for v in vals if not np.isnan(v)]
        if valid:
            print(f"  {np.median(valid):>8.4f}  (IQR {np.percentile(valid, 25):.4f}–{np.percentile(valid, 75):.4f})", end="")
        else:
            print(f"  {'N/A':>28}", end="")
    print()