In [1]:
import sys, os
from jax import config

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ['JAX_PLATFORM_NAME'] = 'cpu'     # use cpu backend. set 'gpu' or 'tpu' to use those backends
config.update("jax_enable_x64", True)
sys.path.insert(0, os.path.abspath(os.path.join("..")))

import time
import numpy as np
import jax.numpy as jnp
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from uot.problems.generators import GaussianMixtureGenerator
from uot.utils.costs import cost_euclid_squared
from uot.experiments.runner import run_pipeline
from uot.experiments.experiment import Experiment
from uot.problems.iterator import OnlineProblemIterator
from uot.solvers.solver_config import SolverConfig


from uot.solvers.back_and_forth import BackNForthSqEuclideanSolver
from uot.solvers.back_and_forth.forward_pushforward import _forward_pushforward_nd
from uot.solvers.back_and_forth.pushforward import adaptive_pushforward_nd
from uot.utils.metrics.pushforward_map_metrics import extra_grid_metrics


In [2]:
from uot.utils.logging import logger
logger.setLevel("WARNING")
from uot.problems.iterator import logger
logger.setLevel("WARNING")

In [3]:
def _grid_coords_from_axes(axes):
    return jnp.stack(jnp.meshgrid(*axes, indexing="ij"), axis=-1)


def _monge_map_index_to_physical(monge_map, axes):
    arr = jnp.asarray(monge_map)
    spatial_shape = tuple(len(ax) for ax in axes)
    d = len(spatial_shape)
    if arr.ndim == len(spatial_shape):
        arr = arr[..., None]
    if arr.shape[0] == d and arr.ndim == len(spatial_shape) + 1:
        arr = jnp.moveaxis(arr, 0, -1)
    elif arr.shape[-1] != d:
        arr = arr.reshape(spatial_shape + (d,))
    spacings = jnp.array([float(ax[1] - ax[0]) if ax.shape[0] > 1 else 1.0 for ax in axes], dtype=arr.dtype)
    origins = jnp.array([float(ax[0]) for ax in axes], dtype=arr.dtype)
    reshape = (1,) * len(spatial_shape) + (d,)
    return origins.reshape(reshape) + arr * spacings.reshape(reshape)
monge_metric_columns = [
    "tv_mu_to_nu",
    "ma_residual_L1",
    "ma_residual_Linf",
    "detJ_min",
    "detJ_max",
    "detJ_neg_frac",
    "phi_is_convex",
]
plot_columns = [
    "tv_mu_to_nu",
    "ma_residual_L1",
    "ma_residual_Linf",
    "detJ_neg_frac",
    "phi_is_convex",
]
def solve_fn(problem, solver_instance, measures, costs, *args, **kwargs):
    start = time.perf_counter()
    results = solver_instance.solve(
        marginals=measures,
        costs=costs,
        **kwargs,
    )
    stop = time.perf_counter()
    results["runtime"] = stop - start

    # Monge-map diagnostics (avoid re-loading tensors later)
    mu_measure, nu_measure = measures
    axes_mu, mu_nd = mu_measure.for_grid_solver(backend="jax", dtype=jnp.float64)
    _, nu_nd = nu_measure.for_grid_solver(backend="jax", dtype=jnp.float64)

    psi = jnp.asarray(results["v_final"]).reshape(mu_nd.shape)
    pushforward_fn = getattr(solver_instance, "_pushforward_fn", adaptive_pushforward_nd)
    pushforward_mu, _ = pushforward_fn(mu_nd, -psi)

    T_phys = _monge_map_index_to_physical(results["monge_map"], axes_mu)
    X = _grid_coords_from_axes(axes_mu)

    monge_metrics = extra_grid_metrics(
        mu_nd=mu_nd,
        nu_nd=nu_nd,
        axes_mu=axes_mu,
        X=X,
        T=T_phys,
        pushforward_mu=pushforward_mu,
    )
    results.update({key: float(val) for key, val in monge_metrics.items()})
    return results


exp = Experiment(
    name="Testing Back-and-Forth Solver",
    solve_fn=solve_fn,
)

tolerances = [1e-4]
max_iterations = [3000]
stepsizes = [4]
pushforward_fns = [
    adaptive_pushforward_nd,
    _forward_pushforward_nd,
]

solver_param_grid = [
    {
        "pushforward_fn": pushforward_fn,
        "maxiter": maxiter,
        "tol": tol,
        "stepsize": stepsize,
        "error_metric": 'tv_psi',
        "stepsize_lower_bound": 0.01,
    }
    for tol in tolerances
    for maxiter in max_iterations
    for stepsize in stepsizes
    for pushforward_fn in pushforward_fns
]

solvers = [
    SolverConfig(
        name="Back-and-Forth SqEuclid",
        solver=BackNForthSqEuclideanSolver,
        param_grid=solver_param_grid,
        is_jit=True,
    )
]

seed = 55
n_problems = 22
n_points = 96
# keep list of datasets for visualizing
dataset_config = dict(
    name=f"Gaussian (1d, 1c, {n_points}p)",
    dim=1,
    num_components=1,
    n_points=n_points,
    num_datasets=n_problems,
    borders=(0, 1),
    cost_fn=cost_euclid_squared,
    use_jax=False,
    seed=seed,
    measure_mode="grid",
    cell_discretization="cell-centered",
)

datasets = [
    GaussianMixtureGenerator(**dataset_config)
]
iterators = [
    OnlineProblemIterator(
        GaussianMixtureGenerator(**dataset_config),
        num=n_problems,
        cache_gt=False,
    ),
]

results = run_pipeline(
    experiment=exp,
    solvers=solvers,
    iterators=iterators,
    folds=1,
    progress=True,
)


Back-and-Forth SqEuclid({'pushforward_fn': <function _forward_pushforward_nd at 0x1503d3920>, 'maxiter': 3000, 'tol': 0.0001, 'stepsize': 4, 'error_metric': 'tv_psi', 'stepsize_lower_bound': 0.01}): 100%|██████████| 44/44 [00:17<00:00,  2.45it/s]


