In [2]:
import sys, os
from jax import config

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ['JAX_PLATFORM_NAME'] = 'cpu'     # use cpu backend. set 'gpu' or 'tpu' to use those backends
config.update("jax_enable_x64", True)
sys.path.insert(0, os.path.abspath(os.path.join("..")))

import time
import numpy as np
import jax.numpy as jnp
import pandas as pd
import plotly.graph_objects as go

from uot.problems.generators import GaussianMixtureGenerator
from uot.utils.costs import cost_euclid_squared
from uot.experiments.runner import run_pipeline
from uot.experiments.experiment import Experiment
from uot.problems.iterator import OnlineProblemIterator
from uot.solvers.solver_config import SolverConfig


from uot.solvers.back_and_forth import BackNForthSqEuclideanSolver

In [None]:
def solve_fn(problem, solver_instance, measures, costs, *args, **kwargs):
    start = time.perf_counter()
    results = solver_instance.solve(
        marginals=measures,
        costs=costs,
        **kwargs,
    )
    stop = time.perf_counter()
    results["runtime"] = stop - start
    return results


exp = Experiment(
    name="Testing Back-and-Forth Solver",
    solve_fn=solve_fn,
)

solvers = [
    SolverConfig(
        name="Back-and-Forth SqEuclid",
        solver=BackNForthSqEuclideanSolver,
        param_grid=[
            {
                "maxiter": 1000,
                "tol": 1e-4,
                "stepsize": 4,
                # "error_metric": 'h1_psi',
                "error_metric": 'h1_psi_relative',
                "stepsize_lower_bound": 0.01,
            },
        ],
        is_jit=True,
    )
]

seed = 55
n_problems = 8
n_points = 128
iterators = [
    OnlineProblemIterator(
        GaussianMixtureGenerator(
            name=f"Gaussian (1d, 1c, {n_points}p)",
            dim=1,
            num_components=1,
            n_points=n_points,
            num_datasets=n_problems,
            borders=(0, 1),
            cost_fn=cost_euclid_squared,
            use_jax=False,
            seed=seed,
            measure_mode="grid",
            cell_discretization="cell-centered",
        ),
        num=n_problems,
        cache_gt=False,
    ),
]

results = run_pipeline(
    experiment=exp,
    solvers=solvers,
    iterators=iterators,
    folds=1,
    progress=True,
)

2025-11-07 15:33:58,234 uot INFO: starting pipeline...


Running experiments:   0%|          | 0/8 [00:00<?, ?it/s]

2025-11-07 15:33:58,304 uot INFO: Running experiments...
2025-11-07 15:34:01,066 uot.problems.iterator INFO: Generated problem <TwoMarginalProblem[Gaussian (1d, 1c, 128p)] 128x128        with (<map object at 0x7f0d04305210>)>
2025-11-07 15:34:01,066 uot.problems.iterator INFO: Generated problem <TwoMarginalProblem[Gaussian (1d, 1c, 128p)] 128x128        with (<map object at 0x7f0d04305210>)>
2025-11-07 15:34:01,333 uot INFO: Starting BackNForthSqEuclideanSolver with {'maxiter': 1000, 'tol': 0.0001, 'stepsize': 4, 'error_metric': 'h1_psi_relative', 'stepsize_lower_bound': 0.01} on <TwoMarginalProblem[Gaussian (1d, 1c, 128p)] 128x128        with (<map object at 0x7f0d0421c040>)>
