From 0b82212304c350c64b1c11a9335d61683bc0f4e0 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Fri, 24 Nov 2023 19:05:32 +0100 Subject: [PATCH 1/4] refactoring --- src/ott/geometry/distrib_cost.py | 82 ++++++ src/ott/problems/linear/linear_problem.py | 10 + src/ott/solvers/linear/univariate.py | 289 ++++++++++++++------ src/ott/solvers/quadratic/lower_bound.py | 40 +-- src/ott/tools/sinkhorn_divergence.py | 4 +- tests/solvers/linear/univariate_test.py | 140 ++++++---- tests/solvers/quadratic/lower_bound_test.py | 110 +------- 7 files changed, 409 insertions(+), 266 deletions(-) create mode 100644 src/ott/geometry/distrib_cost.py diff --git a/src/ott/geometry/distrib_cost.py b/src/ott/geometry/distrib_cost.py new file mode 100644 index 000000000..bf2757e21 --- /dev/null +++ b/src/ott/geometry/distrib_cost.py @@ -0,0 +1,82 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import jax +import jax.numpy as jnp + +from ott.geometry import costs, pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import univariate + +__all__ = [ + "UnivariateWasserstein", +] + + +@jax.tree_util.register_pytree_node_class +class UnivariateWasserstein(costs.CostFn): + """1D Wasserstein cost for two 1D distributions. + + This ground cost between considers vectors as a family of values. The + Wasserstein distance between them is the 1D OT cost, using a user-defined + ground cost. + """ + + def __init__( + self, + ground_cost: Optional[costs.TICost] = None, + kwargs_solve: Optional[Any] = None, + **kwargs: Any + ): + super().__init__() + if ground_cost is None: + self.ground_cost = costs.SqEuclidean() + else: + self.ground_cost = ground_cost + self._kwargs_solve = {} if kwargs_solve is None else kwargs_solve + self._kwargs = kwargs + self._solver = univariate.UnivariateSolver(**kwargs) + + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + """Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist. + + Args: + x: vector + y: vector + kwargs: arguments passed on when calling the + :class:`~ott.solvers.linear.univariate.UnivariateSolver`. May include + random key, or specific instructions to subsample or compute using + quantiles. + + Returns: + The transport cost. + """ + out = self._solver( + linear_problem.LinearProblem( + pointcloud.PointCloud( + x[:, None], y[:, None], cost_fn=self.ground_cost + ) + ), **self._kwargs_solve + ) + return jnp.squeeze(out.ot_costs) + + def tree_flatten(self): # noqa: D102 + return (), (self.ground_cost, self._kwargs_solve, self._kwargs) + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + gc, kws, kw = aux_data + return cls(gc, kws, **kw) diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index 7c206aa63..84d29ebd4 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -78,6 +78,16 @@ def is_balanced(self) -> bool: """Whether the problem is balanced.""" return self.tau_a == 1.0 and self.tau_b == 1.0 + @property + def is_uniform(self) -> bool: + """Test if no weights were passed.""" + return self._a is None and self._b is None + + @property + def is_equal_size(self) -> bool: + """Test if square shape, i.e. n == m.""" + return self.geom.shape[0] == self.geom.shape[1] + @property def epsilon(self) -> float: """Entropic regularization.""" diff --git a/src/ott/solvers/linear/univariate.py b/src/ott/solvers/linear/univariate.py index 2b6392227..334b0ee75 100644 --- a/src/ott/solvers/linear/univariate.py +++ b/src/ott/solvers/linear/univariate.py @@ -12,134 +12,247 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Literal, Optional +from typing import NamedTuple, Optional, Union import jax import jax.numpy as jnp -from ott.geometry import costs +from ott.geometry import costs, pointcloud +from ott.problems.linear import linear_problem -__all__ = ["UnivariateSolver"] +__all__ = ["UnivariateOutput", "UnivariateSolver"] + + +class UnivariateOutput(NamedTuple): # noqa: D101 + prob: linear_problem.LinearProblem + ot_costs: float + paired_indices: jax.Array + mass_paired_indices: jax.Array + + @property + def transport_matrices(self) -> jax.Array: + """Output a ``[d,n,m]`` tensor of all ``[n,m]`` transport matrices.""" + assert self.paired_indices is not None, "[d,n,m] Transports *not* computed" + + n, m = self.prob.geom.shape + if self.prob.is_equal_size and self.prob.is_uniform: + transport_matrices_from_indices = jax.vmap( + lambda idx, idy: jnp.eye(n)[idx, :][:, idy].T, in_axes=[0, 0] + ) + return transport_matrices_from_indices( + self.paired_indices[:, 0, :], self.paired_indices[:, 1, :] + ) + + # raveled indexing of entries. + indices = self.paired_indices[:, 0] * m + self.paired_indices[:, 1] + # segment sum is needed to collect several contributions + return jax.vmap( + lambda idx, mass: jax.ops.segment_sum( + mass, idx, indices_are_sorted=True, num_segments=n * m + ).reshape(n, m), + in_axes=[0, 0] + )(indices, self.mass_paired_indices) + + @property + def mean_transport_matrix(self) -> jax.Array: + """Return the mean tranport matrix, averaged over slices.""" + return jnp.mean(self.transport_matrices, axis=0) @jax.tree_util.register_pytree_node_class class UnivariateSolver: - r"""1-D OT solver. + r"""Univariate solver to compute 1D OT distance over slices of data. + + Computes 1-Dimensional optimal transport distance between two $d$-dimensional + point clouds. The total distance is the sum of univariate Wasserstein + distances on the $d$ slices of data: given two weighted point-clouds, stored + as ``[n,d]`` and ``[m,d]`` in a + :class:`~ott.problems.linear.linear_problem.LinearProblem` object, with + respective weights ``a`` and ``b``, the solver + computes ``d`` OT distances between each of these ``[n,1]`` and ``[m,1]`` + slices. The distance is computed using the analytical formula by default, + which involves sorting each of the slices independently. The optimal transport + matrices are also outputted when possible (described in sparse form, i.e. + pairs of indices and mass transferred between those indices). - .. warning:: - This solver assumes uniform marginals, a non-uniform marginal solver - is coming soon. + When weights ``a`` and ``b`` are uniform, and ``n=m``, the computation only + involves comparing sorted entries per slice, and ``d`` assignments are given. - Computes the 1-Dimensional optimal transport distance between two histograms. + The user may also supply a ``num_subsamples`` parameter to extract as many + points from the original point cloud, sampled with probability masses ``a`` + and ``b``. This then simply applied the method above to the subsamples, to + output ``d`` costs, but assignments are not provided. + + When the problem is not uniform or not of equal size, the method defaults to + an inversion of the CDF, and outputs both costs and transport matrix in sparse + form. + + When a ``quantiles`` argument is passed, either specifying explicit quantiles + or a grid of quantiles, the distance is evaluated by comparing the quantiles + of the two point clouds on each slice. The OT costs are returned but + assignments are not provided. Args: - sort_fn: The sorting function. If :obj:`None`, - use :func:`hard-sorting `. - cost_fn: The cost function for transport. If :obj:`None`, defaults to - :class:`PNormP(2) `. - method: The method used for computing the distance on the line. Options - currently supported are: - - - `'subsample'` - Take a stratified sub-sample of the distances. - - `'quantile'` - Take equally spaced quantiles of the distances. - - `'equal'` - No subsampling is performed, requires distributions to have - the same number of points. - - `'wasserstein'` - Compute the distance using the explicit solution - involving inverse CDFs. - - n_subsamples: The number of samples to draw for the "quantile" or - "subsample" methods. + num_subsamples: option to reduce the size of inputs by doing random + subsampling, taken into account marginal probabilities. + quantiles: When a vector or a number of quantiles is passed, the distance + is computed by evaluating the cost function on the sectional (one for each + dimension) quantiles of the two point cloud distributions described in the + problem. """ def __init__( self, - sort_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, - cost_fn: Optional[costs.CostFn] = None, - method: Literal["subsample", "quantile", "wasserstein", - "equal"] = "subsample", - n_subsamples: int = 100, + num_subsamples: Optional[int] = None, + quantiles: Optional[Union[int, jnp.ndarray]] = None, ): - self.sort_fn = jnp.sort if sort_fn is None else sort_fn - self.cost_fn = costs.PNormP(2) if cost_fn is None else cost_fn - self.method = method - self.n_subsamples = n_subsamples + self._quantiles = quantiles + self.num_subsamples = num_subsamples + + @property + def quantiles(self): + """Quantiles' values used to evaluate OT cost.""" + if self._quantiles is None: + return None + if isinstance(self._quantiles, int): + return jnp.linspace(0.0, 1.0, self._quantiles) + return self._quantiles + + @property + def num_quantiles(self): + """Number of quantiles used to evaluate OT cost.""" + return 0 if self._quantiles is None else self.quantiles.shape[0] def __call__( self, - x: jnp.ndarray, - y: jnp.ndarray, - a: Optional[jnp.ndarray] = None, - b: Optional[jnp.ndarray] = None + prob: linear_problem.LinearProblem, + rng: Optional[jax.Array] = None, ) -> float: - """Computes the Univariate OT Distance between `x` and `y`. + """Computes Univariate Distance between the `d` dimensional slices. Args: - x: The first distribution of shape ``[n,]`` or ``[n, 1]``. - y: The second distribution of shape ``[m,]`` or ``[m, 1]``. - a: The first marginals when ``method = 'wasserstein'``. If :obj:`None`, - uniform will be used. - b: The second marginals when ``method = 'wasserstein'``. If :obj:`None`, - uniform will be used. + prob: describing, in its geometry attribute, the two point clouds + ``x`` and ``y`` (of respective sizes ``[n,d]`` and ``[m,d]``) and + a ground ``cost_fn`` for between two scalars. The ``[n,]`` and ``[m,]`` + size probability weights vectors are stored in attributes ``a`` and + ``b``. + rng: used for random downsampling, if used. + return_transport: whether to return an average transport matrix (across + slices). Not available when approximating the distance computation + using subsamples, or quantiles. Returns: - The OT distance. + The OT distance, and possibly the transport matrix averaged by + considering all matrices arising from 1D transport on each of the ``d`` + dimensional slices of the input. """ - x = x.squeeze(-1) if x.ndim == 2 else x - y = y.squeeze(-1) if y.ndim == 2 else y - assert x.ndim == 1, x.ndim - assert y.ndim == 1, y.ndim - - n, m = x.shape[0], y.shape[0] - - if self.method == "equal": - xx, yy = self.sort_fn(x), self.sort_fn(y) - elif self.method == "subsample": - assert self.n_subsamples <= n, (self.n_subsamples, x) - assert self.n_subsamples <= m, (self.n_subsamples, y) - - sorted_x, sorted_y = self.sort_fn(x), self.sort_fn(y) - xx = sorted_x[jnp.linspace(0, n, num=self.n_subsamples).astype(int)] - yy = sorted_y[jnp.linspace(0, m, num=self.n_subsamples).astype(int)] - elif self.method == "quantile": - sorted_x, sorted_y = self.sort_fn(x), self.sort_fn(y) - xx = jnp.quantile(sorted_x, q=jnp.linspace(0, 1, self.n_subsamples)) - yy = jnp.quantile(sorted_y, q=jnp.linspace(0, 1, self.n_subsamples)) - elif self.method == "wasserstein": - return self._cdf_distance(x, y, a, b) + geom = prob.geom + n, m = geom.shape + rng = jax.random.PRNGKey(0) if rng is None else rng + geom_is_pc = isinstance(geom, pointcloud.PointCloud) + assert geom_is_pc, "Geometry object in problem must be a PointCloud." + cost_is_TI = isinstance(geom.cost_fn, costs.TICost) + assert cost_is_TI, "Geometry's cost must be translation invariant." + x, y = geom.x, geom.y + + # check if problem has the property uniform / same number of points + is_uniform_same_size = prob.is_uniform and prob.is_equal_size + if self.num_subsamples: + rng1, rng2 = jax.random.split(rng, 2) + if prob.is_uniform: + x = x[jnp.linspace(0, n, num=self.num_subsamples).astype(int), :] + y = y[jnp.linspace(0, m, num=self.num_subsamples).astype(int), :] + else: + x = jax.random.choice(rng1, x, (self.num_subsamples,), p=prob.a, axis=0) + y = jax.random.choice(rng2, y, (self.num_subsamples,), p=prob.b, axis=0) + n = m = self.num_subsamples + # now that both are subsampled, consider them as uniform/same size. + is_uniform_same_size = True + + if self.quantiles is None: + if is_uniform_same_size: + i_x, i_y = jnp.argsort(x, axis=0), jnp.argsort(y, axis=0) + x = jnp.take_along_axis(x, i_x, axis=0) + y = jnp.take_along_axis(y, i_y, axis=0) + ot_costs = jax.vmap(geom.cost_fn.h, in_axes=[0])(x.T - y.T) / n + + if self.num_subsamples: + # When subsampling, the pairing computed have no meaning w.r.t. + # original data. + paired_indices, mass_paired_indices = None, None + else: + paired_indices = jnp.stack([i_x, i_y]).transpose([2, 0, 1]) + mass_paired_indices = jnp.ones((n,)) / n + + else: + ot_costs, paired_indices, mass_paired_indices = jax.vmap( + self._quantile_distance_and_transport, + in_axes=[1, 1, None, None, None] + )(x, y, prob.a, prob.b, geom.cost_fn) + else: - raise NotImplementedError(f"Method `{self.method}` not implemented.") + assert prob.is_uniform, "Quantile method only valid for uniform marginals" + x_q = jnp.quantile(x, self.quantiles, axis=0) + y_q = jnp.quantile(y, self.quantiles, axis=0) + ot_costs = jax.vmap(geom.cost_fn.pairwise, in_axes=[1, 1])(x_q, y_q) + ot_costs /= self.num_quantiles + paired_indices = None + mass_paired_indices = None - # re-scale when subsampling - return self.cost_fn.pairwise(xx, yy) * (n / xx.shape[0]) + return UnivariateOutput( + prob=prob, + ot_costs=ot_costs, + paired_indices=paired_indices, + mass_paired_indices=mass_paired_indices + ) - def _cdf_distance( - self, x: jnp.ndarray, y: jnp.ndarray, a: Optional[jnp.ndarray], - b: Optional[jnp.ndarray] + def _quantile_distance_and_transport( + self, x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + cost_fn: costs.TICost ): - # Implementation based on `scipy` implementation for + # Implementation inspired by `scipy` implementation for # :func: - a = jnp.ones_like(x) if a is None else a - a /= jnp.sum(a) - b = jnp.ones_like(y) if b is None else b - b /= jnp.sum(b) + def sort_and_argsort(x: jnp.array, return_argsort: bool, **kwargs): + if return_argsort: + i_x = jnp.argsort(x, **kwargs) + return x[i_x], i_x + return jnp.sort(x), None + + x, i_x = sort_and_argsort(x, True) + y, i_y = sort_and_argsort(y, True) all_values = jnp.concatenate([x, y]) - all_values_sorter = jnp.argsort(all_values) - all_values_sorted = all_values[all_values_sorter] - x_pdf = jnp.concatenate([a, jnp.zeros(y.shape)])[all_values_sorter] - y_pdf = jnp.concatenate([jnp.zeros(x.shape), b])[all_values_sorter] + all_values_sorted, all_values_sorter = sort_and_argsort(all_values, True) + + x_pdf = jnp.concatenate([a[i_x], jnp.zeros_like(b)])[all_values_sorter] + y_pdf = jnp.concatenate([jnp.zeros_like(a), b[i_y]])[all_values_sorter] x_cdf = jnp.cumsum(x_pdf) y_cdf = jnp.cumsum(y_pdf) - quantiles = jnp.sort(jnp.concatenate([x_cdf, y_cdf])) - x_cdf_inv = all_values_sorted[jnp.searchsorted(x_cdf, quantiles)] - y_cdf_inv = all_values_sorted[jnp.searchsorted(y_cdf, quantiles)] - return jnp.sum( - jax.vmap(self.cost_fn)(y_cdf_inv[1:, None], x_cdf_inv[1:, None]) * - jnp.diff(quantiles) + x_y_cdfs = jnp.concatenate([x_cdf, y_cdf]) + quantile_levels, _ = sort_and_argsort(x_y_cdfs, False) + + i_x_cdf_inv = jnp.searchsorted(x_cdf, quantile_levels) + x_cdf_inv = all_values_sorted[i_x_cdf_inv] + i_y_cdf_inv = jnp.searchsorted(y_cdf, quantile_levels) + y_cdf_inv = all_values_sorted[i_y_cdf_inv] + + diff_q = jnp.diff(quantile_levels) + cost = jnp.sum( + jax.vmap(cost_fn.h)(y_cdf_inv[1:, None] - x_cdf_inv[1:, None]) * diff_q ) + n = x.shape[0] + + i_in_sorted_x_of_quantile = all_values_sorter[i_x_cdf_inv] % n + i_in_sorted_y_of_quantile = all_values_sorter[i_y_cdf_inv] - n + + orig_i = (i_x[i_in_sorted_x_of_quantile])[1:] + orig_j = (i_y[i_in_sorted_y_of_quantile])[1:] + + return cost, jnp.stack([orig_i, orig_j]), diff_q + def tree_flatten(self): # noqa: D102 aux_data = vars(self).copy() return [], aux_data diff --git a/src/ott/solvers/quadratic/lower_bound.py b/src/ott/solvers/quadratic/lower_bound.py index 27d5f0283..10762d681 100644 --- a/src/ott/solvers/quadratic/lower_bound.py +++ b/src/ott/solvers/quadratic/lower_bound.py @@ -16,10 +16,10 @@ import jax -from ott.geometry import geometry +from ott.geometry import distrib_cost, pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers import linear -from ott.solvers.linear import sinkhorn, univariate +from ott.solvers.linear import sinkhorn __all__ = ["LowerBoundSolver"] @@ -28,14 +28,7 @@ class LowerBoundSolver: """Lower bound OT solver :cite:`memoli:11`. - .. warning:: - As implemented, this solver assumes uniform marginals, - non-uniform marginal solver coming soon! - Computes the third lower bound distance from :cite:`memoli:11`, def. 6.3. - there is an uneven number of points in the distributions, then we perform a - stratified subsample of the distribution of distances to approximate - the Wasserstein distance between the local distributions of distances. Args: epsilon: Entropy regularization for the resulting linear problem. @@ -49,17 +42,28 @@ def __init__( **kwargs: Any, ): self.epsilon = epsilon - self.univariate_solver = univariate.UnivariateSolver(**kwargs) def __call__( self, prob: quadratic_problem.QuadraticProblem, - **kwargs: Any, + rng: Optional[jax.Array] = None, + kwargs_univsolver: Optional[Any] = None, + epsilon: Optional[float] = None, + **kwargs ) -> sinkhorn.SinkhornOutput: """Run the Histogram transport solver. Args: prob: Quadratic OT problem. + kwargs_univsolver: keyword args to + create the :class:`~ott.solvers.linear.univariate.UnivariateSolver`, + used to compute a ``[n,m]`` cost matrix, using the linearization + approach. This might rely, for instance, on subsampling or quantile + reduction to speed up computations. + rng: random key, possibly used when computing 1D costs when using + subsampling. + epsilon: entropic regularization passed on to solve the linearization of + the quadratic problem using 1D costs. kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`. Returns: @@ -67,13 +71,13 @@ def __call__( """ dists_xx = prob.geom_xx.cost_matrix dists_yy = prob.geom_yy.cost_matrix - cost_xy = jax.vmap( - jax.vmap(self.univariate_solver, in_axes=(0, None), out_axes=-1), - in_axes=(None, 0), - out_axes=-1, - )(dists_xx, dists_yy) - - geom_xy = geometry.Geometry(cost_matrix=cost_xy, epsilon=self.epsilon) + kwargs_univsolver = {} if kwargs_univsolver is None else kwargs_univsolver + geom_xy = pointcloud.PointCloud( + dists_xx, + dists_yy, + cost_fn=distrib_cost.UnivariateWasserstein(**kwargs_univsolver), + epsilon=self.epsilon + ) return linear.solve(geom_xy, **kwargs) diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 51de97613..ffc741a97 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -51,7 +51,7 @@ def to_dual_potentials(self) -> "potentials.EntropicPotentials": f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y ) - def tree_flatten_foo(self): # noqa: D102 + def tree_flatten(self): # noqa: D102 return [ self.divergence, self.potentials, @@ -65,7 +65,7 @@ def tree_flatten_foo(self): # noqa: D102 } @classmethod - def tree_unflatten_foo(cls, aux_data, children): # noqa: D102 + def tree_unflatten(cls, aux_data, children): # noqa: D102 div, pots, geoms, a, b = children return cls(div, pots, geoms, a=a, b=b, **aux_data) diff --git a/tests/solvers/linear/univariate_test.py b/tests/solvers/linear/univariate_test.py index 221f295cd..9aa671fba 100644 --- a/tests/solvers/linear/univariate_test.py +++ b/tests/solvers/linear/univariate_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools + import jax import jax.numpy as jnp import numpy as np @@ -18,7 +20,8 @@ import scipy as sp from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem -from ott.solvers.linear import sinkhorn, univariate +from ott.solvers import linear +from ott.solvers.linear import univariate class TestUnivariate: @@ -26,13 +29,14 @@ class TestUnivariate: @pytest.fixture(autouse=True) def initialize(self, rng: jax.Array): self.rng = rng - self.n = 17 - self.m = 29 + self.n = 7 + self.m = 5 + self.d = 2 self.rng, *rngs = jax.random.split(self.rng, 5) - self.x = jax.random.uniform(rngs[0], [self.n]) - self.y = jax.random.uniform(rngs[1], [self.m]) - a = jax.random.uniform(rngs[2], [self.n]) - b = jax.random.uniform(rngs[3], [self.m]) + self.x = jax.random.uniform(rngs[0], (self.n, self.d)) + self.y = jax.random.uniform(rngs[1], (self.m, self.d)) + a = jax.random.uniform(rngs[2], (self.n,)) + b = jax.random.uniform(rngs[3], (self.m,)) # adding zero weights to test proper handling a = a.at[0].set(0) @@ -40,92 +44,122 @@ def initialize(self, rng: jax.Array): self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - @pytest.mark.parametrize( - "cost_fn", [ - costs.SqEuclidean(), - costs.PNormP(1.0), - costs.PNormP(2.0), - costs.PNormP(1.7) - ] - ) + @pytest.mark.parametrize("cost_fn", [costs.SqEuclidean(), costs.PNormP(1.8)]) def test_cdf_distance_and_sinkhorn(self, cost_fn: costs.CostFn): """The Univariate distance coincides with the sinkhorn solver""" - univariate_solver = univariate.UnivariateSolver( - method="wasserstein", cost_fn=cost_fn + univariate_solver = univariate.UnivariateSolver() + geom = pointcloud.PointCloud(self.x, self.y, cost_fn=cost_fn) + prob = linear_problem.LinearProblem(geom=geom, a=self.a, b=self.b) + out = jax.jit(univariate_solver)(prob) + costs_1d, matrices_1d = out.ot_costs, out.transport_matrices + mean_matrices_1d = out.mean_transport_matrix + + @jax.jit + @functools.partial(jax.vmap, in_axes=[1, 1, None, None]) + def sliced_sinkhorn(x, y, a, b): + geom = pointcloud.PointCloud( + x[:, None], y[:, None], cost_fn=cost_fn, epsilon=0.0015 + ) + out = linear.solve(geom, a=self.a, b=self.b) + return out.primal_cost, out.matrix, out.converged + + costs_sink, matrices_sink, converged = sliced_sinkhorn( + self.x, self.y, self.a, self.b ) - distance = univariate_solver(self.x, self.y, self.a, self.b) + assert jnp.all(converged) + scale = 1 / (self.n * self.m) - geom = pointcloud.PointCloud( - x=self.x[:, None], y=self.y[:, None], cost_fn=cost_fn, epsilon=5e-3 + np.testing.assert_allclose(costs_1d, costs_sink, atol=scale, rtol=1e-1) + + np.testing.assert_allclose( + jnp.mean(matrices_1d, axis=0).sum(1), self.a, atol=1e-3 + ) + np.testing.assert_allclose( + jnp.mean(matrices_1d, axis=0).sum(0), self.b, atol=1e-3 ) - prob = linear_problem.LinearProblem(geom, a=self.a, b=self.b) - sinkhorn_solver = jax.jit(sinkhorn.Sinkhorn(max_iterations=10_000)) - sinkhorn_soln = sinkhorn_solver(prob) np.testing.assert_allclose( - sinkhorn_soln.primal_cost, distance, atol=0, rtol=1e-1 + matrices_sink, matrices_1d, atol=0.5 * scale, rtol=1e-1 + ) + np.testing.assert_allclose( + jnp.mean(matrices_sink, axis=0), + mean_matrices_1d, + atol=0.5 * scale, + rtol=1e-1 ) @pytest.mark.fast() def test_cdf_distance_and_scipy(self): """The OTT solver coincides with scipy solver""" - - # The `scipy` solver only has the solution for p=1.0 visible - univariate_solver = univariate.UnivariateSolver( - method="wasserstein", cost_fn=costs.PNormP(1.0) - ) - ott_distance = univariate_solver(self.x, self.y, self.a, self.b) - - scipy_distance = sp.stats.wasserstein_distance( - self.x, self.y, self.a, self.b - ) - - np.testing.assert_allclose(scipy_distance, ott_distance, atol=0, rtol=1e-2) + x, y, a, b = self.x, self.y, self.a, self.b + # The `scipy` solver only computes the solution for p=1.0 visible + + # non-uniform: vanilla or subsampling + geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNormP(1.0)) + prob = linear_problem.LinearProblem(geom=geom, a=a, b=b) + ott_d = univariate.UnivariateSolver()(prob).ot_costs[0] + scipy_d = sp.stats.wasserstein_distance(x[:, 0], y[:, 0], a, b) + np.testing.assert_allclose(scipy_d, ott_d, atol=1e-2, rtol=1e-2) + + num_subsamples = 100 + ott_dss = univariate.UnivariateSolver(num_subsamples=num_subsamples + )(prob).ot_costs[0] + np.testing.assert_allclose(scipy_d, ott_dss, atol=1e2, rtol=1e-2) + + # uniform variants + prob = linear_problem.LinearProblem(geom=geom) + scipy_d2 = sp.stats.wasserstein_distance(x[:, 0], y[:, 0]) + + ott_d = univariate.UnivariateSolver()(prob).ot_costs[0] + ott_dq = univariate.UnivariateSolver(quantiles=8)(prob).ot_costs[0] + np.testing.assert_allclose(scipy_d2, ott_d, atol=1e-2, rtol=1e-2) + np.testing.assert_allclose(scipy_d2, ott_dq, atol=1e-1, rtol=1e-1) @pytest.mark.fast() - def test_cdf_grad( + def test_univariate_grad( self, rng: jax.Array, ): # TODO: Once a `check_grad` function is implemented, replace the code # blocks before with `check_grad`'s. - cost_fn = costs.SqEuclidean() rngs = jax.random.split(rng, 4) eps, tol = 1e-4, 1e-3 + x, y = self.x[:, 1][:, None], self.y[:, 1][:, None] + a, b = self.a, self.b + solver = univariate.UnivariateSolver() - solver = univariate.UnivariateSolver(method="wasserstein", cost_fn=cost_fn) + def univ_dist(x, y, a, b): + geom = pointcloud.PointCloud(x, y) + prob = linear_problem.LinearProblem(geom=geom, a=a, b=b) + return jnp.squeeze(solver(prob).ot_costs) - grad_x, grad_y, grad_a, grad_b = jax.jit(jax.grad(solver, (0, 1, 2, 3)) - )(self.x, self.y, self.a, self.b) + grad_x, grad_y, grad_a, grad_b = jax.jit(jax.grad(univ_dist, (0, 1, 2, 3)) + )(x, y, a, b) # Checking geometric grads: - v_x = jax.random.normal(rngs[0], shape=self.x.shape) + v_x = jax.random.normal(rngs[0], shape=x.shape) v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * eps - expected = solver(self.x + v_x, self.y, self.a, - self.b) - solver(self.x - v_x, self.y, self.a, self.b) + expected = univ_dist(x + v_x, y, a, b) - univ_dist(x - v_x, y, a, b) actual = 2.0 * jnp.vdot(v_x, grad_x) np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) - v_y = jax.random.normal(rngs[1], shape=self.y.shape) + v_y = jax.random.normal(rngs[1], shape=y.shape) v_y = (v_y / jnp.linalg.norm(v_y, axis=-1, keepdims=True)) * eps - expected = solver(self.x, self.y + v_y, self.a, - self.b) - solver(self.x, self.y - v_y, self.a, self.b) + expected = univ_dist(x, y + v_y, a, b) - univ_dist(x, y - v_y, a, b) actual = 2.0 * jnp.vdot(v_y, grad_y) np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) # Checking probability grads: - v_a = jax.random.normal(rngs[2], shape=self.x.shape) + v_a = jax.random.normal(rngs[2], shape=a.shape) v_a -= jnp.mean(v_a, axis=-1, keepdims=True) v_a = (v_a / jnp.linalg.norm(v_a, axis=-1, keepdims=True)) * eps - expected = solver(self.x, self.y, self.a + v_a, - self.b) - solver(self.x, self.y, self.a - v_a, self.b) + expected = univ_dist(x, y, a + v_a, b) - univ_dist(x, y, a - v_a, b) actual = 2.0 * jnp.vdot(v_a, grad_a) np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) - v_b = jax.random.normal(rngs[3], shape=self.y.shape) + v_b = jax.random.normal(rngs[3], shape=b.shape) v_b -= jnp.mean(v_b, axis=-1, keepdims=True) v_b = (v_b / jnp.linalg.norm(v_b, axis=-1, keepdims=True)) * eps - expected = solver(self.x, self.y, self.a, self.b + - v_b) - solver(self.x, self.y, self.a, self.b - v_b) + expected = univ_dist(x, y, a, b + v_b) - univ_dist(x, y, a, b - v_b) actual = 2.0 * jnp.vdot(v_b, grad_b) np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index 6a15bd20a..a6f7d4f11 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -12,19 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Callable - import jax import jax.numpy as jnp -import numpy as np import pytest from ott.geometry import costs, pointcloud -from ott.initializers.linear import initializers from ott.problems.quadratic import quadratic_problem -from ott.solvers.linear import implicit_differentiation as implicit_lib from ott.solvers.quadratic import lower_bound -from ott.tools import soft_sort class TestLowerBoundSolver: @@ -46,18 +39,12 @@ def initialize(self, rng: jax.Array): self.cy = jax.random.uniform(rngs[3], (self.m, self.m)) @pytest.mark.fast.with_args( - "epsilon_sort,method,cost_fn", - [(0.0, "subsample", costs.SqEuclidean()), - (1e-1, "quantile", costs.PNormP(1.5)), (1.0, "equal", costs.SqPNorm(1)), - (None, "subsample", costs.PNormP(3.1))], + "cost_fn", + [costs.SqEuclidean(), costs.PNormP(1.5)], only_fast=0, ) - def test_lb_pointcloud( - self, epsilon_sort: float, method: str, cost_fn: costs.CostFn - ): - n_sub = min([self.x.shape[0], self.y.shape[0]]) - x, y = (self.x[:n_sub], - self.y[:n_sub]) if method == "equal" else (self.x, self.y) + def test_lb_pointcloud(self, cost_fn: costs.CostFn): + x, y = self.x, self.y geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) @@ -65,95 +52,8 @@ def test_lb_pointcloud( geom_x, geom_y, a=self.a, b=self.b ) - if epsilon_sort is not None and epsilon_sort <= 0.0: - sort_fn = None - else: - sort_fn = functools.partial( - soft_sort.sort, - epsilon=epsilon_sort, - min_iterations=100, - max_iterations=100, - ) - - solver = lower_bound.LowerBoundSolver( - epsilon=1e-1, - sort_fn=sort_fn, - cost_fn=cost_fn, - method=method, - n_subsamples=4, - ) + solver = lower_bound.LowerBoundSolver(epsilon=1e-1, cost_fn=cost_fn) out = jax.jit(solver)(prob) - np.testing.assert_allclose( - out.primal_cost, jnp.sum(out.geom.cost_matrix * out.matrix), rtol=1e-3 - ) - assert not jnp.isnan(out.reg_ot_cost) - - @pytest.mark.parametrize("method", ["subsample", "quantile", "equal"]) - @pytest.mark.parametrize( - "sort_fn", - [ - None, - functools.partial( - soft_sort.sort, - epsilon=1e-3, - implicit_diff=False, - # soft sort uses `sorting` initializer, which uses while loop - # which is not reverse-mode diff. - initializer=initializers.DefaultInitializer(), - min_iterations=10, - max_iterations=10, - ), - functools.partial( - soft_sort.sort, - epsilon=1e-1, - implicit_diff=implicit_lib.ImplicitDiff(), - initializer=initializers.DefaultInitializer(), - min_iterations=0, - max_iterations=100, - ) - ] - ) - def test_lb_grad( - self, rng: jax.Array, sort_fn: Callable[[jnp.ndarray], jnp.ndarray], - method: str - ): - - def fn(x: jnp.ndarray, y: jnp.ndarray) -> float: - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) - - solver = lower_bound.LowerBoundSolver( - epsilon=5e-2, - sort_fn=sort_fn, - cost_fn=costs.SqEuclidean(), - method=method, - n_subsamples=n_sub, - ) - return solver(prob).reg_ot_cost - - rng1, rng2 = jax.random.split(rng) - eps, tol = 1e-4, 1e-3 - - n_sub = min(self.x.shape[0], self.y.shape[0]) - if method == "equal": - x, y = self.x[:n_sub], self.y[:n_sub] - else: - x, y = self.x, self.y - - grad_x, grad_y = jax.jit(jax.grad(fn, (0, 1)))(x, y) - - v_x = jax.random.normal(rng1, shape=x.shape) - v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * eps - expected = fn(x + v_x, y) - fn(x - v_x, y) - actual = 2.0 * jnp.vdot(v_x, grad_x) - np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) - - v_y = jax.random.normal(rng2, shape=y.shape) - v_y = (v_y / jnp.linalg.norm(v_y, axis=-1, keepdims=True)) * eps - expected = (fn(x, y + v_y) - fn(x, y - v_y)) - actual = 2.0 * jnp.vdot(v_y, grad_y) - np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) From 80a43b2b8061f5892119f8b6e22474eb679472d2 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Mon, 27 Nov 2023 17:09:02 +0100 Subject: [PATCH 2/4] refactoring following michal comments --- docs/geometry.rst | 1 + src/ott/geometry/__init__.py | 1 + .../{distrib_cost.py => distrib_costs.py} | 37 +-- src/ott/math/utils.py | 15 +- src/ott/problems/linear/linear_problem.py | 4 +- src/ott/solvers/linear/univariate.py | 246 +++++++++++------- src/ott/solvers/quadratic/lower_bound.py | 41 +-- 7 files changed, 217 insertions(+), 128 deletions(-) rename src/ott/geometry/{distrib_cost.py => distrib_costs.py} (71%) diff --git a/docs/geometry.rst b/docs/geometry.rst index adc0a3880..537606d6f 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -66,6 +66,7 @@ Cost Functions costs.ElasticSTVS costs.ElasticSqKOverlap costs.SoftDTW + distrib_costs.UnivariateWasserstein Utilities --------- diff --git a/src/ott/geometry/__init__.py b/src/ott/geometry/__init__.py index 5890e0935..8c97722d2 100644 --- a/src/ott/geometry/__init__.py +++ b/src/ott/geometry/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from . import ( costs, + distrib_costs, epsilon_scheduler, geometry, graph, diff --git a/src/ott/geometry/distrib_cost.py b/src/ott/geometry/distrib_costs.py similarity index 71% rename from src/ott/geometry/distrib_cost.py rename to src/ott/geometry/distrib_costs.py index bf2757e21..46b2d411c 100644 --- a/src/ott/geometry/distrib_cost.py +++ b/src/ott/geometry/distrib_costs.py @@ -32,33 +32,36 @@ class UnivariateWasserstein(costs.CostFn): This ground cost between considers vectors as a family of values. The Wasserstein distance between them is the 1D OT cost, using a user-defined ground cost. + + Args: + kwargs: arguments passed on when calling the + :class:`~ott.solvers.linear.univariate.UnivariateSolver`. May include + random key, or specific instructions to subsample or compute using + quantiles. + """ def __init__( self, ground_cost: Optional[costs.TICost] = None, - kwargs_solve: Optional[Any] = None, + solver: Optional[univariate.UnivariateSolver] = None, **kwargs: Any ): super().__init__() - if ground_cost is None: - self.ground_cost = costs.SqEuclidean() - else: - self.ground_cost = ground_cost - self._kwargs_solve = {} if kwargs_solve is None else kwargs_solve - self._kwargs = kwargs - self._solver = univariate.UnivariateSolver(**kwargs) + + self.ground_cost = ( + costs.SqEuclidean() if ground_cost is None else ground_cost + ) + + self._solver = univariate.UnivariateSolver() if solver is None else solver + self._kwargs_solve = kwargs def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist. Args: - x: vector - y: vector - kwargs: arguments passed on when calling the - :class:`~ott.solvers.linear.univariate.UnivariateSolver`. May include - random key, or specific instructions to subsample or compute using - quantiles. + x: vector, array of shape ``[n,]`` + y: vector, array of shape ``[m,]`` Returns: The transport cost. @@ -73,10 +76,8 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: return jnp.squeeze(out.ot_costs) def tree_flatten(self): # noqa: D102 - return (), (self.ground_cost, self._kwargs_solve, self._kwargs) + return (self.ground_cost, self._solver), self._kwargs_solve @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - del children - gc, kws, kw = aux_data - return cls(gc, kws, **kw) + return cls(*children, **aux_data) diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 8e7ea90ee..3331b3a61 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -29,6 +29,7 @@ "gen_js", "logsumexp", "softmin", + "sort_and_argsort", "barycentric_projection", ] @@ -220,3 +221,15 @@ def barycentric_projection( return jax.vmap( lambda m, y: cost_fn.barycenter(m, y)[0], in_axes=[0, None] )(matrix, y) + + +def sort_and_argsort( + x: jnp.array, + *, + argsort: bool = False +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Unified function that returns both sort and argsort, if latter needed.""" + if argsort: + i_x = jnp.argsort(x) + return x[i_x], i_x + return jnp.sort(x), None diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index 84d29ebd4..c89c7cd0f 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -80,12 +80,12 @@ def is_balanced(self) -> bool: @property def is_uniform(self) -> bool: - """Test if no weights were passed.""" + """True if no weights ``a,b`` were passed, and have defaulted to uniform.""" return self._a is None and self._b is None @property def is_equal_size(self) -> bool: - """Test if square shape, i.e. n == m.""" + """True if square shape, i.e. ``n == m``.""" return self.geom.shape[0] == self.geom.shape[1] @property diff --git a/src/ott/solvers/linear/univariate.py b/src/ott/solvers/linear/univariate.py index 334b0ee75..5ca4db25a 100644 --- a/src/ott/solvers/linear/univariate.py +++ b/src/ott/solvers/linear/univariate.py @@ -11,28 +11,61 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, Optional, Tuple, Union import jax import jax.numpy as jnp +from ott import utils from ott.geometry import costs, pointcloud +from ott.math import utils as mu from ott.problems.linear import linear_problem -__all__ = ["UnivariateOutput", "UnivariateSolver"] +__all__ = [ + "UnivariateOutput", "UnivariateSolver", "uniform_distance", + "quantile_distance" +] class UnivariateOutput(NamedTuple): # noqa: D101 + """Holds the output of a UnivariateSolver. + + Objects of this class contain both solutions and problem definition of a + univariate OT problem. + + Args: + prob: OT problem between 2 weighted ``[n, d]`` and ``[m, d]`` point clouds. + ot_costs: ``[d,]`` optimal transport cost values, computed independently + along each of the ``d`` slices. + paired_indices: ``None`` if no transport was computed / recorded (e.g. when + using quantiles or subsampling approximations). Otherwise, output a tensor + of shape ``[d, 2, m+n]``, of ``m+n`` pairs of indices, for which the + optimal transport assigns mass, on each slice of the ``d`` slices + described in the dataset. Namely, for each index ``0<=k jax.Array: - """Output a ``[d,n,m]`` tensor of all ``[n,m]`` transport matrices.""" - assert self.paired_indices is not None, "[d,n,m] Transports *not* computed" + def transport_matrices(self) -> jnp.ndarray: + """Outputs a ``[d, n, m]`` tensor of all ``[n, m]`` transport matrices. + + This tensor will be extremely sparse, since it will have at most ``d(n+m)`` + non-zero values, out of ``dnm`` total entries. + """ + assert self.paired_indices is not None, \ + ("[d, n, m] tensor of transports cannot be computed, likely because an"+ + " approximate method was used (using either subsampling or quantiles).") n, m = self.prob.geom.shape if self.prob.is_equal_size and self.prob.is_uniform: @@ -54,7 +87,7 @@ def transport_matrices(self) -> jax.Array: )(indices, self.mass_paired_indices) @property - def mean_transport_matrix(self) -> jax.Array: + def mean_transport_matrix(self) -> jnp.ndarray: """Return the mean tranport matrix, averaged over slices.""" return jnp.mean(self.transport_matrices, axis=0) @@ -66,10 +99,10 @@ class UnivariateSolver: Computes 1-Dimensional optimal transport distance between two $d$-dimensional point clouds. The total distance is the sum of univariate Wasserstein distances on the $d$ slices of data: given two weighted point-clouds, stored - as ``[n,d]`` and ``[m,d]`` in a + as ``[n, d]`` and ``[m, d]`` in a :class:`~ott.problems.linear.linear_problem.LinearProblem` object, with respective weights ``a`` and ``b``, the solver - computes ``d`` OT distances between each of these ``[n,1]`` and ``[m,1]`` + computes ``d`` OT distances between each of these ``[n, 1]`` and ``[m, 1]`` slices. The distance is computed using the analytical formula by default, which involves sorting each of the slices independently. The optimal transport matrices are also outputted when possible (described in sparse form, i.e. @@ -110,7 +143,7 @@ def __init__( self.num_subsamples = num_subsamples @property - def quantiles(self): + def quantiles(self) -> jnp.ndarray: """Quantiles' values used to evaluate OT cost.""" if self._quantiles is None: return None @@ -119,50 +152,49 @@ def quantiles(self): return self._quantiles @property - def num_quantiles(self): + def num_quantiles(self) -> int: """Number of quantiles used to evaluate OT cost.""" - return 0 if self._quantiles is None else self.quantiles.shape[0] + return 0 if self.quantiles is None else self.quantiles.shape[0] def __call__( self, prob: linear_problem.LinearProblem, rng: Optional[jax.Array] = None, - ) -> float: + ) -> UnivariateOutput: """Computes Univariate Distance between the `d` dimensional slices. Args: - prob: describing, in its geometry attribute, the two point clouds - ``x`` and ``y`` (of respective sizes ``[n,d]`` and ``[m,d]``) and - a ground ``cost_fn`` for between two scalars. The ``[n,]`` and ``[m,]`` - size probability weights vectors are stored in attributes ``a`` and - ``b``. - rng: used for random downsampling, if used. - return_transport: whether to return an average transport matrix (across - slices). Not available when approximating the distance computation - using subsamples, or quantiles. + prob: describing, in its :attr:`~ott.problems.linear.LinearProblem.geom` + attribute, the two point clouds ``x`` and ``y`` + (of respective sizes ``[n, d]`` and ``[m, d]``) and a ground + `TI cost ` between two scalars. + The ``[n,]`` and ``[m,]`` size probability weights vectors are stored in + attributes `:attr:`~ott.problems.linear.LinearProblem.a` and + :attr:`~ott.problems.linear.LinearProblem.b` + rng: used for random downsampling, if specified in the solver. Returns: - The OT distance, and possibly the transport matrix averaged by - considering all matrices arising from 1D transport on each of the ``d`` - dimensional slices of the input. + An output object, that computs ``d`` OT costs, in addition to, possibly, + paired lists of indices and their corresponding masses, on each of the + ``d`` dimensional slices of the input. """ geom = prob.geom n, m = geom.shape - rng = jax.random.PRNGKey(0) if rng is None else rng - geom_is_pc = isinstance(geom, pointcloud.PointCloud) - assert geom_is_pc, "Geometry object in problem must be a PointCloud." - cost_is_TI = isinstance(geom.cost_fn, costs.TICost) - assert cost_is_TI, "Geometry's cost must be translation invariant." + rng = utils.default_prng_key(rng) if rng is None else rng + assert isinstance(geom, pointcloud.PointCloud), \ + "Geometry object in problem must be a PointCloud." + assert isinstance(geom.cost_fn, costs.TICost), \ + "Geometry's cost must be translation invariant." x, y = geom.x, geom.y # check if problem has the property uniform / same number of points is_uniform_same_size = prob.is_uniform and prob.is_equal_size if self.num_subsamples: - rng1, rng2 = jax.random.split(rng, 2) if prob.is_uniform: x = x[jnp.linspace(0, n, num=self.num_subsamples).astype(int), :] y = y[jnp.linspace(0, m, num=self.num_subsamples).astype(int), :] else: + rng1, rng2 = jax.random.split(rng, 2) x = jax.random.choice(rng1, x, (self.num_subsamples,), p=prob.a, axis=0) y = jax.random.choice(rng2, y, (self.num_subsamples,), p=prob.b, axis=0) n = m = self.num_subsamples @@ -171,19 +203,9 @@ def __call__( if self.quantiles is None: if is_uniform_same_size: - i_x, i_y = jnp.argsort(x, axis=0), jnp.argsort(y, axis=0) - x = jnp.take_along_axis(x, i_x, axis=0) - y = jnp.take_along_axis(y, i_y, axis=0) - ot_costs = jax.vmap(geom.cost_fn.h, in_axes=[0])(x.T - y.T) / n - - if self.num_subsamples: - # When subsampling, the pairing computed have no meaning w.r.t. - # original data. - paired_indices, mass_paired_indices = None, None - else: - paired_indices = jnp.stack([i_x, i_y]).transpose([2, 0, 1]) - mass_paired_indices = jnp.ones((n,)) / n - + ot_costs, paired_indices, mass_paired_indices = uniform_distance( + x, y, geom.cost_fn, return_pairs=not self.num_subsamples + ) else: ot_costs, paired_indices, mass_paired_indices = jax.vmap( self._quantile_distance_and_transport, @@ -191,7 +213,8 @@ def __call__( )(x, y, prob.a, prob.b, geom.cost_fn) else: - assert prob.is_uniform, "Quantile method only valid for uniform marginals" + assert prob.is_uniform, \ + "The ``quantiles`` method can only be used with uniform marginals." x_q = jnp.quantile(x, self.quantiles, axis=0) y_q = jnp.quantile(y, self.quantiles, axis=0) ot_costs = jax.vmap(geom.cost_fn.pairwise, in_axes=[1, 1])(x_q, y_q) @@ -206,57 +229,104 @@ def __call__( mass_paired_indices=mass_paired_indices ) - def _quantile_distance_and_transport( - self, x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, - cost_fn: costs.TICost - ): - # Implementation inspired by `scipy` implementation for - # :func: - def sort_and_argsort(x: jnp.array, return_argsort: bool, **kwargs): - if return_argsort: - i_x = jnp.argsort(x, **kwargs) - return x[i_x], i_x - return jnp.sort(x), None + def tree_flatten(self): # noqa: D102 + aux_data = vars(self).copy() + return [], aux_data - x, i_x = sort_and_argsort(x, True) - y, i_y = sort_and_argsort(y, True) + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + return cls(*children, **aux_data) - all_values = jnp.concatenate([x, y]) - all_values_sorted, all_values_sorter = sort_and_argsort(all_values, True) - x_pdf = jnp.concatenate([a[i_x], jnp.zeros_like(b)])[all_values_sorter] - y_pdf = jnp.concatenate([jnp.zeros_like(a), b[i_y]])[all_values_sorter] +def uniform_distance( + x: jnp.ndarray, + y: jnp.ndarray, + cost_fn: costs.TICost, + return_pairs: bool = True +) -> Tuple[float, Optional[jnp.ndarray], Optional[jnp.ndarray]]: + """Distance between two equal-size families of uniformly weighted values x/y. - x_cdf = jnp.cumsum(x_pdf) - y_cdf = jnp.cumsum(y_pdf) + Args: + x: a vector ``[n,]`` of real values + y: a vector ``[n,]`` of real values + cost_fn: a translation invariant cost function, i.e. ``c(x,y) = h(x-y).`` + return_pairs: whether to return mapped pairs. + + Returns: + optimal transport cost, a list of ``n+m`` paired indices, and their + corresponding transport mass. Note that said mass can be null in some + entries, but sums to 1.0 + """ + n = x.shape[0] + i_x, i_y = jnp.argsort(x, axis=0), jnp.argsort(y, axis=0) + x = jnp.take_along_axis(x, i_x, axis=0) + y = jnp.take_along_axis(y, i_y, axis=0) + ot_costs = jax.vmap(cost_fn.h, in_axes=[0])(x.T - y.T) / n + + if return_pairs: + paired_indices, mass_paired_indices = None, None + else: + paired_indices = jnp.stack([i_x, i_y]).transpose([2, 0, 1]) + mass_paired_indices = jnp.ones((n,)) / n + return ot_costs, paired_indices, mass_paired_indices + + +def quantile_distance( + x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + cost_fn: costs.TICost +) -> Tuple[float, jnp.ndarray, jnp.ndarray]: + """Computes distance between quantile functions of distributions (a,x)/(b,y). - x_y_cdfs = jnp.concatenate([x_cdf, y_cdf]) - quantile_levels, _ = sort_and_argsort(x_y_cdfs, False) + Args: + x: a vector ``[n,]`` of real values + y: a vector ``[m,]`` of real values + a: a vector ``[n,]`` of nonnegative weights summing to 1. + b: a vector ``[m,]`` of nonnegative weights summing to 1. + cost_fn: a translation invariant cost function, i.e. ``c(x,y) = h(x-y).`` + + Notes: + Implementation inspired by `scipy` implementation for + :func:`~scipy.stats.wasserstein_distance`, but can be used with other costs, + not just :math:`c(x,y)=|x-y|`. + + Returns: + optimal transport cost, a list of ``n+m`` paired indices, and their + corresponding transport mass. Note that said mass can be null in some + entries, but sums to 1.0 + """ + x, i_x = mu.sort_and_argsort(x, argsort=True) + y, i_y = mu.sort_and_argsort(y, argsort=True) - i_x_cdf_inv = jnp.searchsorted(x_cdf, quantile_levels) - x_cdf_inv = all_values_sorted[i_x_cdf_inv] - i_y_cdf_inv = jnp.searchsorted(y_cdf, quantile_levels) - y_cdf_inv = all_values_sorted[i_y_cdf_inv] + all_values = jnp.concatenate([x, y]) + all_values_sorted, all_values_sorter = mu.sort_and_argsort( + all_values, argsort=True + ) - diff_q = jnp.diff(quantile_levels) - cost = jnp.sum( - jax.vmap(cost_fn.h)(y_cdf_inv[1:, None] - x_cdf_inv[1:, None]) * diff_q - ) + x_pdf = jnp.concatenate([a[i_x], jnp.zeros_like(b)])[all_values_sorter] + y_pdf = jnp.concatenate([jnp.zeros_like(a), b[i_y]])[all_values_sorter] - n = x.shape[0] + x_cdf = jnp.cumsum(x_pdf) + y_cdf = jnp.cumsum(y_pdf) - i_in_sorted_x_of_quantile = all_values_sorter[i_x_cdf_inv] % n - i_in_sorted_y_of_quantile = all_values_sorter[i_y_cdf_inv] - n + x_y_cdfs = jnp.concatenate([x_cdf, y_cdf]) + quantile_levels, _ = mu.sort_and_argsort(x_y_cdfs, argsort=False) - orig_i = (i_x[i_in_sorted_x_of_quantile])[1:] - orig_j = (i_y[i_in_sorted_y_of_quantile])[1:] + i_x_cdf_inv = jnp.searchsorted(x_cdf, quantile_levels) + x_cdf_inv = all_values_sorted[i_x_cdf_inv] + i_y_cdf_inv = jnp.searchsorted(y_cdf, quantile_levels) + y_cdf_inv = all_values_sorted[i_y_cdf_inv] - return cost, jnp.stack([orig_i, orig_j]), diff_q + diff_q = jnp.diff(quantile_levels) + cost = jnp.sum( + jax.vmap(cost_fn.h)(y_cdf_inv[1:, None] - x_cdf_inv[1:, None]) * diff_q + ) - def tree_flatten(self): # noqa: D102 - aux_data = vars(self).copy() - return [], aux_data + n = x.shape[0] - @classmethod - def tree_unflatten(cls, aux_data, children): # noqa: D102 - return cls(*children, **aux_data) + i_in_sorted_x_of_quantile = all_values_sorter[i_x_cdf_inv] % n + i_in_sorted_y_of_quantile = all_values_sorter[i_y_cdf_inv] - n + + orig_i = i_x[i_in_sorted_x_of_quantile][1:] + orig_j = i_y[i_in_sorted_y_of_quantile][1:] + + return cost, jnp.stack([orig_i, orig_j]), diff_q diff --git a/src/ott/solvers/quadratic/lower_bound.py b/src/ott/solvers/quadratic/lower_bound.py index 10762d681..b5ec24f05 100644 --- a/src/ott/solvers/quadratic/lower_bound.py +++ b/src/ott/solvers/quadratic/lower_bound.py @@ -16,7 +16,7 @@ import jax -from ott.geometry import distrib_cost, pointcloud +from ott.geometry import distrib_costs, pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers import linear from ott.solvers.linear import sinkhorn @@ -32,6 +32,9 @@ class LowerBoundSolver: Args: epsilon: Entropy regularization for the resulting linear problem. + cost_fn: Univariate Wasserstein cost, used to compare two point clouds in + different spaces, where each point is seen as its distribution of costs + to other points in its point-cloud. kwargs: Keyword arguments for :class:`~ott.solvers.linear.univariate.UnivariateSolver`. """ @@ -39,44 +42,44 @@ class LowerBoundSolver: def __init__( self, epsilon: Optional[float] = None, - **kwargs: Any, + distrib_cost: Optional[distrib_costs.UnivariateWasserstein] = None, ): self.epsilon = epsilon + if distrib_cost is None: + distrib_cost = distrib_costs.UnivariateWasserstein() + self.distrib_cost = distrib_cost def __call__( self, prob: quadratic_problem.QuadraticProblem, - rng: Optional[jax.Array] = None, - kwargs_univsolver: Optional[Any] = None, epsilon: Optional[float] = None, - **kwargs + rng: Optional[jax.Array] = None, + **kwargs: Any ) -> sinkhorn.SinkhornOutput: - """Run the Histogram transport solver. + """Compute a lower-bound for the GW problem using a simple linearization. + + This solver handles a quadratic problem by computing first a proxy ``[n,m]`` + cost-matrix, inject it into a linear OT solver, to output a first OT matrix + that can be used either to linearize/initialize the resolution of the GW + problem, or more simply as a simple GW solution. Args: prob: Quadratic OT problem. - kwargs_univsolver: keyword args to - create the :class:`~ott.solvers.linear.univariate.UnivariateSolver`, - used to compute a ``[n,m]`` cost matrix, using the linearization - approach. This might rely, for instance, on subsampling or quantile - reduction to speed up computations. - rng: random key, possibly used when computing 1D costs when using - subsampling. epsilon: entropic regularization passed on to solve the linearization of the quadratic problem using 1D costs. + rng: random key, possibly used when computing 1D costs when using + subsampling. kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`. Returns: - The Histogram transport output. + A linear OT output, an approximation of the OT coupling obtained using + the lower bound provided by :cite:`memoli:11`. """ dists_xx = prob.geom_xx.cost_matrix dists_yy = prob.geom_yy.cost_matrix - kwargs_univsolver = {} if kwargs_univsolver is None else kwargs_univsolver + geom_xy = pointcloud.PointCloud( - dists_xx, - dists_yy, - cost_fn=distrib_cost.UnivariateWasserstein(**kwargs_univsolver), - epsilon=self.epsilon + dists_xx, dists_yy, cost_fn=self.distrib_cost, epsilon=self.epsilon ) return linear.solve(geom_xy, **kwargs) From d0b4d2df0d6c44e27c26b9517dc630ac1a3a6507 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Mon, 27 Nov 2023 17:52:22 +0100 Subject: [PATCH 3/4] fix tests --- src/ott/geometry/distrib_costs.py | 7 ++++--- src/ott/solvers/linear/univariate.py | 11 +++++------ src/ott/solvers/quadratic/lower_bound.py | 19 ++++++++++--------- tests/solvers/quadratic/lower_bound_test.py | 8 +++++--- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/ott/geometry/distrib_costs.py b/src/ott/geometry/distrib_costs.py index 46b2d411c..590c8a139 100644 --- a/src/ott/geometry/distrib_costs.py +++ b/src/ott/geometry/distrib_costs.py @@ -47,6 +47,7 @@ def __init__( solver: Optional[univariate.UnivariateSolver] = None, **kwargs: Any ): + from ott.solvers.linear import univariate super().__init__() self.ground_cost = ( @@ -71,13 +72,13 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: pointcloud.PointCloud( x[:, None], y[:, None], cost_fn=self.ground_cost ) - ), **self._kwargs_solve + ) ) return jnp.squeeze(out.ot_costs) def tree_flatten(self): # noqa: D102 - return (self.ground_cost, self._solver), self._kwargs_solve + return (self.ground_cost,), (self._solver,) @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - return cls(*children, **aux_data) + return cls(*children, *aux_data) diff --git a/src/ott/solvers/linear/univariate.py b/src/ott/solvers/linear/univariate.py index 5ca4db25a..b7bab62e5 100644 --- a/src/ott/solvers/linear/univariate.py +++ b/src/ott/solvers/linear/univariate.py @@ -180,7 +180,7 @@ def __call__( """ geom = prob.geom n, m = geom.shape - rng = utils.default_prng_key(rng) if rng is None else rng + rng = utils.default_prng_key(rng) assert isinstance(geom, pointcloud.PointCloud), \ "Geometry object in problem must be a PointCloud." assert isinstance(geom.cost_fn, costs.TICost), \ @@ -208,8 +208,7 @@ def __call__( ) else: ot_costs, paired_indices, mass_paired_indices = jax.vmap( - self._quantile_distance_and_transport, - in_axes=[1, 1, None, None, None] + quantile_distance, in_axes=[1, 1, None, None, None] )(x, y, prob.a, prob.b, geom.cost_fn) else: @@ -230,12 +229,12 @@ def __call__( ) def tree_flatten(self): # noqa: D102 - aux_data = vars(self).copy() - return [], aux_data + return None, (self.num_subsamples, self._quantiles) @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - return cls(*children, **aux_data) + del children + return cls(*aux_data) def uniform_distance( diff --git a/src/ott/solvers/quadratic/lower_bound.py b/src/ott/solvers/quadratic/lower_bound.py index b5ec24f05..7e36cbb91 100644 --- a/src/ott/solvers/quadratic/lower_bound.py +++ b/src/ott/solvers/quadratic/lower_bound.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import jax -from ott.geometry import distrib_costs, pointcloud +from ott.geometry import pointcloud + +if TYPE_CHECKING: + from ott.solvers.linear import distrib_costs from ott.problems.quadratic import quadratic_problem from ott.solvers import linear from ott.solvers.linear import sinkhorn @@ -42,8 +45,9 @@ class LowerBoundSolver: def __init__( self, epsilon: Optional[float] = None, - distrib_cost: Optional[distrib_costs.UnivariateWasserstein] = None, + distrib_cost: Optional["distrib_costs.UnivariateWasserstein"] = None, ): + from ott.geometry import distrib_costs self.epsilon = epsilon if distrib_cost is None: distrib_cost = distrib_costs.UnivariateWasserstein() @@ -81,15 +85,12 @@ def __call__( geom_xy = pointcloud.PointCloud( dists_xx, dists_yy, cost_fn=self.distrib_cost, epsilon=self.epsilon ) - return linear.solve(geom_xy, **kwargs) def tree_flatten(self): # noqa: D102 - return [self.epsilon, self.univariate_solver], {} + return (self.epsilon, self.distrib_cost), None @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - epsilon, solver = children - obj = cls(epsilon, **aux_data) - obj.univariate_solver = solver - return obj + del aux_data + return cls(*children) diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index a6f7d4f11..0fc9211c7 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -15,7 +15,7 @@ import jax import jax.numpy as jnp import pytest -from ott.geometry import costs, pointcloud +from ott.geometry import costs, distrib_costs, pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers.quadratic import lower_bound @@ -51,8 +51,10 @@ def test_lb_pointcloud(self, cost_fn: costs.CostFn): prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, a=self.a, b=self.b ) - - solver = lower_bound.LowerBoundSolver(epsilon=1e-1, cost_fn=cost_fn) + distrib_cost = distrib_costs.UnivariateWasserstein(cost_fn=cost_fn) + solver = lower_bound.LowerBoundSolver( + epsilon=1e-1, distrib_cost=distrib_cost + ) out = jax.jit(solver)(prob) From 6a77a9b18b319063282ac9c04802a85710006d33 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Mon, 27 Nov 2023 18:12:13 +0100 Subject: [PATCH 4/4] last fixes --- src/ott/geometry/distrib_costs.py | 2 +- src/ott/solvers/linear/univariate.py | 9 ++++----- tests/solvers/quadratic/lower_bound_test.py | 6 +++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ott/geometry/distrib_costs.py b/src/ott/geometry/distrib_costs.py index 590c8a139..ce24c9597 100644 --- a/src/ott/geometry/distrib_costs.py +++ b/src/ott/geometry/distrib_costs.py @@ -72,7 +72,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: pointcloud.PointCloud( x[:, None], y[:, None], cost_fn=self.ground_cost ) - ) + ), **self._kwargs_solve ) return jnp.squeeze(out.ot_costs) diff --git a/src/ott/solvers/linear/univariate.py b/src/ott/solvers/linear/univariate.py index b7bab62e5..868f4fdce 100644 --- a/src/ott/solvers/linear/univariate.py +++ b/src/ott/solvers/linear/univariate.py @@ -49,12 +49,11 @@ class UnivariateOutput(NamedTuple): # noqa: D101 ``0<=k jnp.ndarray: @@ -64,8 +63,8 @@ def transport_matrices(self) -> jnp.ndarray: non-zero values, out of ``dnm`` total entries. """ assert self.paired_indices is not None, \ - ("[d, n, m] tensor of transports cannot be computed, likely because an"+ - " approximate method was used (using either subsampling or quantiles).") + "[d, n, m] tensor of transports cannot be computed, likely because an" \ + " approximate method was used (using either subsampling or quantiles)." n, m = self.prob.geom.shape if self.prob.is_equal_size and self.prob.is_uniform: diff --git a/tests/solvers/quadratic/lower_bound_test.py b/tests/solvers/quadratic/lower_bound_test.py index 0fc9211c7..2766e564d 100644 --- a/tests/solvers/quadratic/lower_bound_test.py +++ b/tests/solvers/quadratic/lower_bound_test.py @@ -39,11 +39,11 @@ def initialize(self, rng: jax.Array): self.cy = jax.random.uniform(rngs[3], (self.m, self.m)) @pytest.mark.fast.with_args( - "cost_fn", + "ground_cost", [costs.SqEuclidean(), costs.PNormP(1.5)], only_fast=0, ) - def test_lb_pointcloud(self, cost_fn: costs.CostFn): + def test_lb_pointcloud(self, ground_cost: costs.TICost): x, y = self.x, self.y geom_x = pointcloud.PointCloud(x) @@ -51,7 +51,7 @@ def test_lb_pointcloud(self, cost_fn: costs.CostFn): prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, a=self.a, b=self.b ) - distrib_cost = distrib_costs.UnivariateWasserstein(cost_fn=cost_fn) + distrib_cost = distrib_costs.UnivariateWasserstein(ground_cost=ground_cost) solver = lower_bound.LowerBoundSolver( epsilon=1e-1, distrib_cost=distrib_cost )