In [4]:
results

Unnamed: 0,dataset,mu_size,nu_size,cost,monge_map,u_final,v_final,iterations,error,marginal_error_L2,...,condition_number_hessian_eigenvalues,status,problem_index,pushforward_fn,maxiter,tol,stepsize,error_metric,stepsize_lower_bound,name
0,"Gaussian (1d, 1c, 96p)",96,96,0.123188616979258,"[[0.0], [0.0], [4.547473508864641e-13], [2.728...","[6.364474586822717, 6.3608531220002416, 6.3573...","[-6.364474586822717, -6.364420333350495, -6.36...",3000,0.0850776540526388,0.0522437345410599,...,1.0,success,0,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
1,"Gaussian (1d, 1c, 96p)",96,96,0.0153164803703886,"[[0.0], [0.8429500250173305], [1.6344766578772...","[0.9921170219531585, 0.9921080100970691, 0.992...","[-0.9921170219531585, -0.9921080100970691, -0....",3000,0.0451021316885604,0.0282516591813824,...,1.0,success,1,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
2,"Gaussian (1d, 1c, 96p)",96,96,0.3032514970210719,"[[50.147547573038594], [51.00258327247185], [5...","[-0.22883118385048776, -0.22877693037826555, -...","[0.22883118385048776, 0.22338982669195492, 0.2...",3000,0.0132939202065036,0.0062738245700334,...,1.0,success,2,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
3,"Gaussian (1d, 1c, 96p)",96,96,0.0692698234337009,"[[0.0], [0.9387183340113552], [1.8664470134649...","[4.160493300826453, 4.160489633392994, 4.16048...","[-4.160493300826453, -4.160489633392994, -4.16...",3000,0.0870963463366521,0.0422748364805746,...,1.0,success,3,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
4,"Gaussian (1d, 1c, 96p)",96,96,0.0060072524542026,"[[0.0], [0.0], [7.105427357601002e-14], [4.263...","[0.14246476870558283, 0.1418941776979528, 0.14...","[-0.14246476870558283, -0.14241051523336062, -...",3000,0.0066023897628958,0.0029164791941311,...,1.0,success,4,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
5,"Gaussian (1d, 1c, 96p)",96,96,0.0864331831475218,"[[27.81273152018639], [28.7796031906023], [30....","[4.714616275759662, 4.714670529231884, 4.71483...","[-4.714616275759662, -4.717634150273571, -4.72...",3000,0.0899924706987553,0.0492454321080498,...,1.0,success,5,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
6,"Gaussian (1d, 1c, 96p)",96,96,0.0146301004136373,"[[0.0], [2.2737367544323206e-13], [0.0], [0.0]...","[0.7060024590436, 0.7050766357922894, 0.704108...","[-0.7060024590436, -0.7059482055713778, -0.705...",3000,0.0605580664140797,0.0255288879359205,...,1.0,success,6,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
7,"Gaussian (1d, 1c, 96p)",96,96,0.0183185282586694,"[[1.863952422809234e-05], [1.0000361386077543]...","[0.8811394929438994, 0.8811394949664172, 0.881...","[-0.8811394929438994, -0.8811394949664172, -0....",3000,0.0472355164554016,0.022491317090383,...,1.0,success,7,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
8,"Gaussian (1d, 1c, 96p)",96,96,0.0183036519089373,"[[9.738214152043497], [11.113323064343774], [1...","[0.9920415712983728, 0.992095824770595, 0.9922...","[-0.9920415712983728, -0.9930982351603567, -0....",3000,0.053035629755181,0.023719311013881,...,1.0,success,8,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid
9,"Gaussian (1d, 1c, 96p)",96,96,0.0164936832697028,"[[0.0], [0.9999539869854175], [1.9998893881104...","[0.8117303072814761, 0.8117303047015637, 0.811...","[-0.8117303072814761, -0.8117303047015637, -0....",3000,0.0487618697889521,0.0222470928409024,...,1.0,success,9,<function adaptive_pushforward_nd at 0x1503d32e0>,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid


In [5]:
if "results" not in globals():
    raise RuntimeError("Run the solver pipeline cell first to populate `results`.")

successful = results[results["status"] == "success"].copy()
if successful.empty:
    raise RuntimeError("Solver did not return any successful runs to visualize.")

def _format_pushforward_fn(value):
    if isinstance(value, str):
        return value
    name = getattr(value, "__name__", None)
    return name or str(value)

successful["pushforward_fn"] = successful["pushforward_fn"].apply(_format_pushforward_fn)

param_columns = ["tol", "maxiter", "stepsize", "pushforward_fn"]
available_settings = (
    successful[param_columns]
    .drop_duplicates()
    .sort_values(param_columns)
    .reset_index(drop=True)
)

print("Available solver settings (use `setting_index` below):")
print(available_settings.assign(setting_index=available_settings.index))

setting_index = 1  # change this index to explore another configuration
if setting_index >= len(available_settings):
    raise IndexError(
        f"setting_index {setting_index} is out of range for {len(available_settings)} available settings"
    )

selected_setting = available_settings.iloc[setting_index]
setting_mask = (
    np.isclose(successful["tol"], selected_setting["tol"])
    & (successful["maxiter"] == selected_setting["maxiter"])
    & np.isclose(successful["stepsize"], selected_setting["stepsize"])
    & (successful["pushforward_fn"] == selected_setting["pushforward_fn"])
)
selected_results = successful[setting_mask].copy()
if selected_results.empty:
    raise RuntimeError(
        "Selected configuration has no matching successful runs. "
        "Adjust `setting_index` or rerun the pipeline."
    )

