# High-Dimensional GMM Benchmark Results

Analysis and visualization of ETD variants vs baselines on Gaussian mixture
targets in $d > 2$. Unlike the 2D notebook, we cannot rely on particle-on-contour
plots. Instead we use:

1. **Metric convergence** — energy distance, sliced Wasserstein, mode proximity, mode balance, mean error
2. **Mode responsibility analysis** — soft mode shares $R_k$ via posterior $p(k \mid x)$
3. **PCA projections** — 2D embeddings of particle clouds at key iterations
4. **Per-dimension marginals** — KDE overlays vs true 1D marginals
5. **Distance-to-nearest-mode** — geometric diagnostic of particle proximity to modes
6. **Summary table** — final metrics at a glance

In [None]:
import json
import sys
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax.scipy.special import logsumexp
from scipy.stats import gaussian_kde

# Project imports
ROOT = Path.cwd().parent
sys.path.insert(0, str(ROOT / "src"))
sys.path.insert(0, str(ROOT))

from etd.targets.gmm import GMMTarget
from figures.style import (
    ALGO_COLORS, FULL_WIDTH, COL_WIDTH,
    facet_grid, frame_panel, ref_line,
    savefig_paper, setup_style,
)

setup_style()
%config InlineBackend.figure_format = 'retina'

## 1. Load data

Point `EXPERIMENT_NAME` at any $d > 2$ GMM experiment. Everything else is
derived from the saved `config.yaml`.

In [None]:
import yaml

# --- Configure which experiment to load ---
EXPERIMENT_NAME = "gmm-10d-5"

results_root = ROOT / "results" / EXPERIMENT_NAME
runs = sorted(results_root.iterdir())
RESULTS_DIR = runs[-1]  # latest run
print(f"Loading from: {RESULTS_DIR}")

# --- Load config (source of truth) ---
with open(RESULTS_DIR / "config.yaml") as f:
    exp_config = yaml.safe_load(f)

exp = exp_config["experiment"]
target_cfg = exp["target"]
target_params = target_cfg.get("params", {})

# --- Load metrics ---
with open(RESULTS_DIR / "metrics.json") as f:
    raw_metrics = json.load(f)

metrics = {}
for seed, algo_dict in raw_metrics.items():
    metrics[seed] = {}
    for algo, ckpt_dict in algo_dict.items():
        metrics[seed][algo] = {int(k): v for k, v in ckpt_dict.items()}

# --- Load particles ---
particles = dict(np.load(RESULTS_DIR / "particles.npz"))

# --- Load reference samples (if available) ---
ref_path = RESULTS_DIR / "reference.npz"
reference_samples = None
if ref_path.exists():
    ref_data = dict(np.load(ref_path))
    reference_samples = ref_data.get("samples", None)
    if reference_samples is not None:
        print(f"Reference samples: {reference_samples.shape}")

# --- Build target from saved config ---
target = GMMTarget(**target_params)
modes = np.array(target.means)
DIM = target.dim
K = target.n_modes
print(f"Target: {target_cfg['type']}, dim={DIM}, K={K}")
print(f"Mode centers (first 3 dims):\n{modes[:, :3]}")

# --- Constants ---
SEEDS = [f"seed{i}" for i in exp["seeds"]]
CHECKPOINTS = sorted(set(
    int(k) for algo_dict in metrics.values()
    for ckpt_dict in algo_dict.values()
    for k in ckpt_dict
))
ALL_ALGOS = sorted(set(
    algo for algo_dict in metrics.values() for algo in algo_dict
))
ALGOS = ALL_ALGOS  # no need to filter for HD experiments (yet)
N_ALGOS = len(ALGOS)

