Skip to content

Commit

Permalink
last fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Nov 27, 2023
1 parent d0b4d2d commit 6a77a9b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/ott/geometry/distrib_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
pointcloud.PointCloud(
x[:, None], y[:, None], cost_fn=self.ground_cost
)
)
), **self._kwargs_solve
)
return jnp.squeeze(out.ot_costs)

Expand Down
9 changes: 4 additions & 5 deletions src/ott/solvers/linear/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ class UnivariateOutput(NamedTuple): # noqa: D101
``0<=k<m+n``, and ``0<=s<d`` then writing ``i:=paired_indices[s,0,k]``
and ``j=paired_indices[s,1,k]``, point ``i`` sends
``mass_paired_indices[s,k]`` to point ``j``.
"""
prob: linear_problem.LinearProblem
ot_costs: float
paired_indices: Optional[jnp.ndarray]
mass_paired_indices: Optional[jnp.ndarray]
paired_indices: Optional[jnp.ndarray] = None
mass_paired_indices: Optional[jnp.ndarray] = None

@property
def transport_matrices(self) -> jnp.ndarray:
Expand All @@ -64,8 +63,8 @@ def transport_matrices(self) -> jnp.ndarray:
non-zero values, out of ``dnm`` total entries.
"""
assert self.paired_indices is not None, \
("[d, n, m] tensor of transports cannot be computed, likely because an"+
" approximate method was used (using either subsampling or quantiles).")
"[d, n, m] tensor of transports cannot be computed, likely because an" \
" approximate method was used (using either subsampling or quantiles)."

n, m = self.prob.geom.shape
if self.prob.is_equal_size and self.prob.is_uniform:
Expand Down
6 changes: 3 additions & 3 deletions tests/solvers/quadratic/lower_bound_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ def initialize(self, rng: jax.Array):
self.cy = jax.random.uniform(rngs[3], (self.m, self.m))

@pytest.mark.fast.with_args(
"cost_fn",
"ground_cost",
[costs.SqEuclidean(), costs.PNormP(1.5)],
only_fast=0,
)
def test_lb_pointcloud(self, cost_fn: costs.CostFn):
def test_lb_pointcloud(self, ground_cost: costs.TICost):
x, y = self.x, self.y

geom_x = pointcloud.PointCloud(x)
geom_y = pointcloud.PointCloud(y)
prob = quadratic_problem.QuadraticProblem(
geom_x, geom_y, a=self.a, b=self.b
)
distrib_cost = distrib_costs.UnivariateWasserstein(cost_fn=cost_fn)
distrib_cost = distrib_costs.UnivariateWasserstein(ground_cost=ground_cost)
solver = lower_bound.LowerBoundSolver(
epsilon=1e-1, distrib_cost=distrib_cost
)
Expand Down

0 comments on commit 6a77a9b

Please sign in to comment.