selected_setting_label = (
    f"tol={selected_setting['tol']:.1e}, "
    f"maxiter={int(selected_setting['maxiter'])}, "
    f"stepsize={selected_setting['stepsize']}, "
    f"pushforward_fn={selected_setting['pushforward_fn']}"
)
print(f"Using setting #{setting_index}: {selected_setting_label}")


Available solver settings (use `setting_index` below):
      tol  maxiter  stepsize           pushforward_fn  setting_index
0  0.0001     3000         4  _forward_pushforward_nd              0
1  0.0001     3000         4  adaptive_pushforward_nd              1
Using setting #1: tol=1.0e-04, maxiter=3000, stepsize=4, pushforward_fn=adaptive_pushforward_nd


In [6]:
visual_generator = GaussianMixtureGenerator(**dataset_config)
visual_iterator = visual_generator.generate()
problems_for_visuals = [next(visual_iterator) for _ in range(n_problems)]


def _measure_to_arrays(measure):
    points, weights = measure.to_discrete(include_zeros=True)
    points = np.asarray(points)
    weights = np.asarray(weights)
    if points.ndim == 2 and points.shape[1] == 1:
        points = points[:, 0]
    return points.reshape(-1), weights.reshape(-1)


def _sort_curve(x, *ys):
    order = np.argsort(x)
    sorted_x = np.asarray(x)[order]
    sorted_arrays = [np.asarray(arr)[order] for arr in ys]
    return sorted_x, sorted_arrays


def _to_float(value):
    if value is None:
        return float('nan')
    if isinstance(value, (int, float)):
        return float(value)
    if hasattr(value, 'item'):
        try:
            return float(value.item())
        except (TypeError, ValueError):
            return float('nan')
    try:
        return float(value)
    except (TypeError, ValueError):
        return float('nan')


records = (
    selected_results.sort_values("problem_index")
    .drop_duplicates(subset="problem_index")
    .to_dict("records")
)
visual_payload = []
for row in records:
    idx = int(row["problem_index"])
    if idx >= len(problems_for_visuals):
        continue
    mu_measure, nu_measure = problems_for_visuals[idx].get_marginals()
    mu_x, mu_w = _measure_to_arrays(mu_measure)
    nu_x, nu_w = _measure_to_arrays(nu_measure)

    _, mu_nd = mu_measure.for_grid_solver(backend="jax", dtype=jnp.float64)
    psi = jnp.asarray(row["v_final"]).reshape(mu_nd.shape)
    pushforward_nd, _ = adaptive_pushforward_nd(mu_nd, -psi)
    pushforward_flat = np.asarray(pushforward_nd).reshape(-1)

    monge_map = np.asarray(row["monge_map"])
    if monge_map.ndim == 1:
        monge_map = monge_map[:, None]
    monge_vals = monge_map.reshape(-1, monge_map.shape[-1])[:, 0]

    mu_x_sorted, (mu_w_sorted, monge_sorted) = _sort_curve(mu_x, mu_w, monge_vals)
    nu_x_sorted, (nu_w_sorted, pushforward_sorted) = _sort_curve(
        nu_x, nu_w, pushforward_flat
    )

    visual_payload.append(
        {
            "problem_index": idx,
            "mu_x": mu_x_sorted,
            "mu_w": mu_w_sorted,
            "nu_x": nu_x_sorted,
            "nu_w": nu_w_sorted,
            "pushforward": pushforward_sorted,
            "monge": monge_sorted,
            "identity": mu_x_sorted.copy(),
            "iterations": int(row["iterations"]),
            "error": _to_float(row["error"]),
            "runtime": _to_float(row["runtime"]),
            "marginal_error_l2": _to_float(row.get("marginal_error_L2", np.nan)),
            "cost": _to_float(row.get("cost", np.nan)),
        }
    )

visual_payload.sort(key=lambda item: item["problem_index"])
if not visual_payload:
    raise RuntimeError("No overlapping problems between generated data and solver results.")

first_entry = visual_payload[0]
grid_min = min(np.min(entry["mu_x"]) for entry in visual_payload)
grid_max = max(np.max(entry["nu_x"]) for entry in visual_payload)
coord_margin = 0.05 * (grid_max - grid_min if grid_max > grid_min else 1.0)
value_min = min(np.min(entry["monge"]) for entry in visual_payload) - coord_margin
value_max = max(np.max(entry["monge"]) for entry in visual_payload) + coord_margin

global_density_max = max(
    max(np.max(entry["mu_w"]), np.max(entry["nu_w"]), np.max(entry["pushforward"]))
    for entry in visual_payload
)
density_margin = 0.05 * global_density_max if global_density_max > 0 else 0.1

fig = make_subplots(
    rows=2,
    cols=1,
    shared_xaxes=True,
    row_heights=[0.6, 0.4],
    vertical_spacing=0.08,
)

fig.add_trace(
    go.Scatter(x=first_entry["mu_x"], y=first_entry["mu_w"], name="μ (source)", mode="lines"),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(x=first_entry["nu_x"], y=first_entry["nu_w"], name="ν (target)", mode="lines"),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(x=first_entry["nu_x"], y=first_entry["pushforward"], name="T#μ", mode="lines", line=dict(dash="dash")),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(x=first_entry["mu_x"], y=first_entry["monge"], name="T(x)", mode="lines"),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(x=first_entry["mu_x"], y=first_entry["identity"], name="Identity", mode="lines", line=dict(dash="dot")),
    row=2,
    col=1,
)

fig.update_yaxes(title_text="Density", row=1, col=1, range=[0, global_density_max + density_margin])
fig.update_yaxes(title_text="Mapped position", row=2, col=1, range=[value_min, value_max])
fig.update_xaxes(title_text="Grid coordinate", row=2, col=1, range=[grid_min - coord_margin, grid_max + coord_margin])

fig.update_layout(
    height=720,
    width=960,
    title=f"Back-and-Forth transport — {selected_setting_label}",
    legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=1.0),
)