# Evenly-spaced x-positions for convergence plots
X_POS = np.arange(len(CHECKPOINTS))
X_LABELS = [str(c) for c in CHECKPOINTS]
TICK_SHOW = set(range(0, len(CHECKPOINTS), max(1, len(CHECKPOINTS) // 5)))
X_LABELS_SPARSE = [str(c) if i in TICK_SHOW else "" for i, c in enumerate(CHECKPOINTS)]

print(f"Checkpoints: {CHECKPOINTS}")
print(f"Algorithms: {ALGOS}")
print(f"Seeds: {len(SEEDS)}")

## 2. Metric convergence (faceted by algorithm)

Each column is one algorithm; each row is one metric. The foreground line
is the median over seeds, with IQR shading. Gray traces show the other
algorithms for context.

In [None]:
def gather_metric(metric_name, algos=None):
    """Return {algo: (median, q25, q75)} arrays over checkpoints."""
    if algos is None:
        algos = ALGOS
    out = {}
    for algo in algos:
        vals = []
        for seed in SEEDS:
            row = [
                metrics.get(seed, {}).get(algo, {}).get(c, {}).get(metric_name, np.nan)
                for c in CHECKPOINTS
            ]
            vals.append(row)
        vals = np.array(vals)
        out[algo] = (
            np.nanmedian(vals, axis=0),
            np.nanpercentile(vals, 25, axis=0),
            np.nanpercentile(vals, 75, axis=0),
        )
    return out

In [None]:
metric_specs = [
    ("energy_distance", "Energy dist."),
    ("sliced_wasserstein", "Sliced W₂"),
    ("mode_proximity", "Mode proximity"),
    ("mode_balance", "Mode balance"),
    ("mean_error", "Mean error"),
]
n_metrics = len(metric_specs)

fig, axes = plt.subplots(
    n_metrics, N_ALGOS,
    figsize=(FULL_WIDTH, 1.4 * n_metrics),
    sharex=True, sharey="row",
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[:, np.newaxis]

for row, (metric_name, ylabel) in enumerate(metric_specs):
    data = gather_metric(metric_name)
    for col, algo in enumerate(ALGOS):
        ax = axes[row, col]

        # Gray context: other algorithms
        for other in ALGOS:
            if other == algo:
                continue
            med_o, _, _ = data[other]
            ax.plot(X_POS, med_o, color="#cccccc", linewidth=0.6, zorder=1)

        # Foreground: this algorithm
        median, q25, q75 = data[algo]
        color = ALGO_COLORS.get(algo, "#333")
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18, zorder=2)
        ax.plot(X_POS, median, color=color, linewidth=1.5, zorder=3)

        ax.set_xticks(X_POS)
        ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
        ax.tick_params(direction="out", length=3)

        if row == 0:
            ax.set_title(algo, fontsize=9)
        if col == 0:
            ax.set_ylabel(ylabel, fontsize=8)
        if row == n_metrics - 1:
            ax.set_xlabel("Iteration", fontsize=8)

plt.show()

## 3. Convergence overlay (all algorithms on one axis)

A single-panel view per metric for direct comparison. Useful for talks and
quick inspection.

In [None]:
fig, axes = plt.subplots(
    1, n_metrics,
    figsize=(FULL_WIDTH, 1.8),
    constrained_layout=True,
)

for ax, (metric_name, ylabel) in zip(axes, metric_specs):
    data = gather_metric(metric_name)
    for algo in ALGOS:
        median, q25, q75 = data[algo]
        color = ALGO_COLORS.get(algo, "#333")
        ax.fill_between(X_POS, q25, q75, color=color, alpha=0.10)
        ax.plot(X_POS, median, color=color, linewidth=1.5, label=algo)

    ax.set_xticks(X_POS)
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
    ax.set_ylabel(ylabel, fontsize=8)
    ax.set_xlabel("Iteration", fontsize=8)
    ax.tick_params(direction="out", length=3)

# Single legend below the figure
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center", ncol=N_ALGOS, fontsize=7,
           bbox_to_anchor=(0.5, -0.08), frameon=False)
plt.show()

## 4. Mode responsibility analysis

For each particle $x_i$, we compute the GMM posterior responsibility
$r_{ik} = p(k \mid x_i)$ and average to get the **mode share**
$R_k = \frac{1}{N} \sum_i r_{ik}$.

Perfect sampling gives $R_k = 1/K$ for all $k$. Mode collapse shows up as
deviation from this uniform line. This diagnostic is fully dimension-agnostic.

In [None]:
def compute_responsibilities(pts, target):
    """Soft mode shares R_k, shape (K,). Sums to 1."""
    log_comp = target._log_component_densities(jnp.array(pts))  # (N, K)
    r = np.array(jnp.exp(log_comp - logsumexp(log_comp, axis=1, keepdims=True)))
    return r.mean(axis=0)


# Compute for all seeds / algorithms / checkpoints
resp_data = {}  # algo -> (n_seeds, n_checkpoints, K)
for algo in ALGOS:
    all_seeds = []
    for seed_idx in range(len(SEEDS)):
        ckpt_shares = []
        for ckpt in CHECKPOINTS:
            key = f"seed{seed_idx}__{algo}__iter{ckpt}"
            if key in particles:
                shares = compute_responsibilities(particles[key], target)
            else:
                shares = np.full(K, np.nan)
            ckpt_shares.append(shares)
        all_seeds.append(ckpt_shares)
    resp_data[algo] = np.array(all_seeds)

print(f"Responsibility shape per algo: {resp_data[ALGOS[0]].shape}")

In [None]:
# Mode responsibility heatmap: rows = modes, columns = checkpoints,
# one panel per algorithm. Color = median R_k across seeds.
ideal = 1.0 / K

fig, axes = plt.subplots(
    1, N_ALGOS,
    figsize=(FULL_WIDTH, 1.5 + 0.12 * K),
    sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = [axes]

for ax, algo in zip(axes, ALGOS):
    median_r = np.nanmedian(resp_data[algo], axis=0).T  # (K, n_checkpoints)
    im = ax.imshow(
        median_r, aspect="auto", cmap="RdBu_r",
        vmin=0, vmax=2 * ideal,
        interpolation="nearest",
    )
    ax.set_xticks(range(len(CHECKPOINTS)))
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
    ax.set_title(algo, fontsize=9)
    ax.set_xlabel("Iteration", fontsize=8)
    if ax is axes[0]:
        ax.set_yticks(range(K))
        ax.set_yticklabels([f"Mode {k+1}" for k in range(K)], fontsize=7)
        ax.set_ylabel("Mode", fontsize=8)

fig.colorbar(im, ax=axes, label=f"Mode share $R_k$ (ideal = {ideal:.2f})",
             shrink=0.8, pad=0.02)
plt.show()

In [None]:
# Max responsibility deviation |R_k - 1/K|, faceted by algorithm
fig, axes = facet_grid(N_ALGOS, panel_size=1.6, square=False)

for col, (ax, algo) in enumerate(zip(axes, ALGOS)):
    # Gray context
    for other in ALGOS:
        if other == algo:
            continue
        devs_o = np.abs(resp_data[other] - ideal).max(axis=2)
        ax.plot(X_POS, np.nanmedian(devs_o, axis=0),
                color="#cccccc", linewidth=0.6, zorder=1)

    # Foreground
    devs = np.abs(resp_data[algo] - ideal).max(axis=2)
    median = np.nanmedian(devs, axis=0)
    q25 = np.nanpercentile(devs, 25, axis=0)
    q75 = np.nanpercentile(devs, 75, axis=0)
    color = ALGO_COLORS.get(algo, "#333")
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.18, zorder=2)
    ax.plot(X_POS, median, color=color, linewidth=1.5, zorder=3)

    ref_line(ax, 0)
    ax.set_title(algo, fontsize=9)
    ax.set_xticks(X_POS)
    ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
    ax.tick_params(direction="out", length=3)
    if col == 0:
        ax.set_ylabel(r"Max $|R_k - 1/K|$", fontsize=8)

fig.supxlabel("Iteration", fontsize=8)
plt.show()

In [None]:
# Per-mode responsibility at the final iteration
fig, axes = plt.subplots(
    1, N_ALGOS,
    figsize=(FULL_WIDTH, 1.8),
    sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = [axes]

mode_labels = [str(k + 1) for k in range(K)]
x_pos = np.arange(K)

for ax, algo in zip(axes, ALGOS):
    final_resp = resp_data[algo][:, -1, :]  # (n_seeds, K)
    median_r = np.nanmedian(final_resp, axis=0)
    q25_r = np.nanpercentile(final_resp, 25, axis=0)
    q75_r = np.nanpercentile(final_resp, 75, axis=0)

    color = ALGO_COLORS.get(algo, "#333")
    ax.errorbar(
        x_pos, median_r,
        yerr=[median_r - q25_r, q75_r - median_r],
        fmt="o", color=color, markersize=4,
        markeredgecolor="white", markeredgewidth=0.5,
        capsize=2, capthick=0.8, linewidth=0.8, zorder=5,
    )
    ref_line(ax, ideal)
    ax.set_title(algo, fontsize=8)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(mode_labels, fontsize=6)
    if ax is axes[0]:
        ax.set_ylabel("Mode share $R_k$", fontsize=8)

fig.supxlabel("Mode", fontsize=8)
plt.show()

## 5. PCA projections

Project the $d$-dimensional particles onto the top-2 principal components
fitted on the true mode centers. This gives a consistent coordinate system
across algorithms and iterations. True mode centers are shown as $\times$ markers.

In [None]:
from sklearn.decomposition import PCA

# Fit PCA on mode centers (or reference samples if available)
pca = PCA(n_components=2)
if reference_samples is not None:
    pca.fit(reference_samples)
else:
    pca.fit(modes)

modes_2d = pca.transform(modes)
print(f"Explained variance ratio: {pca.explained_variance_ratio_}")
print(f"Mode centers in PC space:\n{modes_2d}")

In [None]:
# PCA snapshot grid: rows = algorithms, columns = iterations (seed 0)
n_snaps = min(5, len(CHECKPOINTS))
snap_indices = np.linspace(0, len(CHECKPOINTS) - 1, n_snaps, dtype=int)
snapshot_iters = [CHECKPOINTS[i] for i in snap_indices]

fig, axes = plt.subplots(
    N_ALGOS, n_snaps,
    figsize=(FULL_WIDTH, 1.2 * N_ALGOS + 0.3),
    sharex=True, sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[np.newaxis, :]

for row, algo in enumerate(ALGOS):
    for col, it in enumerate(snapshot_iters):
        ax = axes[row, col]
        key = f"seed0__{algo}__iter{it}"
        if key in particles:
            pts_2d = pca.transform(particles[key])
            color = ALGO_COLORS.get(algo, "#333")
            ax.scatter(pts_2d[:, 0], pts_2d[:, 1], c=color, s=10,
                       alpha=0.6, edgecolors="none")

        # True mode centers
        ax.scatter(modes_2d[:, 0], modes_2d[:, 1], marker="x",
                   c="black", s=40, linewidths=1.2, zorder=10)

        if row == 0:
            ax.set_title(f"iter {it}", fontsize=8)
        if col == 0:
            ax.set_ylabel(algo, fontsize=8)
        frame_panel(ax)

fig.suptitle("PCA projection (seed 0)", fontsize=10, y=1.02)
plt.show()

## 6. Per-dimension marginal distributions

Compare the empirical marginal distribution of particles in each coordinate
to the true 1D marginal (a mixture of 1D Gaussians). Shown at the final
iteration, aggregated over seeds.

We show a selected subset of dimensions: the first two (where modes live in
the ring arrangement) plus a higher-index dimension (which should be standard
normal for a ring GMM).

In [None]:
def true_marginal_pdf(x_1d, dim_idx, target):
    """Evaluate the 1D marginal density for the given coordinate.

    For an isotropic GMM: marginal in dim j is a 1D mixture of Gaussians
    with means target.means[:, j] and std = target.component_std.
    """
    means_j = np.array(target.means[:, dim_idx])  # (K,)
    weights = np.exp(np.array(target.log_weights))  # (K,)
    sigma = target.component_std

    # (len(x_1d), K)
    z = (x_1d[:, None] - means_j[None, :]) / sigma
    component_pdfs = np.exp(-0.5 * z ** 2) / (sigma * np.sqrt(2 * np.pi))
    return (weights[None, :] * component_pdfs).sum(axis=1)


# Dimensions to show
show_dims = [0, 1]  # modes live here for ring arrangement
if DIM > 2:
    show_dims.append(DIM // 2)  # a mid-range dim
if DIM > 4:
    show_dims.append(DIM - 1)  # the last dim

n_dims_show = len(show_dims)
final_ckpt = max(CHECKPOINTS)

fig, axes = plt.subplots(
    n_dims_show, N_ALGOS,
    figsize=(FULL_WIDTH, 1.3 * n_dims_show),
    sharex="row", sharey="row",
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = axes[:, np.newaxis]

for row, dim_idx in enumerate(show_dims):
    # True marginal density
    mode_extent = float(np.abs(modes[:, dim_idx]).max())
    x_range = np.linspace(-(mode_extent + 4), mode_extent + 4, 300)
    true_pdf = true_marginal_pdf(x_range, dim_idx, target)

    for col, algo in enumerate(ALGOS):
        ax = axes[row, col]

        # True marginal
        ax.fill_between(x_range, true_pdf, alpha=0.12, color="#333")
        ax.plot(x_range, true_pdf, color="#333", linewidth=0.8,
                linestyle="--", label="True" if row == 0 and col == 0 else None)

        # Empirical KDE across all seeds
        all_vals = []
        for seed_idx in range(len(SEEDS)):
            key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
            if key in particles:
                all_vals.append(particles[key][:, dim_idx])
        if all_vals:
            pooled = np.concatenate(all_vals)
            color = ALGO_COLORS.get(algo, "#333")
            try:
                kde = gaussian_kde(pooled)
                ax.plot(x_range, kde(x_range), color=color, linewidth=1.2)
            except np.linalg.LinAlgError:
                pass  # KDE can fail if all particles collapsed

        if row == 0:
            ax.set_title(algo, fontsize=9)
        if col == 0:
            ax.set_ylabel(f"dim {dim_idx}", fontsize=8)
        if row == n_dims_show - 1:
            ax.set_xlabel("Value", fontsize=7)
        ax.tick_params(labelsize=6)
        ax.set_yticks([])

plt.show()

## 7. Distance to nearest mode

For each particle, compute the Euclidean distance to its nearest mode center.
The distribution of these distances reveals whether particles are tightly
concentrated around modes (desirable) or drifting in empty space.

For $d$-dimensional isotropic Gaussians with std $\sigma$, the distance to
the mode follows a chi distribution with median $\approx \sigma \sqrt{d - 2/3}$.

In [None]:
def nearest_mode_distances(pts, modes):
    """Euclidean distance from each particle to its nearest mode.

    Args:
        pts: (N, d) particle positions.
        modes: (K, d) mode centers.

    Returns:
        (N,) distances.
    """
    # (N, K)
    dists = np.sqrt(((pts[:, None, :] - modes[None, :, :]) ** 2).sum(axis=-1))
    return dists.min(axis=1)


# Expected distance for the chi distribution
from scipy.stats import chi
expected_median = target.component_std * chi.median(DIM)

In [None]:
# Violin plot: distance-to-nearest-mode at final iteration, by algorithm
fig, ax = plt.subplots(figsize=(COL_WIDTH * 1.5, 2.2))

violin_data = []
labels = []
colors = []

for algo in ALGOS:
    dists = []
    for seed_idx in range(len(SEEDS)):
        key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
        if key in particles:
            dists.append(nearest_mode_distances(particles[key], modes))
    if dists:
        violin_data.append(np.concatenate(dists))
        labels.append(algo)
        colors.append(ALGO_COLORS.get(algo, "#333"))

# Reference: distances from exact samples
if reference_samples is not None:
    ref_dists = nearest_mode_distances(reference_samples[:500], modes)
    violin_data.append(ref_dists)
    labels.append("Reference")
    colors.append("#333")

parts = ax.violinplot(violin_data, showmedians=True, showextrema=False)

for i, pc in enumerate(parts["bodies"]):
    pc.set_facecolor(colors[i])
    pc.set_alpha(0.5)
parts["cmedians"].set_color("black")
parts["cmedians"].set_linewidth(1)

ax.axhline(expected_median, color="#999", linewidth=0.8, linestyle="--",
           label=f"Chi median ({expected_median:.1f})")
ax.set_xticks(range(1, len(labels) + 1))
ax.set_xticklabels(labels, fontsize=7, rotation=20, ha="right")
ax.set_ylabel("Distance to nearest mode", fontsize=8)
ax.legend(fontsize=7, frameon=False)
ax.set_title(f"Iteration {final_ckpt}", fontsize=9)
plt.show()

In [None]:
# Median distance-to-nearest-mode over iterations, by algorithm
fig, ax = plt.subplots(figsize=(COL_WIDTH * 1.5, 2.2))

for algo in ALGOS:
    medians_per_seed = []
    for seed_idx in range(len(SEEDS)):
        row = []
        for ckpt in CHECKPOINTS:
            key = f"seed{seed_idx}__{algo}__iter{ckpt}"
            if key in particles:
                d = nearest_mode_distances(particles[key], modes)
                row.append(np.median(d))
            else:
                row.append(np.nan)
        medians_per_seed.append(row)

    arr = np.array(medians_per_seed)
    med = np.nanmedian(arr, axis=0)
    q25 = np.nanpercentile(arr, 25, axis=0)
    q75 = np.nanpercentile(arr, 75, axis=0)
    color = ALGO_COLORS.get(algo, "#333")
    ax.fill_between(X_POS, q25, q75, color=color, alpha=0.12)
    ax.plot(X_POS, med, color=color, linewidth=1.5, label=algo)

ax.axhline(expected_median, color="#999", linewidth=0.8, linestyle="--")
ax.set_xticks(X_POS)
ax.set_xticklabels(X_LABELS_SPARSE, fontsize=6)
ax.set_xlabel("Iteration", fontsize=8)
ax.set_ylabel("Median dist. to nearest mode", fontsize=8)
ax.legend(fontsize=7, frameon=False, loc="upper right")
plt.show()

## 8. Per-dimension variance ratio

The ratio of empirical particle variance to the true mixture variance in
each coordinate. A ratio of 1.0 everywhere indicates correct marginal
variance. Ratios $< 1$ indicate variance collapse; $> 1$ indicates
over-dispersion.

In [None]:
true_var = np.array(target.variance)  # (d,)

fig, axes = plt.subplots(
    1, N_ALGOS,
    figsize=(FULL_WIDTH, 1.8),
    sharey=True,
    constrained_layout=True,
)
if N_ALGOS == 1:
    axes = [axes]

dim_labels = [str(d) for d in range(DIM)]

for ax, algo in zip(axes, ALGOS):
    ratios_per_seed = []
    for seed_idx in range(len(SEEDS)):
        key = f"seed{seed_idx}__{algo}__iter{final_ckpt}"
        if key in particles:
            emp_var = np.var(particles[key], axis=0)
            ratio = emp_var / np.where(true_var > 1e-12, true_var, 1.0)
            ratios_per_seed.append(ratio)

    if ratios_per_seed:
        arr = np.array(ratios_per_seed)
        median_r = np.nanmedian(arr, axis=0)
        q25 = np.nanpercentile(arr, 25, axis=0)
        q75 = np.nanpercentile(arr, 75, axis=0)

        color = ALGO_COLORS.get(algo, "#333")
        x = np.arange(DIM)
        ax.errorbar(
            x, median_r,
            yerr=[median_r - q25, q75 - median_r],
            fmt="o", color=color, markersize=3,
            markeredgecolor="white", markeredgewidth=0.4,
            capsize=2, capthick=0.8, linewidth=0.8, zorder=5,
        )

    ref_line(ax, 1.0)
    ax.set_title(algo, fontsize=9)
    ax.set_xticks(range(DIM))
    ax.set_xticklabels(dim_labels, fontsize=6)
    if ax is axes[0]:
        ax.set_ylabel("Var ratio (emp / true)", fontsize=8)

fig.supxlabel("Dimension", fontsize=8)
plt.show()

## 9. Summary table

Final-iteration metrics, median (IQR) across seeds.

In [None]:
table_metrics = ["energy_distance", "sliced_wasserstein", "mode_proximity", "mode_balance", "mean_error"]

# Header
header = f"{'Algorithm':<12}"
for m in table_metrics:
    header += f"  {m:<28}"
print(header)
print("-" * len(header))

for algo in ALGOS:
    line = f"{algo:<12}"
    for m in table_metrics:
        vals = [
            metrics[s].get(algo, {}).get(final_ckpt, {}).get(m, np.nan)
            for s in SEEDS
        ]
        valid = [v for v in vals if not np.isnan(v)]
        if valid:
            med = np.median(valid)
            q25 = np.percentile(valid, 25)
            q75 = np.percentile(valid, 75)
            line += f"  {med:>8.4f} ({q25:.4f}\u2013{q75:.4f})  "
        else:
            line += f"  {'N/A':>28}"
    print(line)