diff --git a/src/ott/geometry/distrib_costs.py b/src/ott/geometry/distrib_costs.py index 590c8a13..ce24c959 100644 --- a/src/ott/geometry/distrib_costs.py +++ b/src/ott/geometry/distrib_costs.py @@ -72,7 +72,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: pointcloud.PointCloud( x[:, None], y[:, None], cost_fn=self.ground_cost ) - ) + ), **self._kwargs_solve ) return jnp.squeeze(out.ot_costs) diff --git a/src/ott/solvers/linear/univariate.py b/src/ott/solvers/linear/univariate.py index b7bab62e..868f4fdc 100644 --- a/src/ott/solvers/linear/univariate.py +++ b/src/ott/solvers/linear/univariate.py @@ -49,12 +49,11 @@ class UnivariateOutput(NamedTuple): # noqa: D101 ``0<=k jnp.ndarray: @@ -64,8 +63,8 @@ def transport_matrices(self) -> jnp.ndarray: non-zero values, out of ``dnm`` total entries. """ assert self.paired_indices is not None, \ - ("[d, n, m] tensor of transports cannot be computed, likely because an"+ - " approximate method was used (using either subsampling or quantiles).") + "[d, n, m] tensor of transports cannot be computed, likely because an" \ + " approximate method was used (using either subsampling or quantiles)." n, m = self.prob.geom.shape if self.prob.is_equal_size and self.prob.is_uniform: diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index 0fc9211c..2766e564 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -39,11 +39,11 @@ def initialize(self, rng: jax.Array): self.cy = jax.random.uniform(rngs[3], (self.m, self.m)) @pytest.mark.fast.with_args( - "cost_fn", + "ground_cost", [costs.SqEuclidean(), costs.PNormP(1.5)], only_fast=0, ) - def test_lb_pointcloud(self, cost_fn: costs.CostFn): + def test_lb_pointcloud(self, ground_cost: costs.TICost): x, y = self.x, self.y geom_x = pointcloud.PointCloud(x) @@ -51,7 +51,7 @@ def test_lb_pointcloud(self, cost_fn: costs.CostFn): prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, a=self.a, b=self.b ) - distrib_cost = distrib_costs.UnivariateWasserstein(cost_fn=cost_fn) + distrib_cost = distrib_costs.UnivariateWasserstein(ground_cost=ground_cost) solver = lower_bound.LowerBoundSolver( epsilon=1e-1, distrib_cost=distrib_cost )