frames = []
for entry in visual_payload:
    frames.append(
        go.Frame(
            name=str(entry["problem_index"]),
            data=[
                go.Scatter(x=entry["mu_x"], y=entry["mu_w"]),
                go.Scatter(x=entry["nu_x"], y=entry["nu_w"]),
                go.Scatter(x=entry["nu_x"], y=entry["pushforward"]),
                go.Scatter(x=entry["mu_x"], y=entry["monge"]),
                go.Scatter(x=entry["mu_x"], y=entry["identity"]),
            ],
        )
    )

fig.frames = frames
fig.update_layout(
    sliders=[
        {
            "active": 0,
            "currentvalue": {"prefix": "Problem index: "},
            "steps": [
                {
                    "args": [[frame.name], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                    "label": frame.name,
                    "method": "animate",
                }
                for frame in fig.frames
            ],
        }
    ],
)
fig.update_layout(
    updatemenus=[
        {
            "type": "buttons",
            "showactive": False,
            "direction": "left",
            "pad": {"r": 10, "t": 50},
            "y": 1.2,
            "x": 0.0,
            "buttons": [
                {"label": "Play", "method": "animate", "args": [None, {"fromcurrent": True}]},
                {"label": "Pause", "method": "animate", "args": [[None], {"mode": "immediate"}]},
            ],
        }
    ],
)

fig.show()


In [7]:
metrics_columns = ["iterations", "marginal_error_l2", "runtime", "cost"]
metrics_df = pd.DataFrame(
    [{col: entry[col] for col in metrics_columns} for entry in visual_payload]
)
metrics_summary = metrics_df.agg(["min", "max", "mean", "median"]).round(4)
metrics_summary.index.name = None

header_labels = ["stat"] + metrics_columns
cell_values = [metrics_summary.index.tolist()] + [metrics_summary[col].tolist() for col in metrics_columns]

stats_fig = go.Figure(
    data=[
        go.Table(
            header=dict(values=header_labels, fill_color="#1f77b4", font=dict(color="white")),
            cells=dict(values=cell_values, fill_color="#f5f5f5"),
        )
    ]
)
stats_fig.update_layout(title=f"Solver run statistics — {selected_setting_label}")
stats_fig.show()


In [8]:
settings_overview = (
    successful[param_columns]
    .drop_duplicates()
    .assign(setting_label=lambda df: df.apply(
        lambda row: f"tol={row['tol']:.1e}, maxiter={int(row['maxiter'])}, stepsize={row['stepsize']}, pf={row['pushforward_fn']}", axis=1
    ))
)
label_map = {
    (row.tol, row.maxiter, row.stepsize, row.pushforward_fn): row.setting_label
    for row in settings_overview.itertuples(index=False)
}

successful_with_labels = successful.copy()
successful_with_labels["setting_label"] = successful_with_labels.apply(
    lambda row: label_map[(row["tol"], row["maxiter"], row["stepsize"], row["pushforward_fn"])],
    axis=1,
)

successful_with_labels

Unnamed: 0,dataset,mu_size,nu_size,cost,monge_map,u_final,v_final,iterations,error,marginal_error_L2,...,status,problem_index,pushforward_fn,maxiter,tol,stepsize,error_metric,stepsize_lower_bound,name,setting_label
0,"Gaussian (1d, 1c, 96p)",96,96,0.123188616979258,"[[0.0], [0.0], [4.547473508864641e-13], [2.728...","[6.364474586822717, 6.3608531220002416, 6.3573...","[-6.364474586822717, -6.364420333350495, -6.36...",3000,0.0850776540526388,0.0522437345410599,...,success,0,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
1,"Gaussian (1d, 1c, 96p)",96,96,0.0153164803703886,"[[0.0], [0.8429500250173305], [1.6344766578772...","[0.9921170219531585, 0.9921080100970691, 0.992...","[-0.9921170219531585, -0.9921080100970691, -0....",3000,0.0451021316885604,0.0282516591813824,...,success,1,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
2,"Gaussian (1d, 1c, 96p)",96,96,0.3032514970210719,"[[50.147547573038594], [51.00258327247185], [5...","[-0.22883118385048776, -0.22877693037826555, -...","[0.22883118385048776, 0.22338982669195492, 0.2...",3000,0.0132939202065036,0.0062738245700334,...,success,2,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
3,"Gaussian (1d, 1c, 96p)",96,96,0.0692698234337009,"[[0.0], [0.9387183340113552], [1.8664470134649...","[4.160493300826453, 4.160489633392994, 4.16048...","[-4.160493300826453, -4.160489633392994, -4.16...",3000,0.0870963463366521,0.0422748364805746,...,success,3,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
4,"Gaussian (1d, 1c, 96p)",96,96,0.0060072524542026,"[[0.0], [0.0], [7.105427357601002e-14], [4.263...","[0.14246476870558283, 0.1418941776979528, 0.14...","[-0.14246476870558283, -0.14241051523336062, -...",3000,0.0066023897628958,0.0029164791941311,...,success,4,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
5,"Gaussian (1d, 1c, 96p)",96,96,0.0864331831475218,"[[27.81273152018639], [28.7796031906023], [30....","[4.714616275759662, 4.714670529231884, 4.71483...","[-4.714616275759662, -4.717634150273571, -4.72...",3000,0.0899924706987553,0.0492454321080498,...,success,5,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
6,"Gaussian (1d, 1c, 96p)",96,96,0.0146301004136373,"[[0.0], [2.2737367544323206e-13], [0.0], [0.0]...","[0.7060024590436, 0.7050766357922894, 0.704108...","[-0.7060024590436, -0.7059482055713778, -0.705...",3000,0.0605580664140797,0.0255288879359205,...,success,6,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
7,"Gaussian (1d, 1c, 96p)",96,96,0.0183185282586694,"[[1.863952422809234e-05], [1.0000361386077543]...","[0.8811394929438994, 0.8811394949664172, 0.881...","[-0.8811394929438994, -0.8811394949664172, -0....",3000,0.0472355164554016,0.022491317090383,...,success,7,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
8,"Gaussian (1d, 1c, 96p)",96,96,0.0183036519089373,"[[9.738214152043497], [11.113323064343774], [1...","[0.9920415712983728, 0.992095824770595, 0.9922...","[-0.9920415712983728, -0.9930982351603567, -0....",3000,0.053035629755181,0.023719311013881,...,success,8,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."
9,"Gaussian (1d, 1c, 96p)",96,96,0.0164936832697028,"[[0.0], [0.9999539869854175], [1.9998893881104...","[0.8117303072814761, 0.8117303047015637, 0.811...","[-0.8117303072814761, -0.8117303047015637, -0....",3000,0.0487618697889521,0.0222470928409024,...,success,9,adaptive_pushforward_nd,3000,0.0001,4,tv_psi,0.01,Back-and-Forth SqEuclid,"tol=1.0e-04, maxiter=3000, stepsize=4, pf=adap..."


In [31]:
metric_titles = {
    "iterations": "Iterations",
    "error": "Solver Error",
    "marginal_error_L2": "Marginal Error L2",
    "runtime": "Runtime (s)",
    "cost": "Cost",
}

monge_metric_titles = {
    "tv_mu_to_nu": "TV distance",
    "ma_residual_L1": "MA residual L1",
    "ma_residual_Linf": "MA residual Linf",
    "detJ_neg_frac": "det(J) < 0 fraction",
    "phi_is_convex": "Convexity fraction",
}

max_cols = max(len(metric_titles), len(monge_metric_titles))
subplot_titles = list(metric_titles.values()) + list(monge_metric_titles.values())

dist_fig = make_subplots(
    rows=2,
    cols=max_cols,
    horizontal_spacing=0.05,
    vertical_spacing=0.30,
    subplot_titles=subplot_titles,
)

for idx, (metric, title) in enumerate(metric_titles.items(), start=1):
    series = successful_with_labels[metric].apply(
        lambda val: float(val.item()) if hasattr(val, 'item') else float(val)
    )
    metric_mask = series.notna()
    x_vals = ['<br>'.join(setting.split(', ')) for setting in successful_with_labels.loc[metric_mask, "setting_label"]]
    y_vals = series[metric_mask]
    dist_fig.add_trace(
        go.Violin(
            x=x_vals,
            y=y_vals,
            name=title,
            legendgroup="general",
            box_visible=True,
            meanline_visible=True,
            opacity=0.9,
            points='all',
            showlegend=False,
        ),
        row=1,
        col=idx,
    )
    dist_fig.update_xaxes(tickangle=45, row=1, col=idx)

for idx, (metric, title) in enumerate(monge_metric_titles.items(), start=1):
    if metric not in successful_with_labels.columns:
        continue
    series = successful_with_labels[metric].apply(
        lambda val: float(val.item()) if hasattr(val, 'item') else float(val)
    )
    metric_mask = series.notna()
    x_vals = ['<br>'.join(setting.split(', ')) for setting in successful_with_labels.loc[metric_mask, "setting_label"]]
    y_vals = series[metric_mask]
    dist_fig.add_trace(
        go.Violin(
            x=x_vals,
            y=y_vals,
            name=title,
            legendgroup="monge",
            box_visible=True,
            meanline_visible=True,
            opacity=0.9,
            points='all',
            showlegend=False,
        ),
        row=2,
        col=idx,
    )
    dist_fig.update_xaxes(tickangle=45, row=2, col=idx)

dist_fig.update_layout(
    title="Performance & Monge-map statistics across solver settings",
    height=1100,
    width=300 * max_cols,
)
dist_fig.show()


In [10]:
missing_cols = [col for col in monge_metric_columns if col not in selected_results.columns]
if missing_cols:
    raise RuntimeError(f"Missing Monge metric columns in results: {missing_cols}")

monge_metrics_df_1d = (
    selected_results[["problem_index", "pushforward_fn"] + monge_metric_columns]
    .dropna(subset=monge_metric_columns, how="all")
    .reset_index(drop=True)
)

if monge_metrics_df_1d.empty:
    raise RuntimeError("Selected results do not contain Monge diagnostics.")

monge_summary_1d = monge_metrics_df_1d[monge_metric_columns].agg(["min", "max", "mean", "median"]).round(4)
monge_table_1d = go.Figure(
    data=[
        go.Table(
            header=dict(values=["stat"] + monge_metric_columns, fill_color="#1f77b4", font=dict(color="white")),
            cells=dict(values=[monge_summary_1d.index.tolist()] + [monge_summary_1d[col].tolist() for col in monge_metric_columns], fill_color="#f5f5f5"),
        )
    ]
)
monge_table_1d.update_layout(title="Monge-map diagnostics (1D) — summary")
monge_table_1d.show()


## 2D Gaussian transport experiments

Repeat the same evaluation pipeline on 2D tensor grids to compare pushforward implementations on a richer setting.


In [11]:
seed_2d = 101
n_problems_2d = 12
n_points_2d = 96

# 2D dataset configuration (uniform tensors on [0, 1]^2)
dataset_config_2d = dict(
    name=f"Gaussian (2d, 2c, {n_points_2d}x{n_points_2d})",
    dim=2,
    num_components=2,
    n_points=n_points_2d,
    num_datasets=n_problems_2d,
    borders=(0, 1),
    cost_fn=cost_euclid_squared,
    use_jax=False,
    seed=seed_2d,
    measure_mode="grid",
    cell_discretization="cell-centered",
)

datasets_2d = [
    GaussianMixtureGenerator(**dataset_config_2d)
]
iterators_2d = [
    OnlineProblemIterator(
        GaussianMixtureGenerator(**dataset_config_2d),
        num=n_problems_2d,
        cache_gt=False,
    ),
]

tolerances = [1e-3]
max_iterations = [500]
stepsizes = [1]
pushforward_fns = [
    adaptive_pushforward_nd,
    _forward_pushforward_nd,
]

solver_param_grid = [
    {
        "pushforward_fn": pushforward_fn,
        "maxiter": maxiter,
        "tol": tol,
        "stepsize": stepsize,
        "error_metric": 'h1_psi_relative',
        "stepsize_lower_bound": 0.01,
    }
    for tol in tolerances
    for maxiter in max_iterations
    for stepsize in stepsizes
    for pushforward_fn in pushforward_fns
]

solvers = [
    SolverConfig(
        name="Back-and-Forth SqEuclid",
        solver=BackNForthSqEuclideanSolver,
        param_grid=solver_param_grid,
        is_jit=True,
    )
]

results_2d = run_pipeline(
    experiment=exp,
    solvers=solvers,
    iterators=iterators_2d,
    folds=1,
    progress=True,
)


Back-and-Forth SqEuclid({'pushforward_fn': <function _forward_pushforward_nd at 0x1503d3920>, 'maxiter': 500, 'tol': 0.001, 'stepsize': 1, 'error_metric': 'h1_psi_relative', 'stepsize_lower_bound': 0.01}): 100%|██████████| 24/24 [00:47<00:00,  1.96s/it]


In [12]:
if "results_2d" not in globals():
    raise RuntimeError("Run the 2D pipeline cell to populate `results_2d`.")

successful_2d = results_2d[results_2d["status"] == "success"].copy()
if successful_2d.empty:
    raise RuntimeError("2D solver did not return any successful runs to visualize.")

successful_2d["pushforward_fn"] = successful_2d["pushforward_fn"].apply(_format_pushforward_fn)

available_settings_2d = (
    successful_2d[param_columns]
    .drop_duplicates()
    .sort_values(param_columns)
    .reset_index(drop=True)
)

print("Available 2D solver settings (use `setting_index_2d` below):")
print(available_settings_2d.assign(setting_index=available_settings_2d.index))

setting_index_2d = 1
if setting_index_2d >= len(available_settings_2d):
    raise IndexError(
        f"setting_index_2d {setting_index_2d} is out of range for {len(available_settings_2d)} available settings"
    )

selected_setting_2d = available_settings_2d.iloc[setting_index_2d]
setting_mask_2d = (
    np.isclose(successful_2d["tol"], selected_setting_2d["tol"])
    & (successful_2d["maxiter"] == selected_setting_2d["maxiter"])
    & np.isclose(successful_2d["stepsize"], selected_setting_2d["stepsize"])
    & (successful_2d["pushforward_fn"] == selected_setting_2d["pushforward_fn"])
)
selected_results_2d = successful_2d[setting_mask_2d].copy()
if selected_results_2d.empty:
    raise RuntimeError("Selected 2D configuration has no matching successful runs.")

selected_setting_label_2d = (
    f"tol={selected_setting_2d['tol']:.1e}, "
    f"maxiter={int(selected_setting_2d['maxiter'])}, "
    f"stepsize={selected_setting_2d['stepsize']}, "
    f"pushforward_fn={selected_setting_2d['pushforward_fn']}"
)
print(f"Using 2D setting #{setting_index_2d}: {selected_setting_label_2d}")


Available 2D solver settings (use `setting_index_2d` below):
     tol  maxiter  stepsize           pushforward_fn  setting_index
0  0.001      500         1  _forward_pushforward_nd              0
1  0.001      500         1  adaptive_pushforward_nd              1
Using 2D setting #1: tol=1.0e-03, maxiter=500, stepsize=1, pushforward_fn=adaptive_pushforward_nd


In [24]:
visual_generator_2d = GaussianMixtureGenerator(**dataset_config_2d)
visual_iterator_2d = visual_generator_2d.generate()
problems_for_visuals_2d = [next(visual_iterator_2d) for _ in range(n_problems_2d)]


def _measure_to_grid_2d(measure):
    axes, grid = measure.for_grid_solver(backend="jax", dtype=jnp.float64)
    axes_np = [np.asarray(ax) for ax in axes]
    return axes_np, np.asarray(grid)


def _monge_map_displacement(row, axes_np):
    monge = np.asarray(row["monge_map"])
    spatial_shape = tuple(len(ax) for ax in axes_np)
    d = len(spatial_shape)
    monge = monge.reshape((*spatial_shape, d))
    spacings = np.array([float(ax[1] - ax[0]) if ax.shape[0] > 1 else 1.0 for ax in axes_np])
    origins = np.array([float(ax[0]) for ax in axes_np])
    coords = np.stack(np.meshgrid(*axes_np, indexing="ij"), axis=-1)
    monge_phys = origins + monge * spacings
    displacement = monge_phys - coords
    return coords, np.linalg.norm(displacement, axis=-1)


def _to_float(value):
    if value is None:
        return float("nan")
    if isinstance(value, (int, float)):
        return float(value)
    if hasattr(value, "item"):
        try:
            return float(value.item())
        except (TypeError, ValueError):
            return float("nan")
    try:
        return float(value)
    except (TypeError, ValueError):
        return float("nan")


records_2d = (
    selected_results_2d.sort_values("problem_index")
    .drop_duplicates(subset="problem_index")
    .to_dict("records")
)
visual_payload_2d = []
for row in records_2d:
    idx = int(row["problem_index"])
    if idx >= len(problems_for_visuals_2d):
        continue
    mu_measure, nu_measure = problems_for_visuals_2d[idx].get_marginals()
    axes_mu, mu_grid = _measure_to_grid_2d(mu_measure)
    _, nu_grid = _measure_to_grid_2d(nu_measure)

    _, mu_nd = mu_measure.for_grid_solver(backend="jax", dtype=jnp.float64)
    psi = jnp.asarray(row["v_final"]).reshape(mu_nd.shape)
    pushforward_nd, _ = adaptive_pushforward_nd(mu_nd, -psi)

    _, displacement_norm = _monge_map_displacement(row, axes_mu)

    visual_payload_2d.append(
        {
            "problem_index": idx,
            "x_axis": axes_mu[1],
            "y_axis": axes_mu[0],
            "mu_grid": mu_grid,
            "nu_grid": nu_grid,
            "pushforward_grid": np.asarray(pushforward_nd),
            "displacement_norm": displacement_norm,
            "iterations": int(row["iterations"]),
            "error": _to_float(row["error"]),
            "runtime": _to_float(row["runtime"]),
            "marginal_error_l2": _to_float(row.get("marginal_error_L2", np.nan)),
            "cost": _to_float(row.get("cost", np.nan)),
        }
    )

visual_payload_2d.sort(key=lambda item: item["problem_index"])
if not visual_payload_2d:
    raise RuntimeError("No overlapping 2D problems between generated data and solver results.")

first_entry_2d = visual_payload_2d[0]
density_max_2d = max(
    max(
        np.max(entry["mu_grid"]),
        np.max(entry["nu_grid"]),
        np.max(entry["pushforward_grid"]),
    )
    for entry in visual_payload_2d
)
displacement_max_2d = max(np.max(entry["displacement_norm"]) for entry in visual_payload_2d)

fig_2d = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=["μ (source)", "ν (target)", "T#μ", "‖T(x) - x‖"],
    horizontal_spacing=0.05,
    vertical_spacing=0.08,
    specs=[[{"type": "heatmap"}, {"type": "heatmap"}],
           [{"type": "heatmap"}, {"type": "heatmap"}]],
)

fig_2d.add_trace(
    go.Heatmap(z=first_entry_2d["mu_grid"], x=first_entry_2d["x_axis"], y=first_entry_2d["y_axis"], coloraxis="coloraxis"),
    row=1,
    col=1,
)
fig_2d.add_trace(
    go.Heatmap(z=first_entry_2d["nu_grid"], x=first_entry_2d["x_axis"], y=first_entry_2d["y_axis"], coloraxis="coloraxis"),
    row=1,
    col=2,
)
fig_2d.add_trace(
    go.Heatmap(z=first_entry_2d["pushforward_grid"], x=first_entry_2d["x_axis"], y=first_entry_2d["y_axis"], coloraxis="coloraxis"),
    row=2,
    col=1,
)
fig_2d.add_trace(
    go.Heatmap(z=first_entry_2d["displacement_norm"], x=first_entry_2d["x_axis"], y=first_entry_2d["y_axis"], coloraxis="coloraxis2"),
    row=2,
    col=2,
)

fig_2d.update_layout(
    coloraxis=dict(
        colorscale="Viridis",
        cmin=0,
        cmax=density_max_2d,
        colorbar=dict(title="Density", x=-0.18, len=0.8),
    ),
    coloraxis2=dict(
        colorscale="Magma",
        cmin=0,
        cmax=displacement_max_2d,
        colorbar=dict(title="‖T(x) - x‖", x=1.05, len=0.8),
    ),
)

frames_2d = []
for entry in visual_payload_2d:
    frames_2d.append(
        go.Frame(
            name=str(entry["problem_index"]),
            data=[
                go.Heatmap(z=entry["mu_grid"], x=entry["x_axis"], y=entry["y_axis"], coloraxis="coloraxis"),
                go.Heatmap(z=entry["nu_grid"], x=entry["x_axis"], y=entry["y_axis"], coloraxis="coloraxis"),
                go.Heatmap(z=entry["pushforward_grid"], x=entry["x_axis"], y=entry["y_axis"], coloraxis="coloraxis"),
                go.Heatmap(z=entry["displacement_norm"], x=entry["x_axis"], y=entry["y_axis"], coloraxis="coloraxis2"),
            ],
        )
    )

fig_2d.frames = frames_2d
fig_2d.update_layout(
    sliders=[
        {
            "active": 0,
            "currentvalue": {"prefix": "Problem index: "},
            "steps": [
                {
                    "args": [[frame.name], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}],
                    "label": frame.name,
                    "method": "animate",
                }
                for frame in fig_2d.frames
            ],
        }
    ],
    updatemenus=[
        {
            "type": "buttons",
            "showactive": False,
            "direction": "left",
            "pad": {"r": 10, "t": 70},
            "y": 1.2,
            "x": 0.0,
            "buttons": [
                {"label": "Play", "method": "animate", "args": [None, {"fromcurrent": True}]},
                {"label": "Pause", "method": "animate", "args": [[None], {"mode": "immediate"}]},
            ],
        }
    ],
)
fig_2d.update_layout(
    title=f"2D Back-and-Forth transport — {selected_setting_label_2d}",
    height=900,
    width=1000,
)
fig_2d.show()


In [14]:
metrics_columns = ["iterations", "marginal_error_l2", "runtime", "cost"]
metrics_df_2d = pd.DataFrame(
    [{col: entry[col] for col in metrics_columns} for entry in visual_payload_2d]
)
metrics_summary_2d = metrics_df_2d.agg(["min", "max", "mean", "median"]).round(4)
metrics_summary_2d.index.name = None

header_labels = ["stat"] + metrics_columns
cell_values = [metrics_summary_2d.index.tolist()] + [metrics_summary_2d[col].tolist() for col in metrics_columns]

stats_fig_2d = go.Figure(
    data=[
        go.Table(
            header=dict(values=header_labels, fill_color="#1f77b4", font=dict(color="white")),
            cells=dict(values=cell_values, fill_color="#f5f5f5"),
        )
    ]
)
stats_fig_2d.update_layout(title=f"2D solver run statistics — {selected_setting_label_2d}")
stats_fig_2d.show()


In [15]:
settings_overview_2d = (
    successful_2d[param_columns]
    .drop_duplicates()
    .assign(setting_label=lambda df: df.apply(
        lambda row: f"tol={row['tol']:.1e}, maxiter={int(row['maxiter'])}, stepsize={row['stepsize']}, pf={row['pushforward_fn']}", axis=1
    ))
)
label_map_2d = {
    (row.tol, row.maxiter, row.stepsize, row.pushforward_fn): row.setting_label
    for row in settings_overview_2d.itertuples(index=False)
}

successful_with_labels_2d = successful_2d.copy()
successful_with_labels_2d["setting_label"] = successful_with_labels_2d.apply(
    lambda row: label_map_2d[(row["tol"], row["maxiter"], row["stepsize"], row["pushforward_fn"])],
    axis=1,
)

metric_titles = {
    "iterations": "Iterations",
    "error": "Solver Error",
    "marginal_error_L2": "Marginal Error L2",
    "runtime": "Runtime (s)",
    "cost": "Cost",
}

dist_fig_2d = make_subplots(
    rows=1,
    cols=len(metric_titles),
    horizontal_spacing=0.05,
    subplot_titles=list(metric_titles.values()),
)

for idx, (metric, title) in enumerate(metric_titles.items(), start=1):
    series = successful_with_labels_2d[metric].apply(
        lambda val: float(val.item()) if hasattr(val, "item") else float(val)
    )
    metric_mask = series.notna()
    x_vals = ['<br>'.join(setting.split(', ')) for setting in successful_with_labels_2d.loc[metric_mask, "setting_label"]]
    y_vals = series[metric_mask]
    showlegend = idx == 1
    dist_fig_2d.add_trace(
        go.Violin(
            x=x_vals,
            y=y_vals,
            name="Violin",
            legendgroup="violin",
            box_visible=True,
            meanline_visible=True,
            opacity=0.8,
            showlegend=showlegend,
            points='all',
        ),
        row=1,
        col=idx,
    )

dist_fig_2d.update_layout(
    title="2D performance distributions across solver settings",
)
dist_fig_2d.show()


In [16]:
missing_cols_2d = [col for col in monge_metric_columns if col not in selected_results_2d.columns]
if missing_cols_2d:
    raise RuntimeError(f"Missing Monge metric columns in 2D results: {missing_cols_2d}")

monge_metrics_df_2d = (
    selected_results_2d[["problem_index", "pushforward_fn"] + monge_metric_columns]
    .dropna(subset=monge_metric_columns, how="all")
    .reset_index(drop=True)
)

if monge_metrics_df_2d.empty:
    raise RuntimeError("Selected 2D results do not contain Monge diagnostics.")

monge_summary_2d = monge_metrics_df_2d[monge_metric_columns].agg(["min", "max", "mean", "median"]).round(4)
monge_table_2d = go.Figure(
    data=[
        go.Table(
            header=dict(values=["stat"] + monge_metric_columns, fill_color="#1f77b4", font=dict(color="white")),
            cells=dict(values=[monge_summary_2d.index.tolist()] + [monge_summary_2d[col].tolist() for col in monge_metric_columns], fill_color="#f5f5f5"),
        )
    ]
)
monge_table_2d.update_layout(title="Monge-map diagnostics (2D) — summary")
monge_table_2d.show()


In [32]:
metric_titles = {
    "iterations": "Iterations",
    "error": "Solver Error",
    "marginal_error_L2": "Marginal Error L2",
    "runtime": "Runtime (s)",
    "cost": "Cost",
}

monge_metric_titles = {
    "tv_mu_to_nu": "TV distance",
    "ma_residual_L1": "MA residual L1",
    "ma_residual_Linf": "MA residual Linf",
    "detJ_neg_frac": "det(J) < 0 fraction",
    "phi_is_convex": "Convexity fraction",
}

max_cols = max(len(metric_titles), len(monge_metric_titles))
subplot_titles = list(metric_titles.values()) + list(monge_metric_titles.values())

dist_fig_2d = make_subplots(
    rows=2,
    cols=max_cols,
    horizontal_spacing=0.05,
    vertical_spacing=0.30,
    subplot_titles=subplot_titles,
)

for idx, (metric, title) in enumerate(metric_titles.items(), start=1):
    series = successful_with_labels_2d[metric].apply(
        lambda val: float(val.item()) if hasattr(val, 'item') else float(val)
    )
    metric_mask = series.notna()
    x_vals = ['<br>'.join(setting.split(', ')) for setting in successful_with_labels_2d.loc[metric_mask, "setting_label"]]
    y_vals = series[metric_mask]
    dist_fig_2d.add_trace(
        go.Violin(
            x=x_vals,
            y=y_vals,
            name=title,
            legendgroup="general",
            box_visible=True,
            meanline_visible=True,
            opacity=0.9,
            points='all',
            showlegend=False,
        ),
        row=1,
        col=idx,
    )
    dist_fig_2d.update_xaxes(tickangle=45, row=1, col=idx)

for idx, (metric, title) in enumerate(monge_metric_titles.items(), start=1):
    if metric not in successful_with_labels_2d.columns:
        continue
    series = successful_with_labels_2d[metric].apply(
        lambda val: float(val.item()) if hasattr(val, 'item') else float(val)
    )
    metric_mask = series.notna()
    x_vals = ['<br>'.join(setting.split(', ')) for setting in successful_with_labels_2d.loc[metric_mask, "setting_label"]]
    y_vals = series[metric_mask]
    dist_fig_2d.add_trace(
        go.Violin(
            x=x_vals,
            y=y_vals,
            name=title,
            legendgroup="monge",
            box_visible=True,
            meanline_visible=True,
            opacity=0.9,
            points='all',
            showlegend=False,
        ),
        row=2,
        col=idx,
    )
    dist_fig_2d.update_xaxes(tickangle=45, row=2, col=idx)

dist_fig_2d.update_layout(
    title="2D performance & Monge-map statistics across solver settings",
    height=1100,
    width=300 * max_cols,
)
dist_fig_2d.show()
