diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c354bc85..884377a6b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,12 +11,15 @@ repos: hooks: - id: yapf additional_dependencies: [toml] -- repo: https://github.com/tomcatling/black-nb - rev: '0.7' +- repo: https://github.com/nbQA-dev/nbQA + rev: 1.6.0 hooks: - - id: black-nb + - id: nbqa-pyupgrade + args: [--py38-plus] + - id: nbqa-black + - id: nbqa-isort - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.11.4 hooks: - id: isort - repo: https://github.com/asottile/yesqa @@ -30,7 +33,7 @@ repos: - flake8-bugbear - flake8-blind-except - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.4.0 + rev: v2.5.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '2'] @@ -65,7 +68,7 @@ repos: - flake8-blind-except args: [--docstring-convention, google] - repo: https://github.com/asottile/pyupgrade - rev: v3.2.2 + rev: v3.3.1 hooks: - id: pyupgrade args: [--py3-plus, --py37-plus, --keep-runtime-typing] diff --git a/README.md b/README.md index 39a137e33..1fc236331 100644 --- a/README.md +++ b/README.md @@ -52,9 +52,13 @@ In the simple toy example below, we compute the optimal coupling matrix between ```python import jax import jax.numpy as jnp -from ott.tools import transport -# Samples two point clouds and their weights. -rngs = jax.random.split(jax.random.PRNGKey(0),4) + +from ott.geometry import pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + +# sample two point clouds and their weights. +rngs = jax.random.split(jax.random.PRNGKey(0), 4) n, m, d = 12, 14, 2 x = jax.random.normal(rngs[0], (n,d)) + 1 y = jax.random.uniform(rngs[1], (m,d)) @@ -62,15 +66,17 @@ a = jax.random.uniform(rngs[2], (n,)) b = jax.random.uniform(rngs[3], (m,)) a, b = a / jnp.sum(a), b / jnp.sum(b) # Computes the couplings using the Sinkhorn algorithm. -ot = transport.solve(x, y, a=a, b=b) -P = ot.matrix +geom = pointcloud.PointCloud(x, y) +prob = linear_problem.LinearProblem(geom, a, b) + +solver = sinkhorn.Sinkhorn() +out = solver(prob) ``` -The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix +The call to `solver(prob)` above works out the optimal transport solution. The `out` object contains a transport matrix (here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or -more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and -are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives, -and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). +more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost +functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). ![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png) diff --git a/docs/conf.py b/docs/conf.py index e1e5bb9eb..7e66f29d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,10 +70,6 @@ source_suffix = ['.rst'] autosummary_generate = True -autosummary_filename_map = { - "ott.solvers.linear.sinkhorn.sinkhorn": - "ott.solvers.linear.sinkhorn.sinkhorn-function" -} autodoc_typehints = 'description' diff --git a/docs/solvers/linear.rst b/docs/solvers/linear.rst index 0605bd7e1..f2a92cc27 100644 --- a/docs/solvers/linear.rst +++ b/docs/solvers/linear.rst @@ -10,7 +10,7 @@ Sinkhorn Solvers .. autosummary:: :toctree: _autosummary - sinkhorn.sinkhorn + sinkhorn.solve sinkhorn.Sinkhorn sinkhorn.SinkhornOutput sinkhorn_lr.LRSinkhorn diff --git a/docs/solvers/quadratic.rst b/docs/solvers/quadratic.rst index 9f6ea7a38..6f572f54d 100644 --- a/docs/solvers/quadratic.rst +++ b/docs/solvers/quadratic.rst @@ -10,9 +10,9 @@ Gromov-Wasserstein Solvers .. autosummary:: :toctree: _autosummary + gromov_wasserstein.solve gromov_wasserstein.GromovWasserstein gromov_wasserstein.GWOutput - gromov_wasserstein.gromov_wasserstein Barycenter Solvers ------------------ diff --git a/docs/tools.rst b/docs/tools.rst index a0847d506..48edf620c 100644 --- a/docs/tools.rst +++ b/docs/tools.rst @@ -9,13 +9,6 @@ The tools package contains high level functions that build on outputs produced b They can be used to compute Sinkhorn divergences :cite:`sejourne:19`, instantiate transport matrices, provide differentiable approximations to ranks and quantile functions :cite:`cuturi:19`, etc. -Optimal Transport ------------------ -.. autosummary:: - :toctree: _autosummary - - transport.Transport - Segmented Sinkhorn ------------------ .. autosummary:: diff --git a/examples/fairness/main.py b/examples/fairness/main.py index b6b5aafba..9d3c663f0 100644 --- a/examples/fairness/main.py +++ b/examples/fairness/main.py @@ -15,11 +15,12 @@ from typing import Sequence -import jax from absl import app, flags, logging from clu import platform from ml_collections import config_flags +import jax + from ott.examples.fairness import train FLAGS = flags.FLAGS diff --git a/examples/fairness/train.py b/examples/fairness/train.py index 29c2869e8..5cf959eb7 100644 --- a/examples/fairness/train.py +++ b/examples/fairness/train.py @@ -17,10 +17,11 @@ import functools from typing import Any +import ml_collections + import flax import jax import jax.numpy as jnp -import ml_collections from flax import jax_utils from flax.metrics import tensorboard from flax.training import checkpoints, common_utils diff --git a/examples/soft_error/data.py b/examples/soft_error/data.py index 8f3678cfc..009e550e0 100644 --- a/examples/soft_error/data.py +++ b/examples/soft_error/data.py @@ -13,9 +13,10 @@ # limitations under the License. """Data loading and data augmentation.""" -import jax import tensorflow as tf import tensorflow_datasets as tfds + +import jax from flax import jax_utils diff --git a/examples/soft_error/main.py b/examples/soft_error/main.py index 0c5bd9b09..480a9286c 100644 --- a/examples/soft_error/main.py +++ b/examples/soft_error/main.py @@ -15,11 +15,12 @@ from typing import Sequence -import jax from absl import app, flags, logging from clu import platform from ml_collections import config_flags +import jax + from ott.examples.soft_error import train FLAGS = flags.FLAGS diff --git a/examples/soft_error/train.py b/examples/soft_error/train.py index a364a1cea..bcd28539f 100644 --- a/examples/soft_error/train.py +++ b/examples/soft_error/train.py @@ -17,11 +17,12 @@ import functools from typing import Any +import ml_collections +import tensorflow_datasets as tfds + import flax import jax import jax.numpy as jnp -import ml_collections -import tensorflow_datasets as tfds from flax import jax_utils from flax.metrics import tensorboard from flax.training import checkpoints, common_utils diff --git a/pyproject.toml b/pyproject.toml index 827a20899..6e8114dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,9 @@ include = '\.ipynb$' profile = "black" include_trailing_comma = true multi_line_output = 3 +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"] +known_numeric = ["numpy", "scipy", "pandas", "sklearn", "jax", "flax", "optax", "torch"] +known_plotting = ["matplotlib", "mpl_toolkits", "seaborn"] skip_glob = ["docs/*"] [tool.pytest.ini_options] diff --git a/src/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py index 65b91cc52..8e7124461 100644 --- a/src/ott/initializers/linear/initializers.py +++ b/src/ott/initializers/linear/initializers.py @@ -286,7 +286,7 @@ def _vectorized_update( """Inner loop DualSort Update. Args: - f : potential f, array of size n. + f: potential f, array of size n. modified_cost: cost matrix minus diagonal column-wise. Returns: diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index ed5c172ae..9ff72aa35 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -33,20 +33,20 @@ class LinearProblem: r"""Linear OT problem. This class describes the main ingredients appearing in a linear OT problem. - Namely, a `geom` object (including cost structure/points) describing point - clouds or the support of measures, followed by probability masses `a` and `b`. - Unabalancedness of the problem is also kept track of, through two coefficients - `tau_a` and `tau_b`, which are both kept between 0 and 1 + Namely, a ``geom`` object (including cost structure/points) describing point + clouds or the support of measures, followed by probability masses ``a`` and + ``b``. Unbalancedness of the problem is also kept track of, through two + coefficients ``tau_a`` and ``tau_b``, which are both kept between 0 and 1 (1 corresponding to a balanced OT problem). Args: geom: The ground geometry cost of the linear problem. a: The first marginal. If `None`, it will be uniform. b: The second marginal. If `None`, it will be uniform. - tau_a: If smaller than `1`, defines how much unbalanced the problem is on - the first marginal. - tau_b: If smaller than `1`, defines how much unbalanced the problem is on - the second marginal. + tau_a: If `< 1`, defines how much unbalanced the problem is + on the first marginal. + tau_b: If `< 1`, defines how much unbalanced the problem is + on the second marginal. """ def __init__( diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index 388f60089..67200893f 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -32,7 +32,7 @@ @jax.tree_util.register_pytree_node_class class QuadraticProblem: - r"""Quadratic regularized OT problem. + r"""Quadratic OT problem. The quadratic loss of a single OT matrix is assumed to have the form given in :cite:`peyre:16`, eq. 4. @@ -48,37 +48,37 @@ class QuadraticProblem: Args: geom_xx: Ground geometry of the first space. geom_yy: Ground geometry of the second space. - geom_xy: Geometry defining the linear penalty term for Fused Gromov - Wasserstein. If `None`, the problem reduces to a - plain Gromov Wasserstein problem. - fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein, + geom_xy: Geometry defining the linear penalty term for + Fused Gromov-Wasserstein. If `None`, the problem reduces to a plain + Gromov-Wasserstein problem. + fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored if ``geom_xy`` is not specified. scale_cost: option to rescale the cost matrices: - - if `True`, use the default for each geometry. - - if `False`, keep the original scaling in geometries. + - if :obj:`True`, use the default for each geometry. + - if :obj:`False`, keep the original scaling in geometries. - if :class:`str`, use a specific method available in :class:`~ott.geometry.geometry.Geometry` or :class:`~ott.geometry.pointcloud.PointCloud`. - - if `None`, do not scale the cost matrices. + - if :obj:`None`, do not scale the cost matrices. - a: jnp.ndarray[n] representing the probability weights of the samples - from geom_xx. If None, it will be uniform. - b: jnp.ndarray[n] representing the probability weights of the samples - from geom_yy. If None, it will be uniform. + a: array representing the probability weights of the samples + from ``geom_xx``. If `None`, it will be uniform. + b: array representing the probability weights of the samples + from ``geom_yy``. If `None`, it will be uniform. loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear - part of the loss (see in the pydoc of the class lin1, lin2). The second - one is the quadratic part (quad1, quad2). By default, the loss - is set as the 4 functions representing the squared Euclidean loss, and - this property is taken advantage of in subsequent computations. See - Alternatively, KL loss can be specified in no less optimized way. - tau_a: if lower that 1.0, defines how much unbalanced the problem is on + part of the loss. The second one is the quadratic part (quad1, quad2). + By default, the loss is set as the 4 functions representing the squared + Euclidean loss, and this property is taken advantage of in subsequent + computations. Alternatively, KL loss can be specified in no less optimized + way. + tau_a: if `< 1.0`, defines how much unbalanced the problem is on the first marginal. - tau_b: if lower that 1.0, defines how much unbalanced the problem is on + tau_b: if `< 1.0`, defines how much unbalanced the problem is on the second marginal. gw_unbalanced_correction: Whether the unbalanced version of - :cite:`sejourne:21` is used. Otherwise ``tau_a`` and ``tau_b`` only affect + :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect the inner Sinkhorn loop. ranks: Ranks of the cost matrices, see :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when @@ -274,7 +274,7 @@ def update_linearization( If the problem is unbalanced (``tau_a < 1.0 or tau_b < 1.0``), two cases are possible, as explained in :meth:`init_linearization` above. - Finally, it is also possible to consider a Fused Gromov Wasserstein problem. + Finally, it is also possible to consider a Fused Gromov-Wasserstein problem. Details about the resulting cost matrix are also given in :meth:`init_linearization`. diff --git a/src/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py index a22ac9ab6..1e1ed10fe 100644 --- a/src/ott/solvers/linear/acceleration.py +++ b/src/ott/solvers/linear/acceleration.py @@ -15,6 +15,7 @@ class AndersonAcceleration: """Implements Anderson acceleration for Sinkhorn.""" + # TODO(michalk8): use memory=0 as no Anderson acceleration? memory: int = 2 # Number of iterates considered to form interpolation. refresh_every: int = 1 # Recompute interpolation periodically. ridge_identity: float = 1e-2 # Ridge used in the linear system. diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 0ed908bbe..d25b0e872 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -13,8 +13,8 @@ # limitations under the License. """A Jax implementation of the Sinkhorn algorithm.""" from typing import ( + TYPE_CHECKING, Any, - Callable, Literal, Mapping, NamedTuple, @@ -28,7 +28,6 @@ import jax.numpy as jnp import numpy as np -from ott import utils from ott.geometry import geometry from ott.initializers.linear import initializers as init_lib from ott.math import fixed_point_loop @@ -38,7 +37,10 @@ from ott.solvers.linear import acceleration from ott.solvers.linear import implicit_differentiation as implicit_lib -__all__ = ["Sinkhorn", "SinkhornOutput"] +if TYPE_CHECKING: + from ott.solvers.linear.sinkhorn_lr import LRSinkhorn, LRSinkhornOutput + +__all__ = ["Sinkhorn", "SinkhornOutput", "solve"] class SinkhornState(NamedTuple): @@ -411,12 +413,237 @@ def to_dual_potentials(self) -> potentials.EntropicPotentials: @jax.tree_util.register_pytree_node_class class Sinkhorn: - """A Sinkhorn solver for linear reg-OT problem. + r"""Sinkhorn solver. + + The Sinkhorn algorithm is a fixed point iteration that solves a regularized + optimal transport (reg-OT) problem between two measures. + The optimization variables are a pair of vectors (called potentials, or + scalings when parameterized as exponentials of the former). Calling this + function returns therefore a pair of optimal vectors. In addition to these, + it also returns the objective value achieved by these optimal vectors; + a vector of size ``max_iterations/inner_iterations`` that records the vector + of values recorded to monitor convergence, throughout the execution of the + algorithm (padded with `-1` if convergence happens before), as well as a + boolean to signify whether the algorithm has converged within the number of + iterations specified by the user. + + The reg-OT problem is specified by two measures, of respective sizes ``n`` and + ``m``. From the viewpoint of the ``sinkhorn`` function, these two measures are + only seen through a triplet (``geom``, ``a``, ``b``), where ``geom`` is a + ``Geometry`` object, and ``a`` and ``b`` are weight vectors of respective + sizes ``n`` and ``m``. Starting from two initial values for those potentials + or scalings (both can be defined by the user by passing value in + ``init_dual_a`` or ``init_dual_b``), the Sinkhorn algorithm will use + elementary operations that are carried out by the ``geom`` object. + + Some maths: + Given a geometry ``geom``, which provides a cost matrix :math:`C` with its + regularization parameter :math:`\varepsilon`, (or a kernel matrix :math:`K`) + the reg-OT problem consists in finding two vectors `f`, `g` of size ``n``, + ``m`` that maximize the following criterion. + + .. math:: + + \arg\max_{f, g}{- \langle a, \phi_a^{*}(-f) \rangle - \langle b, + \phi_b^{*}(-g) \rangle - \varepsilon \langle e^{f/\varepsilon}, + e^{-C/\varepsilon} e^{-g/\varepsilon}} \rangle + + where :math:`\phi_a(z) = \rho_a z(\log z - 1)` is a scaled entropy, and + :math:`\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}`, its Legendre transform. + + That problem can also be written, instead, using positive scaling vectors + `u`, `v` of size ``n``, ``m``, handled with the kernel + :math:`K := e^{-C/\varepsilon}`, + + .. math:: + + \arg\max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle + + \langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - \langle u, K v \rangle + + Both of these problems corresponds, in their *primal* formulation, to + solving the unbalanced optimal transport problem with a variable matrix + :math:`P` of size ``n`` x ``m``: + + .. math:: + + \arg\min_{P>0} \langle P,C \rangle -\varepsilon \text{KL}(P | ab^T) + + \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n + | b) + + where :math:`KL` is the generalized Kullback-Leibler divergence. + + The very same primal problem can also be written using a kernel :math:`K` + instead of a cost :math:`C` as well: + + .. math:: + + \arg\min_{P} \varepsilon KL(P|K) + \rho_a \text{KL}(P\mathbf{1}_m | a) + + \rho_b \text{KL}(P^T \mathbf{1}_n | b) + + The *original* OT problem taught in linear programming courses is recovered + by using the formulation above relying on the cost :math:`C`, and letting + :math:`\varepsilon \rightarrow 0`, and :math:`\rho_a, \rho_b \rightarrow + \infty`. + In that case the entropy disappears, whereas the :math:`KL` regularization + above become constraints on the marginals of :math:`P`: This results in a + standard min cost flow problem. This problem is not handled for now in this + toolbox, which focuses exclusively on the case :math:`\varepsilon > 0`. + + The *balanced* regularized OT problem is recovered for finite + :math:`\varepsilon > 0` but letting :math:`\rho_a, \rho_b \rightarrow + \infty`. This problem can be shown to be equivalent to a matrix scaling + problem, which can be solved using the Sinkhorn fixed-point algorithm. + To handle the case :math:`\rho_a, \rho_b \rightarrow \infty`, the + ``sinkhorn`` function uses parameters :math:`tau\_a := \rho_a / + (\varepsilon + \rho_a)` and :math:`tau\_b := \rho_b / (\varepsilon + + \rho_b)` instead. Setting either of these parameters to 1 corresponds to + setting the corresponding :math:`\rho_a, \rho_b` to :math:`\infty`. + + The Sinkhorn algorithm solves the reg-OT problem by seeking optimal `f`, `g` + potentials (or alternatively their parametrization as positive scalings + `u`, `v`), rather than solving the primal problem in :math:`P`. + This is mostly for efficiency (potentials and scalings have a ``n + m`` + memory footprint, rather than ``n m`` required to store `P`). This is also + because both problems are, in fact, equivalent, since the optimal transport + math:`P^*` can be recovered from optimal potentials :math:`f^*`, :math:`g^*` + or scalings :math:`u^*`, :math:`v^*`, using the geometry's cost or kernel + matrix respectively: + + .. math:: + + P^* = \exp\left(\frac{f^*\mathbf{1}_m^T + \mathbf{1}_n g^{*T} - + C}{\varepsilon}\right) \text{ or } P^* = \text{diag}(u^*) K + \text{diag}(v^*) - A Sinkhorn solver takes a linear OT problem object as an input and returns a - SinkhornOutput object that contains all the information required to compute - transports. See :func:`~ott.solvers.linear.sinkhorn.sinkhorn` - for a functional wrapper. + By default, the Sinkhorn algorithm solves this dual problem in `f, g` or + `u, v` using block coordinate ascent, i.e. devising an update for each `f` + and `g` (resp. `u` and `v`) that cancels their respective gradients, one at + a time. These two iterations are repeated ``inner_iterations`` times, after + which the norm of these gradients will be evaluated and compared with the + ``threshold`` value. The iterations are then repeated as long as that error + exceeds ``threshold``. + + Note on Sinkhorn updates: + The boolean flag ``lse_mode`` sets whether the algorithm is run in either: + + - log-sum-exp mode (``lse_mode=True``), in which case it is directly + defined in terms of updates to `f` and `g`, using log-sum-exp + computations. This requires access to the cost matrix :math:`C`, as it is + stored, or possibly computed on the fly by ``geom``. + + - kernel mode (``lse_mode=False``), in which case it will require access + to a matrix vector multiplication operator :math:`z \rightarrow K z`, + where :math:`K` is either instantiated from :math:`C` as + :math:`\exp(-C/\varepsilon)`, or provided directly. In that case, rather + than optimizing on :math:`f` and :math:`g`, it is more convenient to + optimize on their so called scaling formulations, + :math:`u := \exp(f / \varepsilon)` and :math:`v := \exp(g / \varepsilon)`. + While faster (applying matrices is faster than applying ``lse`` repeatedly + over lines), this mode is also less stable numerically, notably for + smaller :math:`\varepsilon`. + + In the source code, the variables ``f_u`` or ``g_v`` can be either regarded + as potentials (real) or scalings (positive) vectors, depending on the choice + of ``lse_mode`` by the user. Once optimization is carried out, we only + return dual variables in potential form, i.e. ``f`` and ``g``. + + In addition to standard Sinkhorn updates, the user can also use heavy-ball + type updates using a ``momentum`` parameter in ]0,2[. We also implement a + strategy that tries to set that parameter adaptively at + ``chg_momentum_from`` iterations, as a function of progress in the error, + as proposed in the literature. + + Another upgrade to the standard Sinkhorn updates provided to the users lies + in using Anderson acceleration. This can be parameterized by setting the + otherwise null ``anderson`` to a positive integer. When selected,the + algorithm will recompute, every ``refresh_anderson_frequency`` (set by + default to 1) an extrapolation of the most recently computed ``anderson`` + iterates. When using that option, notice that differentiation (if required) + can only be carried out using implicit differentiation, and that all + momentum related parameters are ignored. + + The ``parallel_dual_updates`` flag is set to ``False`` by default. In that + setting, ``g_v`` is first updated using the latest values for ``f_u`` and + ``g_v``, before proceeding to update ``f_u`` using that new value for + ``g_v``. When the flag is set to ``True``, both ``f_u`` and ``g_v`` are + updated simultaneously. Note that setting that choice to ``True`` requires + using some form of averaging (e.g. ``momentum=0.5``). Without this, and on + its own ``parallel_dual_updates`` won't work. + + Differentiation: + The optimal solutions ``f`` and ``g`` and the optimal objective + (``reg_ot_cost``) outputted by the Sinkhorn algorithm can be differentiated + w.r.t. relevant inputs ``geom``, ``a`` and ``b`` using, by default, implicit + differentiation of the optimality conditions (``implicit_differentiation`` + set to ``True``). This choice has two consequences. + + - The termination criterion used to stop Sinkhorn (cancellation of + gradient of objective w.r.t. ``f_u`` and ``g_v``) is used to differentiate + ``f`` and ``g``, given a change in the inputs. These changes are computed + by solving a linear system. The arguments starting with + ``implicit_solver_*`` allow to define the linear solver that is used, and + to control for two types or regularization (we have observed that, + depending on the architecture, linear solves may require higher ridge + parameters to remain stable). The optimality conditions in Sinkhorn can be + analyzed as satisfying a ``z=z'`` condition, which are then + differentiated. It might be beneficial (e.g., as in :cite:`cuturi:20a`) + to use a preconditioning function ``precondition_fun`` to differentiate + instead ``h(z) = h(z')``. + + - The objective ``reg_ot_cost`` returned by Sinkhorn uses the so-called + envelope (or Danskin's) theorem. In that case, because it is assumed that + the gradients of the dual variables ``f_u`` and ``g_v`` w.r.t. dual + objective are zero (reflecting the fact that they are optimal), small + variations in ``f_u`` and ``g_v`` due to changes in inputs (such as + ``geom``, ``a`` and ``b``) are considered negligible. As a result, + ``stop_gradient`` is applied on dual variables ``f_u`` and ``g_v`` when + evaluating the ``reg_ot_cost`` objective. Note that this approach is + `invalid` when computing higher order derivatives. In that case the + ``use_danskin`` flag must be set to ``False``. + + An alternative yet more costly way to differentiate the outputs of the + Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation + of the Sinkhorn loop. This is possible because Sinkhorn iterations are + wrapped in a custom fixed point iteration loop, defined in + ``fixed_point_loop``, rather than a standard while loop. This is to ensure + the end result of this fixed point loop can also be differentiated, if + needed, using standard JAX operations. To ensure backprop differentiability, + the ``fixed_point_loop.fixpoint_iter_backprop`` loop does checkpointing of + state variables (here ``f_u`` and ``g_v``) every ``inner_iterations``, and + backpropagates automatically, block by block, through blocks of + ``inner_iterations`` at a time. + + Note: + * The Sinkhorn algorithm may not converge within the maximum number of + iterations for possibly several reasons: + + 1. the regularizer (defined as ``epsilon`` in the geometry ``geom`` + object) is too small. Consider either switching to ``lse_mode=True`` + (at the price of a slower execution), increasing ``epsilon``, or, + alternatively, if you are unable or unwilling to increase ``epsilon``, + either increase ``max_iterations`` or ``threshold``. + 2. the probability weights ``a`` and ``b`` do not have the same total + mass, while using a balanced (``tau_a=tau_b=1.0``) setup. + Consider either normalizing ``a`` and ``b``, or set either ``tau_a`` + and/or ``tau_b<1.0``. + 3. OOMs issues may arise when storing either cost or kernel matrices that + are too large in ``geom``. In the case where, the ``geom`` geometry is + a ``PointCloud``, some of these issues might be solved by setting the + ``online`` flag to ``True``. This will trigger a re-computation on the + fly of the cost/kernel matrix. + + * The weight vectors ``a`` and ``b`` can be passed on with coordinates that + have zero weight. This is then handled by relying on simple arithmetic for + ``inf`` values that will likely arise (due to :math:`\log 0` when + ``lse_mode`` is ``True``, or divisions by zero when ``lse_mode`` is + ``False``). Whenever that arithmetic is likely to produce ``NaN`` values + (due to ``-inf * 0``, or ``-inf - -inf``) in the forward pass, we use + ``jnp.where`` conditional statements to carry ``inf`` rather than ``NaN`` + values. In the reverse mode differentiation, the inputs corresponding to + these 0 weights (a location `x`, or a row in the corresponding cost/kernel + matrix), and the weight itself will have ``NaN`` gradient values. This is + reflects that these gradients are undefined, since these points were not + considered in the optimization and have therefore no impact on the output. Args: lse_mode: ``True`` for log-sum-exp computations, ``False`` for kernel @@ -857,400 +1084,42 @@ def _iterations_implicit_bwd(res, gr): ) -# Sets threshold, norm_errors, geom, a and b to be differentiable, as those are -# non static. Only differentiability w.r.t. geom, a and b will be used. +# sets threshold, norm_errors, geom, a and b to be differentiable, as those are +# non-static. Only differentiability w.r.t. geom, a and b will be used. _iterations_implicit = jax.custom_vjp(iterations) _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd) -def make( - tau_a: float = 1.0, - tau_b: float = 1.0, - threshold: float = 1e-3, - norm_error: int = 1, - inner_iterations: int = 10, - min_iterations: int = 0, - max_iterations: int = 2000, - momentum: Optional[float] = None, - chg_momentum_from: Optional[int] = None, - anderson_acceleration: int = 0, - refresh_anderson_frequency: int = 1, - lse_mode: bool = True, - implicit_differentiation: bool = True, - implicit_solver_fun=jax.scipy.sparse.linalg.cg, - implicit_solver_ridge_kernel: float = 0.0, - implicit_solver_ridge_identity: float = 0.0, - implicit_solver_symmetric: bool = False, - precondition_fun: Optional[Callable[[float], float]] = None, - parallel_dual_updates: bool = False, - use_danskin: bool = None, - initializer: init_lib.SinkhornInitializer = init_lib.DefaultInitializer(), -) -> Sinkhorn: - """For backward compatibility.""" - del tau_a, tau_b - if not implicit_differentiation: - implicit_diff = None - else: - implicit_diff = implicit_lib.ImplicitDiff( - solver_fun=implicit_solver_fun, - ridge_kernel=implicit_solver_ridge_kernel, - ridge_identity=implicit_solver_ridge_identity, - symmetric=implicit_solver_symmetric, - precondition_fun=precondition_fun - ) - # If no params are passed, align default with that provide in Sinkhorn solver. - if momentum is None and chg_momentum_from is None: - mom = acceleration.Momentum(start=300, error_threshold=1e-2) - elif momentum is None: - mom = acceleration.Momentum(start=chg_momentum_from) - elif chg_momentum_from is None: - mom = acceleration.Momentum(value=momentum) - else: - mom = acceleration.Momentum(start=chg_momentum_from, value=momentum) - - if anderson_acceleration > 0: - anderson = acceleration.AndersonAcceleration( - memory=anderson_acceleration, refresh_every=refresh_anderson_frequency - ) - else: - anderson = None - - return Sinkhorn( - lse_mode=lse_mode, - threshold=threshold, - norm_error=norm_error, - inner_iterations=inner_iterations, - min_iterations=min_iterations, - max_iterations=max_iterations, - momentum=mom, - anderson=anderson, - implicit_diff=implicit_diff, - parallel_dual_updates=parallel_dual_updates, - use_danskin=use_danskin, - initializer=initializer, - ) - - -@utils.deprecate(version="0.3.2", alt="Use the `Sinkhorn` class instead.") -def sinkhorn( +def solve( geom: geometry.Geometry, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, tau_a: float = 1.0, tau_b: float = 1.0, - init_dual_a: Optional[jnp.ndarray] = None, - init_dual_b: Optional[jnp.ndarray] = None, - **kwargs: Any, -): - r"""Solve regularized OT problem using Sinkhorn iterations. - - .. note:: - - This function has been deprecated and will be removed in ``0.3.2`` release. - Please use the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver - instead. - - The Sinkhorn algorithm is a fixed point iteration that solves a regularized - optimal transport (reg-OT) problem between two measures. - The optimization variables are a pair of vectors (called potentials, or - scalings when parameterized as exponentials of the former). Calling this - function returns therefore a pair of optimal vectors. In addition to these, - it also returns the objective value achieved by these optimal vectors; - a vector of size ``max_iterations/inner_iterations`` that records the vector - of values recorded to monitor convergence, throughout the execution of the - algorithm (padded with `-1` if convergence happens before), as well as a - boolean to signify whether the algorithm has converged within the number of - iterations specified by the user. - - The reg-OT problem is specified by two measures, of respective sizes ``n`` and - ``m``. From the viewpoint of the ``sinkhorn`` function, these two measures are - only seen through a triplet (``geom``, ``a``, ``b``), where ``geom`` is a - ``Geometry`` object, and ``a`` and ``b`` are weight vectors of respective - sizes ``n`` and ``m``. Starting from two initial values for those potentials - or scalings (both can be defined by the user by passing value in - ``init_dual_a`` or ``init_dual_b``), the Sinkhorn algorithm will use - elementary operations that are carried out by the ``geom`` object. - - Some maths: - Given a geometry ``geom``, which provides a cost matrix :math:`C` with its - regularization parameter :math:`\varepsilon`, (or a kernel matrix :math:`K`) - the reg-OT problem consists in finding two vectors `f`, `g` of size ``n``, - ``m`` that maximize the following criterion. - - .. math:: - - \arg\max_{f, g}{- \langle a, \phi_a^{*}(-f) \rangle - \langle b, - \phi_b^{*}(-g) \rangle - \varepsilon \langle e^{f/\varepsilon}, - e^{-C/\varepsilon} e^{-g/\varepsilon}} \rangle - - where :math:`\phi_a(z) = \rho_a z(\log z - 1)` is a scaled entropy, and - :math:`\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}`, its Legendre transform. - - That problem can also be written, instead, using positive scaling vectors - `u`, `v` of size ``n``, ``m``, handled with the kernel - :math:`K := e^{-C/\varepsilon}`, - - .. math:: - - \arg\max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle + - \langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - \langle u, K v \rangle - - Both of these problems corresponds, in their *primal* formulation, to - solving the unbalanced optimal transport problem with a variable matrix - :math:`P` of size ``n`` x ``m``: - - .. math:: - - \arg\min_{P>0} \langle P,C \rangle -\varepsilon \text{KL}(P | ab^T) - + \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n - | b) - - where :math:`KL` is the generalized Kullback-Leibler divergence. - - The very same primal problem can also be written using a kernel :math:`K` - instead of a cost :math:`C` as well: - - .. math:: - - \arg\min_{P} \varepsilon KL(P|K) + \rho_a \text{KL}(P\mathbf{1}_m | a) + - \rho_b \text{KL}(P^T \mathbf{1}_n | b) - - The *original* OT problem taught in linear programming courses is recovered - by using the formulation above relying on the cost :math:`C`, and letting - :math:`\varepsilon \rightarrow 0`, and :math:`\rho_a, \rho_b \rightarrow - \infty`. - In that case the entropy disappears, whereas the :math:`KL` regularization - above become constraints on the marginals of :math:`P`: This results in a - standard min cost flow problem. This problem is not handled for now in this - toolbox, which focuses exclusively on the case :math:`\varepsilon > 0`. - - The *balanced* regularized OT problem is recovered for finite - :math:`\varepsilon > 0` but letting :math:`\rho_a, \rho_b \rightarrow - \infty`. This problem can be shown to be equivalent to a matrix scaling - problem, which can be solved using the Sinkhorn fixed-point algorithm. - To handle the case :math:`\rho_a, \rho_b \rightarrow \infty`, the - ``sinkhorn`` function uses parameters :math:`tau\_a := \rho_a / - (\varepsilon + \rho_a)` and :math:`tau\_b := \rho_b / (\varepsilon + - \rho_b)` instead. Setting either of these parameters to 1 corresponds to - setting the corresponding :math:`\rho_a, \rho_b` to :math:`\infty`. - - The Sinkhorn algorithm solves the reg-OT problem by seeking optimal `f`, `g` - potentials (or alternatively their parametrization as positive scalings - `u`, `v`), rather than solving the primal problem in :math:`P`. - This is mostly for efficiency (potentials and scalings have a ``n + m`` - memory footprint, rather than ``n m`` required to store `P`). This is also - because both problems are, in fact, equivalent, since the optimal transport - math:`P^*` can be recovered from optimal potentials :math:`f^*`, :math:`g^*` - or scalings :math:`u^*`, :math:`v^*`, using the geometry's cost or kernel - matrix respectively: - - .. math:: - - P^* = \exp\left(\frac{f^*\mathbf{1}_m^T + \mathbf{1}_n g^{*T} - - C}{\varepsilon}\right) \text{ or } P^* = \text{diag}(u^*) K - \text{diag}(v^*) - - By default, the Sinkhorn algorithm solves this dual problem in `f, g` or - `u, v` using block coordinate ascent, i.e. devising an update for each `f` - and `g` (resp. `u` and `v`) that cancels their respective gradients, one at - a time. These two iterations are repeated ``inner_iterations`` times, after - which the norm of these gradients will be evaluated and compared with the - ``threshold`` value. The iterations are then repeated as long as that error - exceeds ``threshold``. - - Note on Sinkhorn updates: - The boolean flag ``lse_mode`` sets whether the algorithm is run in either: - - - log-sum-exp mode (``lse_mode=True``), in which case it is directly - defined in terms of updates to `f` and `g`, using log-sum-exp - computations. This requires access to the cost matrix :math:`C`, as it is - stored, or possibly computed on the fly by ``geom``. - - - kernel mode (``lse_mode=False``), in which case it will require access - to a matrix vector multiplication operator :math:`z \rightarrow K z`, - where :math:`K` is either instantiated from :math:`C` as - :math:`\exp(-C/\varepsilon)`, or provided directly. In that case, rather - than optimizing on :math:`f` and :math:`g`, it is more convenient to - optimize on their so called scaling formulations, - :math:`u := \exp(f / \varepsilon)` and :math:`v := \exp(g / \varepsilon)`. - While faster (applying matrices is faster than applying ``lse`` repeatedly - over lines), this mode is also less stable numerically, notably for - smaller :math:`\varepsilon`. - - In the source code, the variables ``f_u`` or ``g_v`` can be either regarded - as potentials (real) or scalings (positive) vectors, depending on the choice - of ``lse_mode`` by the user. Once optimization is carried out, we only - return dual variables in potential form, i.e. ``f`` and ``g``. - - In addition to standard Sinkhorn updates, the user can also use heavy-ball - type updates using a ``momentum`` parameter in ]0,2[. We also implement a - strategy that tries to set that parameter adaptively at - ``chg_momentum_from`` iterations, as a function of progress in the error, - as proposed in the literature. - - Another upgrade to the standard Sinkhorn updates provided to the users lies - in using Anderson acceleration. This can be parameterized by setting the - otherwise null ``anderson`` to a positive integer. When selected,the - algorithm will recompute, every ``refresh_anderson_frequency`` (set by - default to 1) an extrapolation of the most recently computed ``anderson`` - iterates. When using that option, notice that differentiation (if required) - can only be carried out using implicit differentiation, and that all - momentum related parameters are ignored. - - The ``parallel_dual_updates`` flag is set to ``False`` by default. In that - setting, ``g_v`` is first updated using the latest values for ``f_u`` and - ``g_v``, before proceeding to update ``f_u`` using that new value for - ``g_v``. When the flag is set to ``True``, both ``f_u`` and ``g_v`` are - updated simultaneously. Note that setting that choice to ``True`` requires - using some form of averaging (e.g. ``momentum=0.5``). Without this, and on - its own ``parallel_dual_updates`` won't work. - - Differentiation: - The optimal solutions ``f`` and ``g`` and the optimal objective - (``reg_ot_cost``) outputted by the Sinkhorn algorithm can be differentiated - w.r.t. relevant inputs ``geom``, ``a`` and ``b`` using, by default, implicit - differentiation of the optimality conditions (``implicit_differentiation`` - set to ``True``). This choice has two consequences. - - - The termination criterion used to stop Sinkhorn (cancellation of - gradient of objective w.r.t. ``f_u`` and ``g_v``) is used to differentiate - ``f`` and ``g``, given a change in the inputs. These changes are computed - by solving a linear system. The arguments starting with - ``implicit_solver_*`` allow to define the linear solver that is used, and - to control for two types or regularization (we have observed that, - depending on the architecture, linear solves may require higher ridge - parameters to remain stable). The optimality conditions in Sinkhorn can be - analyzed as satisfying a ``z=z'`` condition, which are then - differentiated. It might be beneficial (e.g., as in :cite:`cuturi:20a`) - to use a preconditioning function ``precondition_fun`` to differentiate - instead ``h(z) = h(z')``. - - - The objective ``reg_ot_cost`` returned by Sinkhorn uses the so-called - envelope (or Danskin's) theorem. In that case, because it is assumed that - the gradients of the dual variables ``f_u`` and ``g_v`` w.r.t. dual - objective are zero (reflecting the fact that they are optimal), small - variations in ``f_u`` and ``g_v`` due to changes in inputs (such as - ``geom``, ``a`` and ``b``) are considered negligible. As a result, - ``stop_gradient`` is applied on dual variables ``f_u`` and ``g_v`` when - evaluating the ``reg_ot_cost`` objective. Note that this approach is - `invalid` when computing higher order derivatives. In that case the - ``use_danskin`` flag must be set to ``False``. - - An alternative yet more costly way to differentiate the outputs of the - Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation - of the Sinkhorn loop. This is possible because Sinkhorn iterations are - wrapped in a custom fixed point iteration loop, defined in - ``fixed_point_loop``, rather than a standard while loop. This is to ensure - the end result of this fixed point loop can also be differentiated, if - needed, using standard JAX operations. To ensure backprop differentiability, - the ``fixed_point_loop.fixpoint_iter_backprop`` loop does checkpointing of - state variables (here ``f_u`` and ``g_v``) every ``inner_iterations``, and - backpropagates automatically, block by block, through blocks of - ``inner_iterations`` at a time. - - Note: - * The Sinkhorn algorithm may not converge within the maximum number of - iterations for possibly several reasons: - - 1. the regularizer (defined as ``epsilon`` in the geometry ``geom`` - object) is too small. Consider either switching to ``lse_mode=True`` - (at the price of a slower execution), increasing ``epsilon``, or, - alternatively, if you are unable or unwilling to increase ``epsilon``, - either increase ``max_iterations`` or ``threshold``. - 2. the probability weights ``a`` and ``b`` do not have the same total - mass, while using a balanced (``tau_a=tau_b=1.0``) setup. - Consider either normalizing ``a`` and ``b``, or set either ``tau_a`` - and/or ``tau_b<1.0``. - 3. OOMs issues may arise when storing either cost or kernel matrices that - are too large in ``geom``. In the case where, the ``geom`` geometry is - a ``PointCloud``, some of these issues might be solved by setting the - ``online`` flag to ``True``. This will trigger a re-computation on the - fly of the cost/kernel matrix. - - * The weight vectors ``a`` and ``b`` can be passed on with coordinates that - have zero weight. This is then handled by relying on simple arithmetic for - ``inf`` values that will likely arise (due to :math:`\log 0` when - ``lse_mode`` is ``True``, or divisions by zero when ``lse_mode`` is - ``False``). Whenever that arithmetic is likely to produce ``NaN`` values - (due to ``-inf * 0``, or ``-inf - -inf``) in the forward pass, we use - ``jnp.where`` conditional statements to carry ``inf`` rather than ``NaN`` - values. In the reverse mode differentiation, the inputs corresponding to - these 0 weights (a location `x`, or a row in the corresponding cost/kernel - matrix), and the weight itself will have ``NaN`` gradient values. This is - reflects that these gradients are undefined, since these points were not - considered in the optimization and have therefore no impact on the output. + rank: int = -1, + **kwargs: Any +) -> Union[SinkhornOutput, 'LRSinkhornOutput']: + """Solve linear regularized OT problem using Sinkhorn iterations. Args: - geom: a Geometry object. + geom: The ground geometry cost of the linear problem. a: The first marginal. If `None`, it will be uniform. b: The second marginal. If `None`, it will be uniform. - tau_a: ratio rho/(rho+eps) between KL divergence regularizer to first - marginal and itself + epsilon regularizer used in the unbalanced - formulation. - tau_b: ratio rho/(rho+eps) between KL divergence regularizer to first - marginal and itself + epsilon regularizer used in the unbalanced - formulation. - init_dual_a: optional initialization for potentials/scalings w.r.t. - first marginal (``a``) of reg-OT problem. - init_dual_b: optional initialization for potentials/scalings w.r.t. - second marginal (``b``) of reg-OT problem. - threshold: tolerance used to stop the Sinkhorn iterations. This is - typically the deviation between a target marginal and the marginal of the - current primal solution when either or both tau_a and tau_b are 1.0 - (balanced or semi-balanced problem), or the relative change between two - successive solutions in the unbalanced case. - norm_error: - power used to define p-norm of error for marginal/target. - inner_iterations:the Sinkhorn error is not recomputed at each - iteration but every inner_num_iter instead. - min_iterations: the minimum number of Sinkhorn iterations carried - out before the error is computed and monitored. - max_iterations: the maximum number of Sinkhorn iterations. If - ``max_iterations`` is equal to ``min_iterations``, sinkhorn iterations are - run by default using a :func:`jax.lax.scan` loop rather than a custom, - unroll-able :func:`jax.lax.while_loop` that monitors convergence. - In that case, the error is not monitored and the ``converged`` flag - will return ``False`` as a consequence. - momentum: - a float in [0,2]. - chg_momentum_from: if positive, momentum is recomputed using the - adaptive rule provided in :cite:`lehmann:21` - after that number of iterations. - anderson_acceleration: int, if 0 (default), no acceleration. If positive, - use Anderson acceleration on the dual sinkhorn (in log/potential form), as - described `here `_ - and advocated in :cite:`chizat:20`, with a memory of size equal - to ``anderson_acceleration``. In that case, differentiation is - necessarily handled implicitly (``implicit_differentiation`` is set to - ``True``) and all ``momentum`` related parameters are ignored. - refresh_anderson_frequency: int, when using ``anderson_acceleration``, - recompute direction periodically every int sinkhorn iterations. - lse_mode: ``True`` for log-sum-exp computations, ``False`` for kernel - multiplication. - implicit_differentiation: ``True`` if using implicit differentiation, - ``False`` if unrolling Sinkhorn iterations. - linear_solve_kwargs: parametrization of linear solver when using implicit - differentiation. Arguments currently accepted appear in the optional - arguments of ``apply_inv_hessian``, namely ``linear_solver_fun``, a - Callable that specifies the linear solver, as well as ``ridge_kernel`` and - ``ridge_identity``, to be added to enforce stability of linear solve. - parallel_dual_updates: updates potentials or scalings in parallel if True, - sequentially (in Gauss-Seidel fashion) if False. - use_danskin: when ``True``, it is assumed the entropy regularized cost - is evaluated using optimal potentials that are frozen, i.e. whose - gradients have been stopped. This is useful when carrying out first order - differentiation, and is only valid (as with ``implicit_differentiation``) - when the algorithm has converged with a low tolerance. - kwargs: Additional keyword arguments (see above). + tau_a: If `< 1`, defines how much unbalanced the problem is + on the first marginal. + tau_b: If `< 1`, defines how much unbalanced the problem is + on the second marginal. + rank: + Rank constraint on the coupling to minimize the linear OT problem + :cite:`scetbon:21`. If `-1`, no rank constraint is used. + kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or + :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`, + depending ``rank``. Returns: - a ``SinkhornOutput`` named tuple. The tuple contains two optimal potential - vectors ``f`` and ``g``, the objective ``reg_ot_cost`` evaluated at those - solutions, an array of ``errors`` to monitor convergence every - ``inner_iterations`` and a flag ``converged`` that is ``True`` if the - algorithm has converged within the number of iterations that was predefined - by the user. + The Sinkhorn output. """ - sink = make(**kwargs) - ot_prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b) - return sink(ot_prob, (init_dual_a, init_dual_b)) + prob = linear_problem.LinearProblem(geom, a=a, b=b, tau_a=tau_a, tau_b=tau_b) + solver = LRSinkhorn(rank=rank, **kwargs) if rank > 0 else Sinkhorn(**kwargs) + return solver(prob) diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index cb1072ee8..12ca108b7 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -223,7 +223,7 @@ class LRSinkhorn(sinkhorn.Sinkhorn): case. Args: - rank: The rank constraint on the coupling to minimize the linear OT problem + rank: Rank constraint on the coupling to minimize the linear OT problem gamma: The (inverse of) gradient step size used by mirror descent. gamma_rescale: Whether to rescale :math:`\gamma` every iteration as described in :cite:`scetbon:22b`. @@ -660,35 +660,3 @@ def run( ot_prob, lse_mode=solver.lse_mode, use_danskin=solver.use_danskin ) return out.set(ot_prob=ot_prob) - - -def make( - rank: int, - gamma: float = 1.0, - epsilon: float = 1e-4, - initializer: Literal['random', 'rank2', 'k-means'] = 'k-means', - lse_mode: bool = True, - threshold: float = 1e-3, - norm_error: int = 10, - inner_iterations: int = 1, - min_iterations: int = 0, - max_iterations: int = 2000, - use_danskin: bool = True, - implicit_diff: bool = False, - kwargs_dys: Optional[Mapping[str, Any]] = None -) -> LRSinkhorn: - return LRSinkhorn( - rank=rank, - gamma=gamma, - epsilon=epsilon, - initializer=initializer, - lse_mode=lse_mode, - threshold=threshold, - norm_error=norm_error, - inner_iterations=inner_iterations, - min_iterations=min_iterations, - max_iterations=max_iterations, - use_danskin=use_danskin, - implicit_diff=implicit_diff, - kwargs_dys=kwargs_dys - ) diff --git a/src/ott/solvers/nn/icnn.py b/src/ott/solvers/nn/icnn.py index bbb608433..5d24eb2d7 100644 --- a/src/ott/solvers/nn/icnn.py +++ b/src/ott/solvers/nn/icnn.py @@ -15,10 +15,10 @@ from typing import Any, Callable, Sequence, Tuple, Union +import flax.linen as nn import jax import jax.numpy as jnp import optax -from flax import linen as nn from flax.training import train_state from jax.nn import initializers diff --git a/src/ott/solvers/nn/layers.py b/src/ott/solvers/nn/layers.py index 2aa72862a..f43d280b7 100644 --- a/src/ott/solvers/nn/layers.py +++ b/src/ott/solvers/nn/layers.py @@ -13,9 +13,9 @@ from typing import Any, Callable, Tuple +import flax.linen as nn import jax import jax.numpy as jnp -from flax import linen as nn __all__ = ["PositiveDense", "PosDefPotentials"] diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index b411e22c3..d859b0b49 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -27,8 +27,7 @@ import jax import jax.numpy as jnp -from ott import utils -from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud +from ott.geometry import geometry, low_rank, pointcloud from ott.initializers.linear import initializers_lr from ott.initializers.quadratic import initializers as quad_initializers from ott.math import fixed_point_loop @@ -37,7 +36,7 @@ from ott.solvers import was_solver from ott.solvers.linear import sinkhorn, sinkhorn_lr -__all__ = ["GWOutput", "GromovWasserstein", "gromov_wasserstein"] +__all__ = ["GWOutput", "GromovWasserstein", "solve"] LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput] @@ -149,7 +148,7 @@ class GromovWasserstein(was_solver.WassersteinSolver): """Gromov-Wasserstein solver. Args: - args: Positional_arguments for + args: Positional arguments for :class:`~ott.solvers.was_solver.WassersteinSolver`. warm_start: Whether to initialize (low-rank) Sinkhorn calls using values from the previous iteration. If `None`, warm starts are not used for @@ -400,72 +399,7 @@ def body_fn( return solver.output_from_state(state) -def make( - epsilon: Union[epsilon_scheduler.Epsilon, float] = 1., - rank: int = -1, - max_iterations: int = 50, - warm_start: Optional[bool] = None, - store_inner_errors: bool = False, - linear_ot_solver_kwargs: Optional[Mapping[str, Any]] = None, - threshold: float = 1e-2, - min_iterations: int = 1, - **kwargs: Any, -) -> GromovWasserstein: - """Create a GromovWasserstein solver. - - Args: - epsilon: a regularization parameter or a epsilon_scheduler.Epsilon object. - rank: integer used to constrain the rank of GW solutions if >0. - max_iterations: the maximum number of outer iterations for - Gromov Wasserstein. - warm_start: Whether to initialize (low-rank) Sinkhorn calls using values - from the previous iteration. If `None`, it's enabled when using low-rank. - store_inner_errors: whether or not to return all the errors of the inner - Sinkhorn iterations. - linear_ot_solver_kwargs: Optionally a dictionary containing the keywords - arguments for the linear OT solver (e.g. sinkhorn) - threshold: threshold (progress between two iterate costs) used to stop GW. - min_iterations: see fixed_point_loop. - kwargs: additional kwargs for epsilon. - - Returns: - A GromovWasserstein solver. - """ - if linear_ot_solver_kwargs is None: - linear_ot_solver_kwargs = {} - - if rank == -1: - sink = sinkhorn.make(**linear_ot_solver_kwargs) - elif rank > 0: - # `rank` and `epsilon` are arguments of the `sinkhorn_lr` solver. As we are - # passing them to make, we should not pass them in `linear_ot_solver_kwargs` - # Therefore, the `rank` or `epsilon` passed to `linear_ot_solver_kwargs` are - # deleted. - _ = linear_ot_solver_kwargs.pop('rank', None) - _ = linear_ot_solver_kwargs.pop('epsilon', None) - sink = sinkhorn_lr.make( - rank=rank, epsilon=epsilon, **linear_ot_solver_kwargs - ) - else: - raise ValueError(f"Invalid value for `rank={rank}`.") - - return GromovWasserstein( - epsilon=epsilon, - rank=rank, - linear_ot_solver=sink, - threshold=threshold, - min_iterations=min_iterations, - max_iterations=max_iterations, - store_inner_errors=store_inner_errors, - warm_start=warm_start, - **kwargs - ) - - -@utils.deprecate( - version="0.3.2", alt="Use the `GromovWasserstein` class instead." -) -def gromov_wasserstein( +def solve( geom_xx: geometry.Geometry, geom_yy: geometry.Geometry, geom_xy: Optional[geometry.Geometry] = None, @@ -481,68 +415,69 @@ def gromov_wasserstein( tolerances: Union[float, Tuple[float, ...]] = 1e-2, **kwargs: Any, ) -> GWOutput: - """Solve a Gromov Wasserstein problem. + r"""Solve quadratic regularized OT problem. + + The quadratic loss of a single OT matrix is assumed to + have the form given in :cite:`peyre:16`, eq. 4. - .. note:: + The two geometries below parameterize matrices :math:`C` and :math:`\bar{C}` + in that equation. The function :math:`L` (of two real values) in that equation + is assumed to match the form given in eq. 5., with our notations: - This function has been deprecated and will be removed in ``0.3.2`` release. - Please use the - :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` - solver instead. + .. math:: - Wrapper that instantiates a quadratic problem (possibly with linear term - if the problem is fused) and calls a solver to output a solution. + L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y) Args: - geom_xx: Geometry for the first view. - geom_yy: Geometry for the second view. - geom_xy: Geometry representing the linear cost in FGW. - fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein, - i.e. loss = quadratic_loss + fused_penalty * linear_loss. + geom_xx: Ground geometry of the first space. + geom_yy: Ground geometry of the second space. + geom_xy: Geometry defining the linear penalty term for + Fused Gromov-Wasserstein. If `None`, the problem reduces to + a plain Gromov-Wasserstein problem. + fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein, + i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored if ``geom_xy`` is not specified. scale_cost: option to rescale the cost matrices: - - if `True`, use the default for each geometry. - - if `False`, keep the original scaling in geometries. + - if :obj:`True`, use the default for each geometry. + - if :obj:`False`, keep the original scaling in geometries. - if :class:`str`, use a specific method available in :class:`~ott.geometry.geometry.Geometry` or :class:`~ott.geometry.pointcloud.PointCloud`. - - if `None`, do not scale the cost matrices. - - a: jnp.ndarray[num_a,] or jnp.ndarray[batch,num_a] weights. - b: jnp.ndarray[num_b,] or jnp.ndarray[batch,num_b] weights. - loss: defaults to the square Euclidean distance. Can also pass 'kl' - to define the GW loss as KL loss. - tau_a: float between 0 and 1.0, parameter that controls the strength of the - KL divergence constraint between the weights and marginals of the - transport for the first view. If set to 1.0, then it is equivalent to a - hard constraint and if smaller to a softer constraint. - tau_b: float between 0 and 1.0, parameter that controls the strength of the - KL divergence constraint between the weights and marginals of the - transport for the second view. If set to 1.0, then it is equivalent to a - hard constraint and if smaller to a softer constraint. - gw_unbalanced_correction: True (default) if the unbalanced version of - :cite:`sejourne:21` is used, False if tau_a and tau_b - only affect the inner Sinkhorn loop. - ranks: Switch to a low rank approximation of all cost matrices, using - :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`, to gain speed. - This is only relevant if the geometries of interest are *not* - :class:`~ott.geometry.pointcloud.PointCloud` with `'sqeucl'` cost - function, in which case they would be low-rank by construction (as long - as the sizes of these point clouds is larger than dimension). - If `-1`, geometries are left as they are, and not converted. - If :class:`tuple`, these 2 or 3 :class:`int` specify the ranks of - ``geom_xx``, ``geom_yy`` and ``geom_xy``, respectively. If :class:`int`, - all 3 geometries are converted using that rank. - tolerances: Tolerances used when converting geometries to low-rank. Used when + - if :obj:`None`, do not scale the cost matrices. + + a: array representing the probability weights of the samples + from ``geom_xx``. If `None`, it will be uniform. + b: array representing the probability weights of the samples + from ``geom_yy``. If `None`, it will be uniform. + loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear + part of the loss. The second one is the quadratic part (quad1, quad2). + By default, the loss is set as the 4 functions representing the squared + Euclidean loss, and this property is taken advantage of in subsequent + computations. Alternatively, KL loss can be specified in no less optimized + way. + tau_a: if `< 1.0`, defines how much unbalanced the problem is on + the first marginal. + tau_b: if `< 1.0`, defines how much unbalanced the problem is on + the second marginal. + gw_unbalanced_correction: Whether the unbalanced version of + :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect + the inner Sinkhorn loop. + ranks: Ranks of the cost matrices, see + :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with - `'sqeucl'` cost. If :class:`float`, that tolerance is shared across all - 3 geometries. + `'sqeucl'` cost function. If `-1`, the geometries will not be converted + to low-rank. If :class:`tuple`, it specifies the ranks of ``geom_xx``, + ``geom_yy`` and ``geom_xy``, respectively. If :class:`int`, rank is shared + across all geometries. + tolerances: Tolerances used when converting geometries to low-rank. Used + when geometries are not :class:`~ott.geometry.pointcloud.PointCloud` with + `'sqeucl'` cost. If :class:`float`, it is shared across all geometries. kwargs: Keyword arguments for :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`. Returns: - A GromovWassersteinState named tuple. + Gromov-Wasserstein output. """ prob = quadratic_problem.QuadraticProblem( geom_xx, @@ -559,5 +494,5 @@ def gromov_wasserstein( ranks=ranks, tolerances=tolerances ) - solver = make(**kwargs) + solver = GromovWasserstein(**kwargs) return solver(prob) diff --git a/src/ott/tools/__init__.py b/src/ott/tools/__init__.py index 1dd1945e5..4b7f20c82 100644 --- a/src/ott/tools/__init__.py +++ b/src/ott/tools/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. """OTT tools: A set of tools to use OT in differentiable ML pipelines.""" -from . import gaussian_mixture, k_means, plot, sinkhorn_divergence, soft_sort, transport +from . import gaussian_mixture, k_means, plot, sinkhorn_divergence, soft_sort diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index eccd3da68..f1fe519d7 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -13,10 +13,13 @@ # limitations under the License. """Pytree containing parameters for a pair of coupled Gaussian mixture models. """ # noqa: D200 +from typing import Any + import jax import jax.numpy as jnp from ott.geometry import costs, geometry, pointcloud +from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn from ott.tools.gaussian_mixture import gaussian_mixture @@ -129,19 +132,22 @@ def get_cost_matrix(self) -> jnp.ndarray: """Get matrix of W2^2 costs between all pairs of (gmm0, gmm1) components.""" return self.get_bures_geometry().cost_matrix - def get_sinkhorn(self, cost_matrix: jnp.ndarray) -> sinkhorn.SinkhornOutput: - """Get the output of ott.sinkhorn's method for a given cost matrix.""" + def get_sinkhorn( + self, cost_matrix: jnp.ndarray, **kwargs: Any + ) -> sinkhorn.SinkhornOutput: + """Get the output of Sinkhorn's method for a given cost matrix.""" # We use a Geometry here rather than the PointCloud created in - # in get_bures_geometry to avoid recomputing the cost matrix, since + # get_bures_geometry to avoid recomputing the cost matrix, since # the cost matrix is quite expensive geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=self.epsilon) - return sinkhorn.sinkhorn( + prob = linear_problem.LinearProblem( geom, a=self.gmm0.component_weights, b=self.gmm1.component_weights, tau_a=self.tau, tau_b=self.tau ) + return sinkhorn.Sinkhorn(**kwargs)(prob) def get_normalized_sinkhorn_coupling( self, diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index fec2e4037..4e819589b 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -13,19 +13,27 @@ # limitations under the License. """Plotting utils.""" -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import jax.numpy as jnp -import matplotlib.pyplot as plt import numpy as np import scipy + +import matplotlib.pyplot as plt from matplotlib import animation +from ott import utils from ott.geometry import pointcloud -from ott.tools import transport +from ott.solvers.linear import sinkhorn, sinkhorn_lr +from ott.solvers.quadratic import gromov_wasserstein +# TODO(michalk8): make sure all outputs conform to a unified transport interface +Transport = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, + gromov_wasserstein.GWOutput] -def bidimensional(x: jnp.ndarray, y: jnp.ndarray): + +def bidimensional(x: jnp.ndarray, + y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """Apply PCA to reduce to bimensional data.""" if x.shape[1] < 3: return x, y @@ -72,7 +80,7 @@ def __init__( self._scale = scale self._cmap = cmap - def _scatter(self, ot: transport.Transport): + def _scatter(self, ot: Transport): """Compute the position and scales of the points on a 2D plot.""" if not isinstance(ot.geom, pointcloud.PointCloud): raise ValueError('So far we only plot PointCloud geometry.') @@ -95,7 +103,7 @@ def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray): result.append((xy[i, [0, 2]], xy[i, [1, 3]], strength)) return result - def __call__(self, ot: transport.Transport) -> List[plt.Artist]: + def __call__(self, ot: Transport) -> List[plt.Artist]: """Plot 2-D couplings. Projects via PCA if data is higher dimensional.""" x, y, sx, sy = self._scatter(ot) self._points_x = self.ax.scatter( @@ -123,8 +131,8 @@ def __call__(self, ot: transport.Transport) -> List[plt.Artist]: self._lines.append(line) return [self._points_x, self._points_y] + self._lines - def update(self, ot: transport.Transport) -> List[plt.Artist]: - """Update a plot with a transport.Transport instance.""" + def update(self, ot: Transport) -> List[plt.Artist]: + """Update a plot with a transport instance.""" x, y, _, _ = self._scatter(ot) self._points_x.set_offsets(x) self._points_y.set_offsets(y) @@ -159,11 +167,11 @@ def update(self, ot: transport.Transport) -> List[plt.Artist]: def animate( self, - transports: Sequence[transport.Transport], + transports: Sequence[Transport], frame_rate: float = 10.0 ) -> animation.FuncAnimation: - """Make an animation from several transport.Transport.""" - self(transports[0]) + """Make an animation from several transports.""" + _ = self(transports[0]) return animation.FuncAnimation( self.fig, lambda i: self.update(transports[i]), @@ -191,7 +199,7 @@ def _barycenters( def barycentric_projections( - arg: Union[transport.Transport, jnp.ndarray], + arg: Union[Transport, jnp.ndarray], a: jnp.ndarray = None, b: jnp.ndarray = None, matrix: jnp.ndarray = None, @@ -202,13 +210,17 @@ def barycentric_projections( if ax is None: _, ax = plt.subplots(1, 1, figsize=(8, 5)) - if isinstance(arg, transport.Transport): - ot = arg - return _barycenters(ax, ot.geom.y, ot.a, ot.b, ot.matrix, **kwargs) + if utils.is_jax_array(arg): + if matrix is None: + raise ValueError('The `matrix` argument cannot be None.') + + a = jnp.ones(matrix.shape[0]) / matrix.shape[0] if a is None else a + b = jnp.ones(matrix.shape[1]) / matrix.shape[1] if b is None else b + return _barycenters(ax, arg, a, b, matrix, **kwargs) - if matrix is None: - raise ValueError('The `matrix` argument cannot be None.') + if isinstance(arg, gromov_wasserstein.GWOutput): + geom = arg.linear_state.geom + else: + geom = arg.geom - a = jnp.ones(matrix.shape[0]) / matrix.shape[0] if a is None else a - b = jnp.ones(matrix.shape[1]) / matrix.shape[1] if b is None else b - return _barycenters(ax, arg, a, b, matrix, **kwargs) + return _barycenters(ax, geom.y, arg.a, arg.b, arg.matrix, **kwargs) diff --git a/src/ott/tools/segment_sinkhorn.py b/src/ott/tools/segment_sinkhorn.py index 886430d07..ae6052c60 100644 --- a/src/ott/tools/segment_sinkhorn.py +++ b/src/ott/tools/segment_sinkhorn.py @@ -18,6 +18,7 @@ import jax.numpy as jnp from ott.geometry import costs, pointcloud, segment +from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn @@ -111,19 +112,20 @@ def eval_fn( ) -> float: mask_x = padded_weight_x > 0. mask_y = padded_weight_y > 0. - return sinkhorn.sinkhorn( - pointcloud.PointCloud( - padded_x, - padded_y, - cost_fn=cost_fn, - src_mask=mask_x, - tgt_mask=mask_y, - **kwargs - ), - a=padded_weight_x, - b=padded_weight_y, - **sinkhorn_kwargs - ).reg_ot_cost + + geom = pointcloud.PointCloud( + padded_x, + padded_y, + cost_fn=cost_fn, + src_mask=mask_x, + tgt_mask=mask_y, + **kwargs, + ) + prob = linear_problem.LinearProblem( + geom, a=padded_weight_x, b=padded_weight_y + ) + solver = sinkhorn.Sinkhorn(**sinkhorn_kwargs) + return solver(prob).reg_ot_cost return segment._segment_interface( x, diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 426aeefaf..a362937e4 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -19,7 +19,7 @@ from ott.geometry import costs, geometry, pointcloud, segment from ott.problems.linear import linear_problem, potentials -from ott.solvers.linear import sinkhorn +from ott.solvers.linear import acceleration, sinkhorn __all__ = [ "sinkhorn_divergence", "segment_sinkhorn_divergence", @@ -135,19 +135,19 @@ def _sinkhorn_divergence( geometry_yy: a Cost object able to apply kernels with a certain epsilon, between elements of the view Y. a: jnp.ndarray[n]: the weight of each input point. The sum of - all elements of b must match that of a to converge. + all elements of ``b`` must match that of ``a`` to converge. b: jnp.ndarray[m]: the weight of each target point. The sum of - all elements of b must match that of a to converge. + all elements of ``b`` must match that of ``a`` to converge. symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for symmetric terms comparing x/x and y/y. - kwargs: Keyword arguments to :func:`~ott.solvers.linear.sinkhorn.sinkhorn`. + kwargs: Keyword arguments to :func:`~ott.solvers.linear.sinkhorn.Sinkhorn`. Returns: SinkhornDivergenceOutput named tuple. """ # When computing a Sinkhorn divergence, the (x,y) terms and (x,x) / (y,y) # terms are computed independently. The user might want to pass some - # sinkhorn_kwargs to parameterize sinkhorn's behavior, but those should + # sinkhorn_kwargs to parameterize Sinkhorn's behavior, but those should # only apply to the (x,y) part. For the (x,x) / (y,y) part we fall back # on a simpler choice (parallel_dual_updates + momentum 0.5) that is known # to work well in such settings. In the future we might want to give some @@ -159,18 +159,17 @@ def _sinkhorn_divergence( if symmetric_sinkhorn: kwargs_symmetric.update( parallel_dual_updates=True, - momentum=0.5, - chg_momentum_from=0, - anderson_acceleration=0, - implicit_solver_symmetric=True + momentum=acceleration.Momentum(start=0, value=0.5), + anderson=None, + # TODO(michalk8): implicit_diff ) - out_xy = sinkhorn.sinkhorn(geometry_xy, a, b, **kwargs) - out_xx = sinkhorn.sinkhorn(geometry_xx, a, a, **kwargs_symmetric) + out_xy = sinkhorn.solve(geometry_xy, a, b, **kwargs) + out_xx = sinkhorn.solve(geometry_xx, a, a, **kwargs_symmetric) if geometry_yy is None: - out_yy = sinkhorn.SinkhornOutput(errors=jnp.array([]), reg_ot_cost=0) + out_yy = sinkhorn.SinkhornOutput(errors=jnp.array([]), reg_ot_cost=0.0) else: - out_yy = sinkhorn.sinkhorn(geometry_yy, b, b, **kwargs_symmetric) + out_yy = sinkhorn.solve(geometry_yy, b, b, **kwargs_symmetric) div = ( out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) + diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 4209fe009..c481a41b1 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -20,7 +20,9 @@ import jax.numpy as jnp import numpy as np -from ott.tools import transport +from ott.geometry import pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn def transport_for_sort( @@ -30,7 +32,7 @@ def transport_for_sort( squashing_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, epsilon: float = 1e-2, **kwargs: Any, -) -> jnp.ndarray: +) -> sinkhorn.SinkhornOutput: r"""Solve reg. OT, from inputs to a weighted family of increasing values. Args: @@ -42,7 +44,8 @@ def transport_for_sort( sigmoid of whitened values by default. Can be set to be the identity by passing ``squashing_fun = lambda x : x`` instead. epsilon: the regularization parameter. - kwargs: keyword arguments for `sinkhorn` and / or `PointCloud`. + kwargs: keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. Returns: A jnp.ndarray num_points x num_target transport matrix, from all @@ -64,7 +67,12 @@ def transport_for_sort( num_targets = b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] - return transport.solve(x, y, a=a, b=b, epsilon=epsilon, **kwargs) + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + prob = linear_problem.LinearProblem(geom, a=a, b=b) + + solver = sinkhorn.Sinkhorn(**kwargs) + + return solver(prob) def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray: diff --git a/tests/conftest.py b/tests/conftest.py index d15f8482f..23ed538af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,13 @@ import itertools from typing import Any, Mapping, Optional, Sequence -import jax -import jax.numpy as jnp import pytest from _pytest.config.argparsing import Parser from _pytest.python import Metafunc +import jax +import jax.numpy as jnp + def pytest_generate_tests(metafunc: Metafunc) -> None: if not hasattr(metafunc.function, "pytestmark"): diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index c71bca00d..bf5e88cd7 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for the cost/norm functions.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 72a859d6a..268f7bd74 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -1,16 +1,17 @@ import time from typing import Any, Callable, Optional, Tuple, Union -import jax -import jax.experimental.sparse as jesp -import jax.numpy as jnp import networkx as nx -import numpy as np import pytest from networkx.algorithms import shortest_paths from networkx.generators import balanced_tree, random_graphs from typing_extensions import Literal +import jax +import jax.experimental.sparse as jesp +import jax.numpy as jnp +import numpy as np + from ott.geometry import geometry, graph from ott.math import decomposition from ott.problems.linear import linear_problem diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index 5d4fb0eda..7498c1ba4 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -14,10 +14,11 @@ """Test Low-Rank Geometry.""" from typing import Callable, Optional, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, grid, low_rank, pointcloud diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index 24ee95211..447ef8bb2 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -14,10 +14,11 @@ """Tests for apply_cost and apply_kernel.""" from typing import Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, pointcloud diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index 8a6752608..8978ff378 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -14,10 +14,11 @@ """Tests for the option to scale the cost matrix.""" from typing import Optional, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.linear import linear_problem @@ -60,7 +61,9 @@ def apply_sinkhorn( geom = pointcloud.PointCloud( x, y, epsilon=self.eps, scale_cost=scale_cost ) - out = sinkhorn.sinkhorn(geom, a, b) + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn() + out = solver(prob) transport = geom.transport_from_potentials(out.f, out.g) return geom, out, transport @@ -123,7 +126,9 @@ def apply_sinkhorn( scale_cost: Union[str, float] ): geom = geometry.Geometry(cost, epsilon=self.eps, scale_cost=scale_cost) - out = sinkhorn.sinkhorn(geom, a, b) + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn() + out = solver(prob) transport = geom.transport_from_potentials(out.f, out.g) return geom, out, transport diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index fe102d353..8c4023e1a 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -1,9 +1,10 @@ from typing import Optional, Sequence, Tuple, Type, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 4741f4d9b..4eb7e09d8 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -10,21 +10,27 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for Sinkhorn initializers.""" -import functools +from typing import Any, Literal, Optional + +import pytest import jax import jax.numpy as jnp import numpy as np -import pytest -import ott.initializers.nn.initializers from ott.geometry import geometry, pointcloud -from ott.initializers.linear import initializers as lin_init +from ott.initializers.linear import initializers as linear_init +from ott.initializers.nn import initializers as nn_init from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn -def create_sorting_problem(rng, n, epsilon=0.01, online=False): +def create_sorting_problem( + rng: jnp.ndarray, + n: int, + epsilon: float = 1e-2, + batch_size: Optional[int] = None +) -> linear_problem.LinearProblem: # define ot problem x_init = jnp.array([-1., 0, .22]) y_init = jnp.array([0., 0, 1.1]) @@ -36,24 +42,27 @@ def create_sorting_problem(rng, n, epsilon=0.01, online=False): x = jnp.sort(x) y = jnp.sort(y) - n = len(x) - m = len(y) + n, m = len(x), len(y) a = jnp.ones(n) / n b = jnp.ones(m) / m - batch_size = 3 if online else None geom = pointcloud.PointCloud( x.reshape(-1, 1), y.reshape(-1, 1), epsilon=epsilon, batch_size=batch_size ) - ot_problem = linear_problem.LinearProblem(geom=geom, a=a, b=b) + return linear_problem.LinearProblem(geom=geom, a=a, b=b) - return ot_problem - -def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): +def create_ot_problem( + rng: jnp.ndarray, + n: int, + m: int, + d: int, + epsilon: float = 1e-2, + batch_size: Optional[int] = None +) -> linear_problem.LinearProblem: # define ot problem x_rng, y_rng = jax.random.split(rng) @@ -66,79 +75,55 @@ def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): a = jnp.ones(n) / n b = jnp.ones(m) / m - batch_size = 3 if online else None geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) - ot_problem = linear_problem.LinearProblem(geom=geom, a=a, b=b) - return ot_problem - - -# define sinkhorn functions -@functools.partial(jax.jit, static_argnames=['lse_mode', 'vector_min']) -def run_sinkhorn_sort_init( - x, y, a=None, b=None, epsilon=0.01, vector_min=True, lse_mode=True -): - geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - sort_init = lin_init.SortingInitializer(vectorized_update=vector_min) - out = sinkhorn.sinkhorn( - geom, a=a, b=b, initializer=sort_init, lse_mode=lse_mode - ) - return out - + return linear_problem.LinearProblem(geom=geom, a=a, b=b) + + +def run_sinkhorn( + x: jnp.ndarray, + y: jnp.ndarray, + *, + initializer: Literal["default", "sorting", "gaussian"], + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, + epsilon: float = 1e-2, + lse_mode: bool = True, + **kwargs: Any +) -> sinkhorn.SinkhornOutput: + if initializer == "default": + init = linear_init.DefaultInitializer() + elif initializer == "sorting": + init = linear_init.SortingInitializer(**kwargs) + elif initializer == "gaussian": + init = linear_init.GaussianInitializer(**kwargs) + else: + raise NotImplementedError(initializer) -@functools.partial(jax.jit, static_argnames=['lse_mode']) -def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=lse_mode) - return out - - -@functools.partial(jax.jit, static_argnames=['lse_mode']) -def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): - geom = pointcloud.PointCloud(x, y, epsilon=epsilon) - out = sinkhorn.sinkhorn( - geom, - a=a, - b=b, - initializer=lin_init.GaussianInitializer(), - lse_mode=lse_mode - ) - return out + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, initializer=init) + return solver(prob) @pytest.mark.fast class TestSinkhornInitializers: - def test_init_pytree(self): - - @jax.jit - def init_sort(): - init = lin_init.SortingInitializer() - return init - - @jax.jit - def init_gaus(): - init = lin_init.GaussianInitializer() - return init - - _ = init_gaus() - _ = init_sort() - @pytest.mark.parametrize( "init", [ "default", "gaussian", "sorting", - lin_init.DefaultInitializer(), "non-existent" + linear_init.DefaultInitializer(), "non-existent" ] ) def test_create_initializer(self, init: str): solver = sinkhorn.Sinkhorn(initializer=init) expected_types = { - "default": lin_init.DefaultInitializer, - "gaussian": lin_init.GaussianInitializer, - "sorting": lin_init.SortingInitializer, + "default": linear_init.DefaultInitializer, + "gaussian": linear_init.GaussianInitializer, + "sorting": linear_init.SortingInitializer, } - if isinstance(init, lin_init.SinkhornInitializer): + if isinstance(init, linear_init.SinkhornInitializer): assert solver.create_initializer() is init elif init == "non-existent": with pytest.raises(NotImplementedError, match=r""): @@ -155,71 +140,67 @@ def test_sorting_init(self, vector_min: bool, lse_mode: bool): """Tests sorting dual initializer.""" rng = jax.random.PRNGKey(42) n = 500 - epsilon = 0.01 + epsilon = 1e-2 + + ot_problem = create_sorting_problem(rng=rng, n=n, epsilon=epsilon) - ot_problem = create_sorting_problem( - rng=rng, n=n, epsilon=epsilon, online=False - ) - # run sinkhorn sink_out_base = run_sinkhorn( x=ot_problem.geom.x, y=ot_problem.geom.y, + initializer="default", a=ot_problem.a, b=ot_problem.b, epsilon=epsilon ) - base_num_iter = jnp.sum(sink_out_base.errors > -1) - sink_out_init = run_sinkhorn_sort_init( + sink_out_init = run_sinkhorn( x=ot_problem.geom.x, y=ot_problem.geom.y, + initializer="sorting", a=ot_problem.a, b=ot_problem.b, epsilon=epsilon, - vector_min=vector_min, + vectorized_update=vector_min, lse_mode=lse_mode ) - sort_num_iter = jnp.sum(sink_out_init.errors > -1) # check initializer is better or equal if lse_mode: - assert base_num_iter >= sort_num_iter + assert sink_out_base.converged + assert sink_out_init.converged + assert sink_out_base.n_iters > sink_out_init.n_iters def test_sorting_init_online(self, rng: jnp.ndarray): n = 100 - epsilon = 0.01 + epsilon = 1e-2 ot_problem = create_sorting_problem( - rng=rng, n=n, epsilon=epsilon, online=True + rng=rng, n=n, epsilon=epsilon, batch_size=5 ) - sort_init = lin_init.SortingInitializer(vectorized_update=True) + sort_init = linear_init.SortingInitializer(vectorized_update=True) with pytest.raises(AssertionError, match=r"online"): sort_init.init_dual_a(ot_problem, lse_mode=True) def test_sorting_init_square_cost(self, rng: jnp.ndarray): - n = 100 - m = 150 - d = 1 - epsilon = 0.01 + n, m, d = 100, 150, 1 + epsilon = 1e-2 - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) - sort_init = lin_init.SortingInitializer(vectorized_update=True) + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon) + sort_init = linear_init.SortingInitializer(vectorized_update=True) with pytest.raises(AssertionError, match=r"square"): sort_init.init_dual_a(ot_problem, lse_mode=True) def test_default_initializer(self, rng: jnp.ndarray): """Tests default initializer""" - n = 200 - m = 200 - d = 2 - epsilon = 0.01 + n, m, d = 200, 200, 2 + epsilon = 1e-2 - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, batch_size=3) - default_potential_a = lin_init.DefaultInitializer().init_dual_a( + default_potential_a = linear_init.DefaultInitializer().init_dual_a( ot_problem, lse_mode=True ) - default_potential_b = lin_init.DefaultInitializer().init_dual_b( + default_potential_b = linear_init.DefaultInitializer().init_dual_b( ot_problem, lse_mode=True ) @@ -228,14 +209,12 @@ def test_default_initializer(self, rng: jnp.ndarray): np.testing.assert_array_equal(0., default_potential_b) def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): - n = 200 - m = 200 - d = 2 - epsilon = 0.01 + n, m, d = 200, 200, 2 + epsilon = 1e-2 - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, batch_size=3) - gaus_init = lin_init.GaussianInitializer() + gaus_init = linear_init.GaussianInitializer() new_geom = geometry.Geometry( cost_matrix=ot_problem.geom.cost_matrix, epsilon=epsilon ) @@ -247,50 +226,65 @@ def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): gaus_init.init_dual_a(ot_problem, lse_mode=True) @pytest.mark.parametrize('lse_mode', [True, False]) - def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray): + @pytest.mark.parametrize("jit", [False, True]) + @pytest.mark.parametrize("initializer", ["sorting", "gaussian"]) + def test_initializer_n_iter( + self, rng: jnp.ndarray, lse_mode: bool, jit: bool, + initializer: Literal["sorting", "gaussian"] + ): """Tests Gaussian initializer""" - # define OT problem - n = 200 - m = 200 - d = 2 - epsilon = 0.01 + n, m, d = 200, 200, 2 + epsilon = 1e-2 - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + if initializer == "gaussian": + ot_problem = create_ot_problem( + rng, n, m, d, epsilon=epsilon, batch_size=3 + ) + else: + ot_problem = create_sorting_problem(rng, n=n, epsilon=epsilon) + + if jit: + run_fn = jax.jit( + run_sinkhorn, static_argnames=["initializer", "lse_mode"] + ) + else: + run_fn = run_sinkhorn # run sinkhorn - sink_out = run_sinkhorn( + default_out = run_fn( x=ot_problem.geom.x, y=ot_problem.geom.y, + initializer="default", a=ot_problem.a, b=ot_problem.b, epsilon=epsilon, - lse_mode=lse_mode + lse_mode=lse_mode, ) - base_num_iter = jnp.sum(sink_out.errors > -1) - sink_out = run_sinkhorn_gaus_init( + + init_out = run_fn( x=ot_problem.geom.x, y=ot_problem.geom.y, + initializer=initializer, a=ot_problem.a, b=ot_problem.b, epsilon=epsilon, lse_mode=lse_mode ) - gaus_num_iter = jnp.sum(sink_out.errors > -1) - # check initializer is better if lse_mode: - assert base_num_iter >= gaus_num_iter + assert default_out.converged + assert init_out.converged + assert default_out.n_iters > init_out.n_iters + else: + assert default_out.n_iters >= init_out.n_iters @pytest.mark.parametrize('lse_mode', [True, False]) - def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): + def test_meta_initializer(self, rng: jnp.ndarray, lse_mode: bool): """Tests Meta initializer""" - # define OT problem - n = 200 - m = 200 - d = 2 - epsilon = 0.01 + n, m, d = 200, 200, 2 + epsilon = 1e-2 - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, batch_size=3) a = ot_problem.a b = ot_problem.b geom = ot_problem.geom @@ -299,25 +293,28 @@ def test_meta_initializer(self, lse_mode, rng: jnp.ndarray): sink_out = run_sinkhorn( x=ot_problem.geom.x, y=ot_problem.geom.y, + initializer="default", a=ot_problem.a, b=ot_problem.b, epsilon=epsilon, lse_mode=lse_mode ) - base_num_iter = jnp.sum(sink_out.errors > -1) # overfit the initializer to the problem. - meta_initializer = ott.initializers.nn.initializers.MetaInitializer(geom) + meta_initializer = nn_init.MetaInitializer(geom) for _ in range(100): _, _, meta_initializer.state = meta_initializer.update( meta_initializer.state, a=a, b=b ) - sink_out = sinkhorn.sinkhorn( - geom, a=a, b=b, initializer=meta_initializer, lse_mode=lse_mode - ) - meta_num_iter = jnp.sum(sink_out.errors > -1) + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn(initializer=meta_initializer, lse_mode=lse_mode) + meta_out = solver(prob) # check initializer is better if lse_mode: - assert base_num_iter >= meta_num_iter + assert sink_out.converged + assert meta_out.converged + assert sink_out.n_iters > meta_out.n_iters + else: + assert sink_out.n_iters >= meta_out.n_iters diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index 0a8001f21..b0dbb4757 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -11,10 +11,11 @@ # limitations under the License. """Tests for Sinkhorn initializers.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.initializers.linear import initializers_lr diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index 255c27bc0..2a0a9e4cb 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -11,10 +11,11 @@ # limitations under the License. """Tests for Gromov-Wasserstein initializers.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers as lin_init diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index f3368613b..fa14fd0de 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for the jvp of a custom implementation of lse.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.math import utils as mu diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index 3eaa000fc..a0b853249 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -14,10 +14,11 @@ """Tests for matrix square roots.""" from typing import Any, Callable +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.math import matrix_square_root diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index 2755fffff..02866a130 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -1,7 +1,8 @@ +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem, potentials diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index c16e7ff13..74e91a621 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -14,10 +14,11 @@ import functools from typing import Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, segment from ott.problems.linear import barycenter_problem diff --git a/tests/solvers/linear/discrete_barycenter_test.py b/tests/solvers/linear/discrete_barycenter_test.py index cd0c99f10..2e8a96230 100644 --- a/tests/solvers/linear/discrete_barycenter_test.py +++ b/tests/solvers/linear/discrete_barycenter_test.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax.numpy as jnp import pytest +import jax.numpy as jnp + from ott.geometry import grid, pointcloud from ott.solvers.linear import discrete_barycenter as db diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 4c17c3be8..78ca522cc 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -13,18 +13,18 @@ # limitations under the License. """Tests for the differentiability of reg_ot_cost w.r.t weights/locations.""" import functools -from typing import Tuple +from typing import Callable, List, Optional, Tuple + +import pytest import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, grid, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import implicit_differentiation as implicit_lib from ott.solvers.linear import sinkhorn -from ott.tools import transport class TestSinkhornImplicit: @@ -52,36 +52,34 @@ def test_implicit_differentiation_versus_autodiff( ): epsilon = 0.05 - def loss_g(a, x, implicit=True): - out = sinkhorn.sinkhorn( - geometry.Geometry( - cost_matrix=jnp.sum(x ** 2, axis=1)[:, jnp.newaxis] + - jnp.sum(self.y ** 2, axis=1)[jnp.newaxis, :] - - 2 * jnp.dot(x, self.y.T), - epsilon=epsilon - ), - a=a, - b=self.b, - tau_a=0.9, - tau_b=0.87, - threshold=threshold, - lse_mode=lse_mode, - implicit_differentiation=implicit + def loss_g(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True) -> float: + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + geom = geometry.Geometry( + cost_matrix=jnp.sum(x ** 2, axis=1)[:, jnp.newaxis] + + jnp.sum(self.y ** 2, axis=1)[jnp.newaxis, :] - + 2 * jnp.dot(x, self.y.T), + epsilon=epsilon ) - return out.reg_ot_cost - - def loss_pcg(a, x, implicit=True): - out = sinkhorn.sinkhorn( - pointcloud.PointCloud(x, self.y, epsilon=epsilon), - a=a, - b=self.b, - tau_a=1.0, - tau_b=0.95, - threshold=threshold, - lse_mode=lse_mode, - implicit_differentiation=implicit + prob = linear_problem.LinearProblem( + geom, a=a, b=self.b, tau_a=0.9, tau_b=0.87 + ) + solver = sinkhorn.Sinkhorn( + threshold=threshold, lse_mode=lse_mode, implicit_diff=implicit_diff + ) + return solver(prob).reg_ot_cost + + def loss_pcg( + a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True + ) -> float: + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + geom = pointcloud.PointCloud(x, self.y, epsilon=epsilon) + prob = linear_problem.LinearProblem( + geom, a=a, b=self.b, tau_a=1.0, tau_b=0.95 + ) + solver = sinkhorn.Sinkhorn( + threshold=threshold, lse_mode=lse_mode, implicit_diff=implicit_diff ) - return out.reg_ot_cost + return solver(prob).reg_ot_cost loss = loss_pcg if pcg else loss_g @@ -160,12 +158,13 @@ def test_autograd_sinkhorn( b = b / jnp.sum(b) def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: - return sinkhorn.sinkhorn( - pointcloud.PointCloud(x, y, epsilon=0.1), a=a, b=b, lse_mode=lse_mode - ).reg_ot_cost + geom = pointcloud.PointCloud(x, y, epsilon=1e-1) + prob = linear_problem.LinearProblem(geom, a=a, b=b) + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode) + return solver(prob).reg_ot_cost - reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot)) - _, grad_reg_ot = reg_ot_and_grad(a, b) + reg_ot_and_grad = jax.jit(jax.grad(reg_ot)) + grad_reg_ot = reg_ot_and_grad(a, b) delta = jax.random.uniform(keys[4], (n,)) delta = delta * (a > 0) # ensures only perturbing non-zero coords. delta = delta - jnp.sum(delta) / jnp.sum(a > 0) # center perturbation @@ -195,11 +194,13 @@ def test_gradient_sinkhorn_geometry( delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) eps = 1e-3 # perturbation magnitude - def loss_fn(cm): + def loss_fn(cm: jnp.ndarray): a = jnp.ones(cm.shape[0]) / cm.shape[0] b = jnp.ones(cm.shape[1]) / cm.shape[1] geom = geometry.Geometry(cm, epsilon=0.5) - out = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) + prob = linear_problem.LinearProblem(geom, a=a, b=b) + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode) + out = solver(prob) return out.reg_ot_cost, (geom, out.f, out.g) # first calculation of gradient @@ -231,7 +232,7 @@ def loss_fn(cm): np.testing.assert_array_equal(jnp.isnan(custom_grad), False) @pytest.mark.fast.with_args( - "lse_mode,implicit_differentiation,min_iter,max_iter,epsilon,cost_fn", + "lse_mode,implicit,min_iter,max_iter,epsilon,cost_fn", [ (True, True, 0, 2000, 1e-3, costs.Euclidean()), (True, True, 1000, 1000, 1e-3, costs.Euclidean()), @@ -246,8 +247,8 @@ def loss_fn(cm): only_fast=[0, 1], ) def test_gradient_sinkhorn_euclidean( - self, rng: jnp.ndarray, lse_mode: bool, implicit_differentiation: bool, - min_iter: int, max_iter: int, epsilon: float, cost_fn: costs.CostFn + self, rng: jnp.ndarray, lse_mode: bool, implicit: bool, min_iter: int, + max_iter: int, epsilon: float, cost_fn: costs.CostFn ): """Test gradient w.r.t. locations x of reg-ot-cost.""" # TODO(cuturi): ensure scaling mode works with backprop. @@ -269,16 +270,16 @@ def test_gradient_sinkhorn_euclidean( def loss_fn(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[float, sinkhorn.SinkhornOutput]: + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = pointcloud.PointCloud(x, y, epsilon=epsilon, cost_fn=cost_fn) - out = sinkhorn.sinkhorn( - geom, - a, - b, + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn( lse_mode=lse_mode, - implicit_differentiation=implicit_differentiation, min_iterations=min_iter, max_iterations=max_iter, + implicit_diff=implicit_diff, ) + out = solver(prob) return out.reg_ot_cost, out delta = jax.random.normal(keys[0], (n, d)) @@ -325,7 +326,8 @@ def test_autoepsilon_differentiability(self, rng: jnp.ndarray): def reg_ot_cost(c: jnp.ndarray) -> float: geom = geometry.Geometry(c, epsilon=None) # auto epsilon - return sinkhorn.sinkhorn(geom).reg_ot_cost + prob = linear_problem.LinearProblem(geom) + return sinkhorn.Sinkhorn()(prob).reg_ot_cost gradient = jax.grad(reg_ot_cost)(cost) np.testing.assert_array_equal(jnp.isnan(gradient), False) @@ -335,7 +337,8 @@ def test_differentiability_with_jit(self, rng: jnp.ndarray): def reg_ot_cost(c: jnp.ndarray) -> float: geom = geometry.Geometry(c, epsilon=1e-2) - return sinkhorn.sinkhorn(geom).reg_ot_cost + prob = linear_problem.LinearProblem(geom) + return sinkhorn.Sinkhorn()(prob).reg_ot_cost cost = jax.random.uniform(rng, (15, 17)) gradient = jax.jit(jax.grad(reg_ot_cost))(cost) @@ -388,18 +391,14 @@ def test_apply_transport_jacobian( # general rule, even more so when using backprop. epsilon = 0.01 if lse_mode else 0.1 - def apply_ot(a: jnp.ndarray, x: jnp.ndarray, implicit: bool): - out = transport.solve( - x, - y, - epsilon=epsilon, - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, - lse_mode=lse_mode, - implicit_differentiation=implicit - ) + def apply_ot(a: jnp.ndarray, x: jnp.ndarray, implicit: bool) -> jnp.ndarray: + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) + + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, implicit_diff=implicit_diff) + out = solver(prob) + return out.apply(vec, axis=axis) delta = delta_x if arg else delta_a @@ -489,19 +488,16 @@ def test_potential_jacobian_sinkhorn( # differentiating. epsilon = 0.01 if lse_mode else 0.1 - def loss_from_potential(a, x, implicit): - out = transport.solve( - x, - y, - epsilon=epsilon, - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, - lse_mode=lse_mode, - implicit_differentiation=implicit - ) - return jnp.sum(random_dir * out.solver_output.f) + def loss_from_potential(a: jnp.ndarray, x: jnp.ndarray, implicit: bool): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) + + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, implicit_diff=implicit_diff) + + out = solver(prob) + + return jnp.sum(random_dir * out.f) # Compute implicit gradient loss_imp = jax.jit( @@ -560,14 +556,14 @@ def test_diff_sinkhorn_x_grid_x_perturbation( a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) - def reg_ot(x): + def reg_ot(x: List[jnp.ndarray]) -> float: geom = grid.Grid(x=x, epsilon=1.0) - return sinkhorn.sinkhorn( - geom, a=a, b=b, threshold=0.1, lse_mode=lse_mode - ).reg_ot_cost + prob = linear_problem.LinearProblem(geom, a=a, b=b) + solver = sinkhorn.Sinkhorn(threshold=1e-1, lse_mode=lse_mode) + return solver(prob).reg_ot_cost - reg_ot_and_grad = jax.value_and_grad(reg_ot) - _, grad_reg_ot = reg_ot_and_grad(x) + reg_ot_and_grad = jax.grad(reg_ot) + grad_reg_ot = reg_ot_and_grad(x) delta = [jax.random.uniform(keys[i], (g,)) for i, g in enumerate(grid_size)] x_p_delta = [(xs + eps * delt) for xs, delt in zip(x, delta)] @@ -609,13 +605,13 @@ def test_diff_sinkhorn_x_grid_weights_perturbation( b = b.ravel() / jnp.sum(b) geom = grid.Grid(x=x, epsilon=1) - def reg_ot(a, b): - return sinkhorn.sinkhorn( - geom, a=a, b=b, threshold=0.001, lse_mode=lse_mode - ).reg_ot_cost + def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn(threshold=1e-3, lse_mode=lse_mode) + return solver(prob).reg_ot_cost - reg_ot_and_grad = jax.value_and_grad(reg_ot) - _, grad_reg_ot = reg_ot_and_grad(a, b) + reg_ot_and_grad = jax.grad(reg_ot) + grad_reg_ot = reg_ot_and_grad(a, b) delta = jax.random.uniform(keys[2], grid_size).ravel() delta = delta - jnp.mean(delta) @@ -668,23 +664,22 @@ def test_potential_jacobian_sinkhorn( epsilon = 0.01 if lse_mode else 0.1 def loss_from_potential( - a, x, precondition_fun=None, linear_solve_kwargs=None - ): - if linear_solve_kwargs is None: - linear_solve_kwargs = {} - out = transport.solve( - x, - y, - epsilon=epsilon, - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, - lse_mode=lse_mode, - precondition_fun=precondition_fun, - **linear_solve_kwargs + a: jnp.ndarray, + x: jnp.ndarray, + precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, + symmetric: bool = False + ) -> float: + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) + + implicit_diff = implicit_lib.ImplicitDiff( + symmetric=symmetric, precondition_fun=precondition_fun ) - return jnp.sum(random_dir * out.solver_output.f) + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, implicit_diff=implicit_diff) + + out = solver(prob) + + return jnp.sum(random_dir * out.f) # Compute implicit gradient loss_imp_no_precond = jax.jit( @@ -692,20 +687,18 @@ def loss_from_potential( functools.partial( loss_from_potential, precondition_fun=lambda x: x, - linear_solve_kwargs={'implicit_solver_symmetric': True} + symmetric=True, ), argnums=arg ) ) - loss_imp_log_precond = jax.jit( - jax.value_and_grad(loss_from_potential, argnums=arg) - ) + loss_imp_log_precond = jax.jit(jax.grad(loss_from_potential, argnums=arg)) _, g_imp_np = loss_imp_no_precond(a, x) imp_dif_np = jnp.sum(g_imp_np * (delta_a if arg == 0 else delta_x)) - _, g_imp_lp = loss_imp_log_precond(a, x) + g_imp_lp = loss_imp_log_precond(a, x) imp_dif_lp = jnp.sum(g_imp_lp * (delta_a if arg == 0 else delta_x)) # Compute finite difference diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index 4684fe2eb..6e33c54f6 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -13,12 +13,14 @@ # limitations under the License. """Tests for Sinkhorn when applied on a grid.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import grid, pointcloud +from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn @@ -39,13 +41,9 @@ def test_separable_grid(self, rng: jnp.ndarray, lse_mode: bool): threshold = 0.01 geom = grid.Grid(grid_size=grid_size, epsilon=0.1) - errors = sinkhorn.sinkhorn( - geom, - a=a, - b=b, - threshold=threshold, - lse_mode=lse_mode, - ).errors + prob = linear_problem.LinearProblem(geom, a=a, b=b) + solver = sinkhorn.Sinkhorn(threshold=threshold, lse_mode=lse_mode) + errors = solver(prob).errors err = errors[jnp.isfinite(errors)][-1] assert threshold > err @@ -66,13 +64,8 @@ def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool): jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon) - out_mat = sinkhorn.sinkhorn( - geometry_mat, - a=a, - b=b, - lse_mode=lse_mode, - ) - out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode) + out_mat = sinkhorn.solve(geometry_mat, a=a, b=b) + out_grid = sinkhorn.solve(geometry_grid, a=a, b=b) np.testing.assert_allclose( out_mat.reg_ot_cost, out_grid.reg_ot_cost, rtol=1e-5, atol=1e-5 ) @@ -93,8 +86,8 @@ def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool): jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1) - sink_mat = sinkhorn.sinkhorn(geom_mat, a=a, b=b, lse_mode=lse_mode) - sink_grid = sinkhorn.sinkhorn(geom_grid, a=a, b=b, lse_mode=lse_mode) + sink_mat = sinkhorn.solve(geom_mat, a=a, b=b) + sink_grid = sinkhorn.solve(geom_grid, a=a, b=b) batch_a = 3 batch_b = 4 diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index 5e9778812..48f799ef8 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests Sinkhorn Low-Rank solver with various initializations.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import low_rank, pointcloud from ott.problems.linear import linear_problem diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index dc24354ee..24f9aef48 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -15,14 +15,17 @@ from typing import Optional, Tuple import chex +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, pointcloud from ott.problems.linear import linear_problem -from ott.solvers.linear import acceleration, sinkhorn +from ott.solvers.linear import acceleration +from ott.solvers.linear import implicit_differentiation as implicit_lib +from ott.solvers.linear import sinkhorn class TestSinkhornAnderson: @@ -37,7 +40,7 @@ class TestSinkhornAnderson: only_fast=0, ) def test_anderson( - self, rng: jnp.ndarray, lse_mode: float, tau_a: float, tau_b: float, + self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, shape: Tuple[int, int], refresh_anderson_frequency: int ): """Test efficiency of Anderson acceleration. @@ -71,19 +74,20 @@ def test_anderson( threshold = 1e-3 iterations_anderson = [] - anderson_memory = [0, 5] - for anderson_acceleration in anderson_memory: - out = sinkhorn.sinkhorn( - pointcloud.PointCloud(x, y, epsilon=epsilon), - a=a, - b=b, - tau_a=tau_a, - tau_b=tau_b, + anderson_memory = [None, 5] + for memory in anderson_memory: + anderson = None if memory is None else acceleration.AndersonAcceleration( + memory=memory, refresh_every=refresh_anderson_frequency + ) + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + prob = linear_problem.LinearProblem(geom, a, b, tau_a=tau_a, tau_b=tau_b) + solver = sinkhorn.Sinkhorn( lse_mode=lse_mode, threshold=threshold, - anderson_acceleration=anderson_acceleration, - refresh_anderson_frequency=refresh_anderson_frequency + anderson=anderson, ) + out = solver(prob) + errors = out.errors clean_errors = errors[errors > -1] # Check convergence @@ -146,10 +150,10 @@ def test_bures_point_cloud( cost_fn = costs.Bures(dimension=self.dim, regularization=1e-4) geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=self.eps) + prob = linear_problem.LinearProblem(geom, self.a, self.b) + solver = sinkhorn.Sinkhorn(threshold=thresh, lse_mode=lse_mode) + out = solver(prob) - out = sinkhorn.sinkhorn( - geom, a=self.a, b=self.b, lse_mode=lse_mode, threshold=thresh - ) err = out.errors[out.errors > -1][-1] assert out.converged @@ -194,26 +198,12 @@ def test_online_matches_offline_size(self, batch_size: int): self.x, self.y, epsilon=1, batch_size=batch_size ) - sol_online = sinkhorn.sinkhorn( - geom_online, - a=self.a, - b=self.b, - threshold=threshold, - lse_mode=True, - implicit_differentiation=True - ) + sol_online = sinkhorn.solve(geom_online) errors_online = sol_online.errors err_online = errors_online[errors_online > -1][-1] assert threshold > err_online - sol_offline = sinkhorn.sinkhorn( - geom_offline, - a=self.a, - b=self.b, - threshold=threshold, - lse_mode=True, - implicit_differentiation=True - ) + sol_offline = sinkhorn.solve(geom_offline) np.testing.assert_allclose( sol_online.matrix, sol_offline.matrix, rtol=rtol, atol=atol @@ -232,14 +222,9 @@ def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput: geom = pointcloud.PointCloud( self.x, self.y, epsilon=epsilon, batch_size=batch_size ) - return sinkhorn.sinkhorn( - geom, - a=self.a, - b=self.b, - threshold=threshold, - lse_mode=True, - implicit_differentiation=True - ) + prob = linear_problem.LinearProblem(geom, self.a, self.b) + solver = sinkhorn.Sinkhorn(threshold=threshold) + return solver(prob) threshold = 1e-1 fun = jax.jit(callback, static_argnums=(1,)) if jit else callback @@ -271,18 +256,19 @@ def test_sinkhorn_unbalanced(self, lse_mode: bool, momentum: float): """Two point clouds, tested with various parameters.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) - errors = sinkhorn.sinkhorn( - geom, - a=self.a, - b=self.b, + prob = linear_problem.LinearProblem( + geom, self.a, self.b, tau_a=0.8, tau_b=0.9 + ) + solver = sinkhorn.Sinkhorn( threshold=threshold, - momentum=momentum, - inner_iterations=10, - norm_error=1, lse_mode=lse_mode, - tau_a=0.8, - tau_b=0.9 - ).errors + norm_error=1, + momentum=acceleration.Momentum(value=momentum), + inner_iterations=10 + ) + + errors = solver(prob).errors + err = errors[errors > -1][-1] assert threshold > err assert err > 0 @@ -359,14 +345,17 @@ def initialize(self, rng: jnp.ndarray): @pytest.mark.fast def test_jit_vs_non_jit_fwd(self): - def assert_output_close(x: jnp.ndarray, y: jnp.ndarray) -> None: + def assert_output_close( + x: sinkhorn.SinkhornOutput, y: sinkhorn.SinkhornOutput + ) -> None: """Assert SinkhornOutputs are close.""" x = tuple(a for a in x if (a is not None and isinstance(a, jnp.ndarray))) y = tuple(a for a in y if (a is not None and isinstance(a, jnp.ndarray))) return chex.assert_tree_all_close(x, y, atol=1e-6, rtol=0) - jitted_result = jax.jit(sinkhorn.sinkhorn)(self.geometry, self.a, self.b) - non_jitted_result = sinkhorn.sinkhorn(self.geometry, self.a, self.b) + geom = self.geometry + jitted_result = jax.jit(sinkhorn.solve)(geom, a=self.a, b=self.b) + non_jitted_result = sinkhorn.solve(geom, a=self.a, b=self.b) assert_output_close(non_jitted_result, jitted_result) @@ -374,7 +363,8 @@ def assert_output_close(x: jnp.ndarray, y: jnp.ndarray) -> None: def test_jit_vs_non_jit_bwd(self, implicit: bool): @jax.value_and_grad - def val_grad(a: jnp.ndarray, x: jnp.ndarray): + def val_grad(a: jnp.ndarray, x: jnp.ndarray) -> float: + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None geom = geometry.Geometry( cost_matrix=( jnp.sum(x ** 2, axis=1)[:, jnp.newaxis] + @@ -383,17 +373,11 @@ def val_grad(a: jnp.ndarray, x: jnp.ndarray): ), epsilon=self.epsilon ) - out = sinkhorn.sinkhorn( - geom, - a=a, - b=self.b, - tau_a=0.94, - tau_b=0.97, - threshold=1e-4, - lse_mode=True, - implicit_differentiation=implicit + prob = linear_problem.LinearProblem( + geom, a=a, b=self.b, tau_a=0.94, tau_b=0.97 ) - return out.reg_ot_cost + solver = sinkhorn.Sinkhorn(threshold=1e-4, implicit_diff=implicit_diff) + return solver(prob).reg_ot_cost jitted_loss, jitted_grad = jax.jit(val_grad)(self.a, self.x) non_jitted_loss, non_jitted_grad = val_grad(self.a, self.x) diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index 2fbd2b311..cd6264dfd 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -13,14 +13,15 @@ # limitations under the License. """Tests for Sinkhorn.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, grid, pointcloud from ott.problems.linear import linear_problem -from ott.solvers.linear import sinkhorn +from ott.solvers.linear import acceleration, sinkhorn class TestSinkhorn: @@ -44,7 +45,7 @@ def initialize(self, rng: jnp.ndarray): self.b = b / jnp.sum(b) @pytest.mark.fast.with_args( - "lse_mode,momentum,chg_momentum_from,inner_iterations,norm_error,cost_fn", + "lse_mode,mom_value,mom_start,inner_iterations,norm_error,cost_fn", [(True, 1.0, 29, 10, 1, costs.SqEuclidean()), (False, 1.0, 30, 10, 1, costs.SqPNorm(p=2.2)), (True, 1.0, 60, 1, 2, costs.Euclidean()), @@ -53,23 +54,27 @@ def initialize(self, rng: jnp.ndarray): only_fast=[0, -1], ) def test_euclidean_point_cloud( - self, lse_mode, momentum, chg_momentum_from, inner_iterations, norm_error, - cost_fn + self, + lse_mode: bool, + mom_value: float, + mom_start: int, + inner_iterations: int, + norm_error: int, + cost_fn: costs.CostFn, ): """Two point clouds, tested with various parameters.""" threshold = 1e-3 + momentum = acceleration.Momentum(start=mom_start, value=mom_value) geom = pointcloud.PointCloud(self.x, self.y, cost_fn=cost_fn, epsilon=0.1) - out = sinkhorn.sinkhorn( + out = sinkhorn.solve( geom, a=self.a, b=self.b, - threshold=threshold, - momentum=momentum, - chg_momentum_from=chg_momentum_from, - inner_iterations=inner_iterations, + lse_mode=lse_mode, norm_error=norm_error, - lse_mode=lse_mode + inner_iterations=inner_iterations, + momentum=momentum ) errors = out.errors err = errors[errors > -1][-1] @@ -86,7 +91,7 @@ def test_autoepsilon(self): # needed in principle, but introduced here to test logic. geom_1 = pointcloud.PointCloud(self.x, self.y, relative_epsilon=True) # not jitting - f_1 = sinkhorn.sinkhorn( + f_1 = sinkhorn.solve( geom_1, a=self.a, b=self.b, @@ -97,10 +102,8 @@ def test_autoepsilon(self): # Second geom does not provide whether epsilon is relative. geom_2 = pointcloud.PointCloud(scale * self.x, scale * self.y) # jitting - compute_f = jax.jit( - lambda g, a, b: sinkhorn.sinkhorn(g, a, b, tau_a=.99, tau_b=.97).f - ) - f_2 = compute_f(geom_2, self.a, self.b) + compute_f = jax.jit(sinkhorn.solve, static_argnames=["tau_a", "tau_b"]) + f_2 = compute_f(geom_2, self.a, self.b, tau_a=0.99, tau_b=0.97).f # Ensure epsilon and optimal f's are a scale^2 apart (^2 comes from ^2 cost) np.testing.assert_allclose( @@ -129,32 +132,38 @@ def test_autoepsilon_with_decay( tau_b: float ): """Check that variations in init/decay work, and result in same solution.""" - - @jax.jit - def run_sinkhorn(geom: pointcloud.PointCloud) -> sinkhorn.SinkhornOutput: - return sinkhorn.sinkhorn( - geom, - a=self.a, - b=self.b, - tau_a=tau_a, - tau_b=tau_b, - lse_mode=lse_mode, - threshold=1e-5 - ) - geom1 = pointcloud.PointCloud(self.x, self.y, init=init, decay=decay) geom2 = pointcloud.PointCloud(self.x, self.y) - out_1 = run_sinkhorn(geom1) - out_2 = run_sinkhorn(geom2) - # recenter if problem is balanced, since in that case solution is only - # valid up to additive constant. - if out_1.ot_prob.is_balanced: - # TODO(michalk8): remove after https://github.com/ott-jax/ott/pull/194 - f_1 = out_1.f - jnp.mean(out_1.f[jnp.isfinite(out_1.f)]) - f_2 = out_2.f - jnp.mean(out_2.f[jnp.isfinite(out_2.f)]) - else: - f_1, f_2 = out_1.f, out_2.f + run_fn = jax.jit( + sinkhorn.solve, + static_argnames=[ + "tau_a", "tau_b", "lse_mode", "threshold", "recenter_potentials" + ] + ) + out_1 = run_fn( + geom1, + self.a, + self.b, + tau_a=tau_a, + tau_b=tau_b, + lse_mode=lse_mode, + threshold=1e-5, + recenter_potentials=True + ) + out_2 = run_fn( + geom2, + self.a, + self.b, + tau_a=tau_a, + tau_b=tau_b, + lse_mode=lse_mode, + threshold=1e-5, + recenter_potentials=True + ) + # recenter the problem, since in that case solution is only + # valid up to additive constant in the balanced case + f_1, f_2 = out_1.f, out_2.f np.testing.assert_allclose(f_1, f_2, rtol=1e-4, atol=1e-4) @pytest.mark.fast @@ -162,13 +171,12 @@ def test_euclidean_point_cloud_min_iter(self): """Testing the min_iterations parameter.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) - errors = sinkhorn.sinkhorn( + errors = sinkhorn.solve( geom, a=self.a, b=self.b, threshold=threshold, min_iterations=34, - implicit_differentiation=False ).errors err = errors[jnp.logical_and(errors > -1, jnp.isfinite(errors))][-1] assert threshold > err @@ -182,8 +190,8 @@ def test_geom_vs_point_cloud(self): geom_1 = pointcloud.PointCloud(self.x, self.y) geom_2 = geometry.Geometry(geom_1.cost_matrix) - f_1 = sinkhorn.sinkhorn(geom_1, a=self.a, b=self.b).f - f_2 = sinkhorn.sinkhorn(geom_2, a=self.a, b=self.b).f + f_1 = sinkhorn.solve(geom_1, a=self.a, b=self.b).f + f_2 = sinkhorn.solve(geom_2, a=self.a, b=self.b).f # re-centering to remove ambiguity on equality up to additive constant. f_1 -= jnp.mean(f_1[jnp.isfinite(f_1)]) f_2 -= jnp.mean(f_2[jnp.isfinite(f_2)]) @@ -195,7 +203,7 @@ def test_online_euclidean_point_cloud(self, lse_mode: bool): """Testing the online way to handle geometry.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, batch_size=5) - errors = sinkhorn.sinkhorn( + errors = sinkhorn.solve( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode ).errors err = errors[errors > -1][-1] @@ -218,20 +226,20 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): self.x, self.y, cost_fn=costs.SqEuclidean(), epsilon=eps ) - out_online = sinkhorn.sinkhorn( + out_online = sinkhorn.solve( online_geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode ) - out_batch = sinkhorn.sinkhorn( + out_batch = sinkhorn.solve( batch_geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode ) - out_online_euc = sinkhorn.sinkhorn( + out_online_euc = sinkhorn.solve( online_geom_euc, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode ) - out_batch_euc = sinkhorn.sinkhorn( + out_batch_euc = sinkhorn.solve( batch_geom_euc, a=self.a, b=self.b, @@ -293,16 +301,16 @@ def test_apply_transport_geometry_from_potentials(self): for j, lse_mode in enumerate([True, False]): for i, batch_size in enumerate([16, None]): geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=0.2) - sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) + out = sinkhorn.solve(geom, a, b, lse_mode=lse_mode) transport_t_vec_a[i + 2 * j] = geom.apply_transport_from_potentials( - sink.f, sink.g, vec_a, axis=0 + out.f, out.g, vec_a, axis=0 ) transport_vec_b[i + 2 * j] = geom.apply_transport_from_potentials( - sink.f, sink.g, vec_b, axis=1 + out.f, out.g, vec_b, axis=1 ) - transport = geom.transport_from_potentials(sink.f, sink.g) + transport = geom.transport_from_potentials(out.f, out.g) np.testing.assert_allclose( transport_t_vec_a[i + 2 * j], @@ -347,10 +355,10 @@ def test_apply_transport_geometry_from_scalings(self): for j, lse_mode in enumerate([True, False]): for i, batch_size in enumerate([64, None]): geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=0.2) - sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) + out = sinkhorn.solve(geom, a, b, lse_mode=lse_mode) - u = geom.scaling_from_potential(sink.f) - v = geom.scaling_from_potential(sink.g) + u = geom.scaling_from_potential(out.f) + v = geom.scaling_from_potential(out.g) transport_t_vec_a[i + 2 * j] = geom.apply_transport_from_scalings( u, v, vec_a, axis=0 @@ -390,7 +398,7 @@ def test_restart(self, lse_mode: bool): """Two point clouds, tested with various parameters.""" threshold = 1e-4 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.01) - out = sinkhorn.sinkhorn( + out = sinkhorn.solve( geom, a=self.a, b=self.b, @@ -424,16 +432,11 @@ def test_restart(self, lse_mode: bool): with pytest.raises(AssertionError): np.testing.assert_allclose(default_b, init_dual_b) - out_restarted = sinkhorn.sinkhorn( - geom, - a=self.a, - b=self.b, - threshold=threshold, - lse_mode=lse_mode, - init_dual_a=init_dual_a, - init_dual_b=init_dual_b, - inner_iterations=1 + prob = linear_problem.LinearProblem(geom, a=self.a, b=self.b) + solver = sinkhorn.Sinkhorn( + threshold=threshold, lse_mode=lse_mode, inner_iterations=1 ) + out_restarted = solver(prob, (init_dual_a, init_dual_b)) errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] diff --git a/tests/solvers/nn/icnn_test.py b/tests/solvers/nn/icnn_test.py index d3b0af292..830b9019b 100644 --- a/tests/solvers/nn/icnn_test.py +++ b/tests/solvers/nn/icnn_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for ICNN network architecture.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.solvers.nn import icnn diff --git a/tests/solvers/nn/neuraldual_test.py b/tests/solvers/nn/neuraldual_test.py index 526d8903f..2ef0e9d57 100644 --- a/tests/solvers/nn/neuraldual_test.py +++ b/tests/solvers/nn/neuraldual_test.py @@ -13,11 +13,12 @@ """Tests for implementation of ICNN-based Kantorovich dual by Makkuva+(2020).""" from typing import Iterator, Sequence, Tuple +import pytest +from typing_extensions import Literal + import jax import jax.numpy as jnp import numpy as np -import pytest -from typing_extensions import Literal from ott.solvers.nn import neuraldual diff --git a/tests/solvers/quadratic/fgw_barycenter_test.py b/tests/solvers/quadratic/fgw_barycenter_test.py index 2c9cc4268..63d6e5513 100644 --- a/tests/solvers/quadratic/fgw_barycenter_test.py +++ b/tests/solvers/quadratic/fgw_barycenter_test.py @@ -1,10 +1,11 @@ """Tests for Fused Gromov-Wasserstein barycenter.""" from typing import Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import pointcloud from ott.problems.quadratic import gw_barycenter as gwb diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 8d1d33f03..e6019d4e9 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for the Fused Gromov Wasserstein.""" -from typing import Tuple, Union +from typing import Literal, Tuple, Union + +import pytest import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem +from ott.solvers.linear import implicit_differentiation as implicit_lib +from ott.solvers.linear import sinkhorn +from ott.solvers.quadratic import gromov_wasserstein from ott.solvers.quadratic import gromov_wasserstein as gw_solver @@ -48,209 +52,109 @@ def initialize(self, rng: jnp.ndarray): self.cy = jax.random.uniform(keys[5], (self.m, self.m)) self.cxy = jax.random.uniform(keys[6], (self.n, self.m)) - def test_fgw_flag_store_errors_fused(self): - """Tests whether errors are properly stored if requested.""" - threshold_sinkhorn = 1e-2 - geom_x = pointcloud.PointCloud(self.x) - geom_y = pointcloud.PointCloud(self.y) - geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) - out = gw_solver.gromov_wasserstein( - geom_xx=geom_x, - geom_yy=geom_y, - geom_xy=geom_xy, - fused_penalty=self.fused_penalty, - a=self.a, - b=self.b, - epsilon=.1 - ).errors - assert out is None - - out = gw_solver.gromov_wasserstein( - geom_xx=geom_x, - geom_yy=geom_y, - geom_xy=geom_xy, - fused_penalty=self.fused_penalty, - a=self.a, - b=self.b, - epsilon=.1, - store_inner_errors=True, - sinkhorn_kwargs={ - 'threshold': threshold_sinkhorn - } - ).errors - out = out[jnp.sum(out > 0, axis=1) > 0, :] - last_errors = out[-1, :] - - assert threshold_sinkhorn > last_errors[last_errors > -1][-1] - assert out.ndim == 2 - @pytest.mark.fast.with_args("jit", [False, True], only_fast=0) def test_gradient_marginals_fgw_solver(self, jit: bool): """Test gradient w.r.t. probability weights.""" geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) - fused_penalty = self.fused_penalty def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001 - } - out = gw_solver.gromov_wasserstein( - geom_x, - geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - a=a, - b=b, - epsilon=1.0, - loss='sqeucl', - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b ) - return out.reg_gw_cost, (out.linear_state.f, out.linear_state.g) - - grad_matrices = [None, None] - for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad(reg_gw, has_aux=True, argnums=(0, 1)) - if jit: - reg_gw_and_grad = jax.jit(reg_gw_and_grad, static_argnames="implicit") - (_, aux), grad_reg_gw = reg_gw_and_grad(self.a, self.b, implicit) - grad_matrices[i] = grad_reg_gw - grad_manual_a = aux[0] - jnp.log(self.a) - grad_manual_b = aux[1] - jnp.log(self.b) - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) - assert not jnp.any(jnp.isnan(grad_reg_gw[1])) - np.testing.assert_allclose( - grad_manual_a, grad_reg_gw[0], rtol=1e-2, atol=1e-2 + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + linear_solver = sinkhorn.Sinkhorn( + implicit_diff=implicit_diff, max_iterations=1000 ) - np.testing.assert_allclose( - grad_manual_b, grad_reg_gw[1], rtol=1e-2, atol=1e-2 + solver = gromov_wasserstein.GromovWasserstein( + linear_ot_solver=linear_solver, epsilon=1.0 ) - np.testing.assert_allclose( - grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 - ) - np.testing.assert_allclose( - grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 - ) - - @pytest.mark.fast.with_args(lse_mode=[False, True], only_fast=1) - def test_fgw_solver_pointcloud(self, lse_mode: bool): - """Test basic computations pointclouds.""" - def reg_gw(x, y, x_2, y_2, fused_penalty, a, b): - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - geom_xy = pointcloud.PointCloud(x_2, y_2) - return gw_solver.gromov_wasserstein( - geom_x, - geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - a=a, - b=b, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs={ - "lse_mode": lse_mode - }, - ).reg_gw_cost - - cost = reg_gw( - self.x, self.y, self.x_2, self.y_2, self.fused_penalty, self.a, self.b - ) - assert cost is not None + out = solver(prob) - @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fgw_solver_pointcloud(self, lse_mode: bool): - """Test gradient w.r.t. pointclouds.""" - - def reg_gw(x, y, x_2, y_2, fused_penalty, a, b, implicit): - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - geom_xy = pointcloud.PointCloud(x_2, y_2) - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001, - 'lse_mode': lse_mode - } - return gw_solver.gromov_wasserstein( - geom_x, - geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - a=a, - b=b, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs - ).reg_gw_cost + return out.reg_gw_cost, (out.linear_state.f, out.linear_state.g) grad_matrices = [None, None] + reg_fgw_grad = jax.grad(reg_gw, has_aux=True, argnums=(0, 1)) + if jit: + reg_fgw_grad = jax.jit(reg_fgw_grad, static_argnames="implicit") + for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad(reg_gw, argnums=(0, 1)) - _, grad_reg_gw = reg_gw_and_grad( - self.x, self.y, self.x_2, self.y_2, self.fused_penalty, self.a, - self.b, implicit + (g_a, g_b), aux = reg_fgw_grad(self.a, self.b, implicit) + grad_matrices[i] = (g_a, g_b) + grad_manual_a = aux[0] - jnp.log(self.a) + grad_manual_b = aux[1] - jnp.log(self.b) + assert not jnp.any(jnp.isnan(g_a)) + assert not jnp.any(jnp.isnan(g_b)) + np.testing.assert_allclose(grad_manual_a, g_a, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(grad_manual_b, g_b, rtol=1e-2, atol=1e-2) + + gi_a, gi_b = grad_matrices[0] + g_a, g_b = grad_matrices[1] + + np.testing.assert_allclose(g_a, gi_a, rtol=1e-02, atol=1e-02) + np.testing.assert_allclose(g_b, gi_b, rtol=1e-02, atol=1e-02) + + @pytest.mark.parametrize( + "lse_mode,is_cost", [(True, False), (False, True)], + ids=["lse-pc", "kernel-cost-mat"] + ) + def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool): + """Test gradient w.r.t. the geometries.""" + + def reg_gw( + x: jnp.ndarray, y: jnp.ndarray, + xy: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool + ): + if is_cost: + geom_x = geometry.Geometry(cost_matrix=x) + geom_y = geometry.Geometry(cost_matrix=y) + geom_xy = geometry.Geometry(cost_matrix=xy) + else: + geom_x = pointcloud.PointCloud(x) + geom_y = pointcloud.PointCloud(y) + geom_xy = pointcloud.PointCloud(xy[0], xy[1]) + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, geom_xy, fused_penalty=fused_penalty, a=a, b=b ) - grad_matrices[i] = grad_reg_gw - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) - assert not jnp.any(jnp.isnan(grad_reg_gw[1])) - - np.testing.assert_allclose( - grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 - ) - np.testing.assert_allclose( - grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 - ) - @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_fgw_solver_geometry(self, lse_mode: bool): - """Test gradient w.r.t. cost matrices.""" + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + linear_solver = sinkhorn.Sinkhorn( + lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=1000 + ) + solver = gromov_wasserstein.GromovWasserstein( + linear_ot_solver=linear_solver, epsilon=1.0, max_iterations=10 + ) - def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): - geom_x = geometry.Geometry(cost_matrix=cx) - geom_y = geometry.Geometry(cost_matrix=cy) - geom_xy = geometry.Geometry(cost_matrix=cxy) - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001, - 'lse_mode': lse_mode - } - return gw_solver.gromov_wasserstein( - geom_x, - geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - a=a, - b=b, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs - ).reg_gw_cost + return solver(prob).reg_gw_cost + if is_cost: + x, y, xy = self.cx, self.cy, self.cxy + else: + x, y, xy = self.x, self.y, (self.x_2, self.y_2) grad_matrices = [None, None] + reg_fgw_grad = jax.grad(reg_gw, argnums=(0, 1, 2)) + for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad(reg_gw, argnums=(0, 1, 2)) - _, grad_reg_gw = reg_gw_and_grad( - self.cx, self.cy, self.cxy, self.fused_penalty, self.a, self.b, - implicit + grad_matrices[i] = reg_fgw_grad( + x, y, xy, self.fused_penalty, self.a, self.b, implicit ) - grad_matrices[i] = grad_reg_gw - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) - assert not jnp.any(jnp.isnan(grad_reg_gw[1])) + assert not jnp.any(jnp.isnan(grad_matrices[i][0])) + assert not jnp.any(jnp.isnan(grad_matrices[i][1])) - np.testing.assert_allclose( - grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 - ) - np.testing.assert_allclose( - grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 - ) - np.testing.assert_allclose( - grad_matrices[0][2], grad_matrices[1][2], rtol=1e-02, atol=1e-02 - ) + gi_x, gi_y, gi_xy = grad_matrices[0] + g_x, g_y, g_xy = grad_matrices[1] + + np.testing.assert_allclose(g_x, gi_x, rtol=1e-02, atol=1e-02) + np.testing.assert_allclose(g_y, gi_y, rtol=1e-02, atol=1e-02) + if is_cost: + np.testing.assert_allclose(g_xy, gi_xy, rtol=1e-02, atol=1e-02) + else: + np.testing.assert_allclose(g_xy[0], gi_xy[0], rtol=1e-02, atol=1e-02) + np.testing.assert_allclose(g_xy[1], gi_xy[1], rtol=1e-02, atol=1e-02) def test_fgw_adaptive_threshold(self): """Checking solution is improved with smaller threshold for convergence.""" @@ -260,99 +164,60 @@ def test_fgw_adaptive_threshold(self): # without warm start for calls to sinkhorn def loss_thre(threshold: float) -> float: - return gw_solver.gromov_wasserstein( - geom_xx=geom_x, - geom_yy=geom_y, - geom_xy=geom_xy, - fused_penalty=self.fused_penalty_2, + prob = quadratic_problem.QuadraticProblem( + geom_x, + geom_y, + geom_xy, a=self.a, b=self.b, - epsilon=.1, - threshold=threshold - ).reg_gw_cost + fused_penalty=self.fused_penalty_2 + ) + solver = gromov_wasserstein.GromovWasserstein( + threshold=threshold, epsilon=1e-1 + ) + + return solver(prob).reg_gw_cost - assert loss_thre(1e-1) > loss_thre(1e-3) + assert loss_thre(1e-1) > loss_thre(1e-4) assert loss_thre(1e-3) > loss_thre(1e-5) @pytest.mark.parametrize("lse_mode", [False, True]) def test_gradient_fgw_solver_penalty(self, lse_mode: bool): """Test gradient w.r.t. penalty.""" - def reg_gw(cx, cy, cxy, fused_penalty, a, b, implicit): + def reg_gw( + cx: jnp.ndarray, cy: jnp.ndarray, cxy: jnp.ndarray, + fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool + ) -> float: geom_x = geometry.Geometry(cost_matrix=cx) geom_y = geometry.Geometry(cost_matrix=cy) geom_xy = geometry.Geometry(cost_matrix=cxy) - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001, - 'lse_mode': lse_mode - } - return gw_solver.gromov_wasserstein( - geom_x, - geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - a=a, - b=b, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs - ).reg_gw_cost + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, geom_xy, a=a, b=b, fused_penalty=fused_penalty + ) + + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + linear_solver = sinkhorn.Sinkhorn( + lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=1000 + ) + solver = gromov_wasserstein.GromovWasserstein( + epsilon=1.0, max_iterations=10, linear_ot_solver=linear_solver + ) + return solver(prob).reg_gw_cost grad_matrices = [None, None] for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad(reg_gw, argnums=(3,)) - _, grad_reg_gw = reg_gw_and_grad( + reg_fgw_grad = jax.grad(reg_gw, argnums=(3,)) + grad_matrices[i] = reg_fgw_grad( self.cx, self.cy, self.cxy, self.fused_penalty, self.a, self.b, implicit ) - grad_matrices[i] = grad_reg_gw - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) + assert not jnp.any(jnp.isnan(grad_matrices[i][0])) + np.testing.assert_allclose( grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 ) - def test_effect_fused_penalty(self): - - def reg_fgw(x, y, x_2, y_2, fused_penalty, a, b): - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - geom_xy = pointcloud.PointCloud(x_2, y_2) - sinkhorn_kwargs = {'max_iterations': 1001} - return gw_solver.gromov_wasserstein( - geom_x, - geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - a=a, - b=b, - epsilon=1.0, - sinkhorn_kwargs=sinkhorn_kwargs - ) - - def reg_gw(x, y, a, b): - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - sinkhorn_kwargs = {'max_iterations': 1001} - return gw_solver.gromov_wasserstein( - geom_x, - geom_y, - a=a, - b=b, - epsilon=1.0, - sinkhorn_kwargs=sinkhorn_kwargs - ) - - fgw_output = reg_fgw( - self.x, self.y, self.x_2, self.y_2, self.fused_penalty, self.a, self.b - ) - gw_output = reg_gw(self.x, self.y, self.a, self.b) - assert fgw_output.reg_gw_cost > gw_output.reg_gw_cost - with pytest.raises(AssertionError): - np.testing.assert_array_almost_equal( - fgw_output.matrix[0, 0], gw_output.matrix[0, 0] - ) - @pytest.mark.limit_memory("400 MB") @pytest.mark.parametrize("jit", [False, True]) def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): @@ -366,17 +231,14 @@ def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) geom_xy = pointcloud.PointCloud(xx, yy) + prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, geom_xy) - solver = gw_solver.gromov_wasserstein + solver = gromov_wasserstein.GromovWasserstein(rank=5) if jit: solver = jax.jit(solver, static_argnames="rank") - ot_gwlr = solver( - geom_x, - geom_y, - geom_xy, - rank=5, - ) + ot_gwlr = solver(prob) + res0 = ot_gwlr.apply(x.T, axis=0) res1 = ot_gwlr.apply(y.T, axis=1) @@ -399,19 +261,19 @@ def test_fgw_lr_generic_cost_matrix( geom_y = geometry.Geometry(cost_matrix=y @ y.T) geom_xy = geometry.Geometry(cost_matrix=xx @ yy.T) - problem = quadratic_problem.QuadraticProblem( + prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, geom_xy, ranks=cost_rank, tolerances=5e-1 ) - assert problem._is_low_rank_convertible - lr_prob = problem.to_low_rank() + assert prob._is_low_rank_convertible + lr_prob = prob.to_low_rank() assert lr_prob.is_low_rank - solver = gw_solver.GromovWasserstein(rank=5, epsilon=1) - out = solver(problem) + solver = gw_solver.GromovWasserstein(rank=5, epsilon=1.0) + out = solver(prob) assert solver.rank == 5 # make sure we don't modify the problem in-place - for geom in [problem.geom_xx, problem.geom_yy, problem.geom_xy]: + for geom in [prob.geom_xx, prob.geom_yy, prob.geom_xy]: assert not isinstance(geom, low_rank.LRCGeometry) ranks = (cost_rank,) * 3 if isinstance(cost_rank, int) else cost_rank for rank, geom in zip( @@ -422,3 +284,38 @@ def test_fgw_lr_generic_cost_matrix( assert out.converged assert out.reg_gw_cost > 0 np.testing.assert_array_equal(jnp.isfinite(out.costs), True) + + @pytest.mark.parametrize("scale_cost", ["mean", "max_cost"]) + def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]): + epsilon = 0.1 + fused_penalty = 1 + geom_x = pointcloud.PointCloud(self.x, scale_cost=1.) + geom_y = pointcloud.PointCloud(self.y, scale_cost=1.) + geom_xy = pointcloud.PointCloud(self.x_2, self.y_2, scale_cost=1.) + geom_x_scaled = pointcloud.PointCloud(self.x, scale_cost=scale_cost) + geom_y_scaled = pointcloud.PointCloud(self.y, scale_cost=scale_cost) + geom_xy_scaled = pointcloud.PointCloud( + self.x_2, self.y_2, scale_cost=scale_cost + ) + + prob_no_scale = quadratic_problem.QuadraticProblem( + geom_x_scaled, + geom_y_scaled, + geom_xy_scaled, + fused_penalty=fused_penalty, + scale_cost=False + ) + prob_scale = quadratic_problem.QuadraticProblem( + geom_x, + geom_y, + geom_xy, + fused_penalty=fused_penalty, + scale_cost=scale_cost + ) + solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon) + + gt = solver(prob_scale) + pred = solver(prob_no_scale) + + np.testing.assert_allclose(pred.matrix, gt.matrix) + np.testing.assert_allclose(pred.costs, gt.costs) diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index d6b991a4b..dc6bf7242 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -13,10 +13,11 @@ """Tests for Gromov-Wasserstein barycenter.""" from typing import Any, Optional, Sequence, Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import pointcloud from ott.problems.quadratic import gw_barycenter as gwb diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index 1aa5609ca..a461b1a0c 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -14,13 +14,16 @@ """Tests for the Gromov Wasserstein.""" from typing import Tuple, Union +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem +from ott.solvers.linear import implicit_differentiation as implicit_lib +from ott.solvers.linear import sinkhorn from ott.solvers.quadratic import gromov_wasserstein @@ -110,64 +113,57 @@ def initialize(self, rng: jnp.ndarray): d_x = 2 d_y = 3 self.n, self.m = 6, 7 - keys = jax.random.split(rng, 8) + keys = jax.random.split(rng, 6) self.x = jax.random.uniform(keys[0], (self.n, d_x)) self.y = jax.random.uniform(keys[1], (self.m, d_y)) - a = jax.random.uniform(keys[2], (self.n,)) + 0.1 - b = jax.random.uniform(keys[3], (self.m,)) + 0.1 + a = jax.random.uniform(keys[2], (self.n,)) + 1e-1 + b = jax.random.uniform(keys[3], (self.m,)) + 1e-1 self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) self.cx = jax.random.uniform(keys[4], (self.n, self.n)) self.cy = jax.random.uniform(keys[5], (self.m, self.m)) - self.xx = jax.random.uniform(keys[6], (self.n, d_x)) - self.yy = jax.random.uniform(keys[7], (self.m, d_x)) + self.tau_a = 0.8 + self.tau_b = 0.9 def test_flag_store_errors(self): """Tests whether errors are properly stored if requested.""" threshold_sinkhorn = 1e-2 geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) - out = gromov_wasserstein.gromov_wasserstein( - geom_xx=geom_x, geom_yy=geom_y, a=self.a, b=self.b, epsilon=.1 - ).errors - assert out is None - - out = gromov_wasserstein.gromov_wasserstein( - geom_xx=geom_x, - geom_yy=geom_y, - a=self.a, - b=self.b, - epsilon=.1, - store_inner_errors=True, - sinkhorn_kwargs={ - 'threshold': threshold_sinkhorn - } - ).errors - - out = out[jnp.sum(out > 0, axis=1) > 0, :] - last_errors = out[-1, :] + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, a=self.a, b=self.b + ) + + solver = gromov_wasserstein.GromovWasserstein( + epsilon=1e-1, store_inner_errors=False + ) + assert solver(prob).errors is None + + solver = gromov_wasserstein.GromovWasserstein( + epsilon=1e-1, store_inner_errors=True + ) + errors = solver(prob).errors + + assert errors.ndim == 2 + errors = errors[jnp.sum(errors > 0, axis=1) > 0, :] + last_errors = errors[-1, :] assert threshold_sinkhorn > last_errors[last_errors > -1][-1] - assert out.ndim == 2 @pytest.mark.parametrize("jit", [False, True]) def test_gradient_marginals_gw(self, jit: bool): """Test gradient w.r.t. probability weights.""" - def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001 - } - out = gromov_wasserstein.gromov_wasserstein( - geom_x, - geom_y, - a=a, - b=b, - epsilon=1.0, - loss='sqeucl', - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs + def reg_gw(a: jnp.ndarray, b: jnp.ndarray, + implicit: bool) -> Tuple[float, Tuple[jnp.ndarray, jnp.ndarray]]: + prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, a=a, b=b) + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + linear_solver = sinkhorn.Sinkhorn( + implicit_diff=implicit_diff, max_iterations=1000 + ) + solver = gromov_wasserstein.GromovWasserstein( + epsilon=1.0, max_iterations=10, linear_ot_solver=linear_solver ) + out = solver(prob) return out.reg_gw_cost, (out.linear_state.f, out.linear_state.g) geom_x = pointcloud.PointCloud(self.x) @@ -175,11 +171,11 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): grad_matrices = [None, None] for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad(reg_gw, has_aux=True, argnums=(0, 1)) + reg_gw_grad = jax.grad(reg_gw, has_aux=True, argnums=(0, 1)) if jit: - reg_gw_and_grad = jax.jit(reg_gw_and_grad, static_argnames="implicit") + reg_gw_grad = jax.jit(reg_gw_grad, static_argnames="implicit") - (_, aux), grad_reg_gw = reg_gw_and_grad(self.a, self.b, implicit) + grad_reg_gw, aux = reg_gw_grad(self.a, self.b, implicit) grad_matrices[i] = grad_reg_gw grad_manual_a = aux[0] - jnp.log(self.a) grad_manual_b = aux[1] - jnp.log(self.b) @@ -200,96 +196,81 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): ) @pytest.mark.fast - def test_gw_pointcloud(self): + @pytest.mark.parametrize("unbalanced", [False, True]) + def test_gw_pointcloud(self, unbalanced: bool): """Test basic computations pointclouds.""" - def reg_gw(x, y, a, b): + def reg_gw( + x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray + ) -> float: geom_x = pointcloud.PointCloud(x) geom_y = pointcloud.PointCloud(y) - return gromov_wasserstein.gromov_wasserstein( - geom_x, geom_y, a=a, b=b, epsilon=1.0, max_iterations=10 - ).reg_gw_cost + tau_a, tau_b = (self.tau_a, self.tau_b) if unbalanced else (1.0, 1.0) + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, a=a, b=b, tau_a=tau_a, tau_b=tau_b + ) + solver = gromov_wasserstein.GromovWasserstein( + epsilon=1.0, max_iterations=10 + ) + return solver(prob).reg_gw_cost assert not jnp.isnan(reg_gw(self.x, self.y, self.a, self.b)) - @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_gw_pointcloud(self, lse_mode: bool): - """Test gradient w.r.t. pointclouds.""" - - def reg_gw(x, y, a, b, implicit): - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001, - 'lse_mode': lse_mode - } - return gromov_wasserstein.gromov_wasserstein( + @pytest.mark.parametrize( + "unbalanced,unbalanced_correction", [(False, False), (True, False), + (True, True)], + ids=["bal", "unbal-nocorr", "unbal-corr"] + ) + @pytest.mark.parametrize( + "lse_mode,is_cost", [(True, False), (False, True)], + ids=["lse-pc", "kernel-cost-mat"] + ) + def test_gradient_gw_geometry( + self, lse_mode: bool, is_cost: bool, unbalanced: bool, + unbalanced_correction: bool + ): + """Test gradient w.r.t. the geometries.""" + + def reg_gw( + x: jnp.ndarray, y: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + implicit: bool + ) -> float: + if is_cost: + geom_x = geometry.Geometry(cost_matrix=x) + geom_y = geometry.Geometry(cost_matrix=y) + else: + geom_x = pointcloud.PointCloud(x) + geom_y = pointcloud.PointCloud(y) + tau_a, tau_b = (self.tau_a, self.tau_b) if unbalanced else (1.0, 1.0) + prob = quadratic_problem.QuadraticProblem( geom_x, geom_y, a=a, b=b, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs - ).reg_gw_cost - - grad_matrices = [None, None] - for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad( - reg_gw, argnums=( - 0, - 1, - ) + tau_a=tau_a, + tau_b=tau_b, + gw_unbalanced_correction=unbalanced_correction ) - _, grad_reg_gw = reg_gw_and_grad(self.x, self.y, self.a, self.b, implicit) - grad_matrices[i] = grad_reg_gw - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) - assert not jnp.any(jnp.isnan(grad_reg_gw[1])) - np.testing.assert_allclose( - grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 - ) - np.testing.assert_allclose( - grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 - ) + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None + lin_solver = sinkhorn.Sinkhorn( + lse_mode=lse_mode, max_iterations=1000, implicit_diff=implicit_diff + ) + solver = gromov_wasserstein.GromovWasserstein( + epsilon=1.0, max_iterations=10, linear_ot_solver=lin_solver + ) - @pytest.mark.parametrize("lse_mode", [False, True]) - def test_gradient_gw_geometry(self, lse_mode: bool): - """Test gradient w.r.t. cost matrices.""" - - def reg_gw(cx, cy, a, b, implicit): - geom_x = geometry.Geometry(cost_matrix=cx) - geom_y = geometry.Geometry(cost_matrix=cy) - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001, - 'lse_mode': lse_mode - } - return gromov_wasserstein.gromov_wasserstein( - geom_x, - geom_y, - a=a, - b=b, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs - ).reg_gw_cost + return solver(prob).reg_gw_cost grad_matrices = [None, None] + x, y = (self.cx, self.cy) if is_cost else (self.x, self.y) + reg_gw_grad = jax.grad(reg_gw, argnums=(0, 1)) + for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad( - reg_gw, argnums=( - 0, - 1, - ) - ) - _, grad_reg_gw = reg_gw_and_grad( - self.cx, self.cy, self.a, self.b, implicit - ) - grad_matrices[i] = grad_reg_gw - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) - assert not jnp.any(jnp.isnan(grad_reg_gw[1])) + grad_matrices[i] = reg_gw_grad(x, y, self.a, self.b, implicit) + assert not jnp.any(jnp.isnan(grad_matrices[i][0])) + assert not jnp.any(jnp.isnan(grad_matrices[i][1])) + np.testing.assert_allclose( grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 ) @@ -302,18 +283,18 @@ def test_gw_adaptive_threshold(self): geom_x = pointcloud.PointCloud(self.x, self.x) geom_y = pointcloud.PointCloud(self.y, self.y) - def loss_thre(threshold): - return gromov_wasserstein.gromov_wasserstein( - geom_xx=geom_x, - geom_yy=geom_y, - a=self.a, - b=self.b, - epsilon=.1, - threshold=threshold - ).reg_gw_cost + def loss_thre(threshold: float) -> float: + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, a=self.a, b=self.b + ) + solver = gromov_wasserstein.GromovWasserstein( + threshold=threshold, epsilon=1e-1 + ) - assert loss_thre(1e-1), loss_thre(1e-3) - assert loss_thre(1e-3), loss_thre(1e-5) + return solver(prob).reg_gw_cost + + assert loss_thre(1e-1) >= loss_thre(1e-4) + assert loss_thre(1e-3) >= loss_thre(1e-5) @pytest.mark.fast def test_gw_lr(self, rng: jnp.ndarray): @@ -367,51 +348,15 @@ def test_gw_lr_matches_fused(self, rng: jnp.ndarray): # Test at least some difference when adding bigger entropic regularization assert jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix) > 1e-3 - @pytest.mark.parametrize("scale_cost", ["mean", "max_cost"]) - def test_gw_fused_scale_cost(self, scale_cost: str): - epsilon = 0.1 - fused_penalty = 1 - geom_x = pointcloud.PointCloud(self.x, scale_cost=1.) - geom_y = pointcloud.PointCloud(self.y, scale_cost=1.) - geom_xy = pointcloud.PointCloud(self.xx, self.yy, scale_cost=1.) - geom_x_scaled = pointcloud.PointCloud(self.x, scale_cost=scale_cost) - geom_y_scaled = pointcloud.PointCloud(self.y, scale_cost=scale_cost) - geom_xy_scaled = pointcloud.PointCloud( - self.xx, self.yy, scale_cost=scale_cost - ) - - gt = gromov_wasserstein.gromov_wasserstein( - geom_xx=geom_x_scaled, - geom_yy=geom_y_scaled, - geom_xy=geom_xy_scaled, - fused_penalty=fused_penalty, - epsilon=epsilon, - scale_cost=False - ) - pred = gromov_wasserstein.gromov_wasserstein( - geom_xx=geom_x, - geom_yy=geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - epsilon=epsilon, - scale_cost=scale_cost - ) - - np.testing.assert_allclose(pred.matrix, gt.matrix) - np.testing.assert_allclose(pred.costs, gt.costs) - @pytest.mark.parametrize("axis", [0, 1]) def test_gw_lr_apply(self, axis: int): geom_x = pointcloud.PointCloud(self.x) geom_y = pointcloud.PointCloud(self.y) - out = gromov_wasserstein.gromov_wasserstein( - geom_xx=geom_x, - geom_yy=geom_y, - a=self.a, - b=self.b, - epsilon=.1, - rank=2, + prob = quadratic_problem.QuadraticProblem( + geom_x, geom_y, a=self.a, b=self.b ) + solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-1, rank=2) + out = solver(prob) arr, matrix = (self.x, out.matrix) if axis == 0 else (self.y, out.matrix.T) res_apply = out.apply(arr.T, axis=axis) @@ -420,159 +365,24 @@ def test_gw_lr_apply(self, axis: int): np.testing.assert_allclose(res_apply, res_matrix, rtol=1e-5, atol=1e-5) def test_gw_lr_warm_start_helps(self, rng: jnp.ndarray): - key1, key2 = jax.random.split(rng, 2) rank = 3 + key1, key2 = jax.random.split(rng, 2) geom_x = pointcloud.PointCloud(jax.random.normal(key1, (100, 5))) geom_y = pointcloud.PointCloud(jax.random.normal(key2, (110, 6))) + prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) - out = gromov_wasserstein.gromov_wasserstein( - geom_x, - geom_y, - rank=rank, - store_inner_errors=True, - warm_start=False, - ) - out_warm_start = gromov_wasserstein.gromov_wasserstein( - geom_x, - geom_y, - rank=rank, - store_inner_errors=True, - warm_start=True, - ) - - cost = out.reg_gw_cost - cost_warm_start = out_warm_start.reg_gw_cost - assert cost_warm_start + 5. < cost - with pytest.raises(AssertionError): - np.testing.assert_allclose(out.matrix, out_warm_start.matrix) - - -class TestGromovWassersteinUnbalanced: - - @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): - d_x = 2 - d_y = 3 - self.n, self.m = 5, 6 - keys = jax.random.split(rng, 7) - self.x = jax.random.uniform(keys[0], (self.n, d_x)) - self.y = jax.random.uniform(keys[1], (self.m, d_y)) - a = jax.random.uniform(keys[2], (self.n,)) + 0.1 - b = jax.random.uniform(keys[3], (self.m,)) + 0.1 - self.a = a / jnp.sum(a) - self.b = b / jnp.sum(b) - self.cx = jax.random.uniform(keys[4], (self.n, self.n)) - self.cy = jax.random.uniform(keys[5], (self.m, self.m)) - self.tau_a = 0.8 - self.tau_b = 0.9 - - @pytest.mark.fast - def test_gw_pointcloud(self): - """Test basic computations pointclouds.""" - - def reg_gw(x, y, a, b): - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - return gromov_wasserstein.gromov_wasserstein( - geom_x, - geom_y, - a=a, - b=b, - tau_a=self.tau_a, - tau_b=self.tau_b, - epsilon=1.0, - max_iterations=10 - ).reg_gw_cost - - cost = reg_gw(self.x, self.y, self.a, self.b) - assert not jnp.isnan(cost) - - @pytest.mark.parametrize("gw_unbalanced_correction", [False, True]) - def test_gradient_gw_pointcloud(self, gw_unbalanced_correction: bool): - """Test gradient w.r.t. pointclouds.""" - - def reg_gw(x, y, a, b, implicit): - geom_x = pointcloud.PointCloud(x) - geom_y = pointcloud.PointCloud(y) - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001 - } - return gromov_wasserstein.gromov_wasserstein( - geom_x, - geom_y, - a=a, - b=b, - tau_a=self.tau_a, - tau_b=self.tau_b, - gw_unbalanced_correction=gw_unbalanced_correction, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs - ).reg_gw_cost - - grad_matrices = [None, None] - for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad( - reg_gw, argnums=( - 0, - 1, - ) - ) - _, grad_reg_gw = reg_gw_and_grad(self.x, self.y, self.a, self.b, implicit) - grad_matrices[i] = grad_reg_gw - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) - assert not jnp.any(jnp.isnan(grad_reg_gw[1])) - - np.testing.assert_allclose( - grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 + solver_cold = gromov_wasserstein.GromovWasserstein( + rank=rank, warm_start=False ) - np.testing.assert_allclose( - grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 + solver_warm = gromov_wasserstein.GromovWasserstein( + rank=rank, warm_start=True ) - @pytest.mark.parametrize("gw_unbalanced_correction", [False, True]) - def test_gradient_gw_geometry(self, gw_unbalanced_correction: bool): - """Test gradient w.r.t. cost matrices.""" - - def reg_gw(cx, cy, a, b, implicit): - geom_x = geometry.Geometry(cost_matrix=cx) - geom_y = geometry.Geometry(cost_matrix=cy) - sinkhorn_kwargs = { - 'implicit_differentiation': implicit, - 'max_iterations': 1001 - } - return gromov_wasserstein.gromov_wasserstein( - geom_x, - geom_y, - a=a, - b=b, - tau_a=self.tau_a, - tau_b=self.tau_b, - gw_unbalanced_correction=gw_unbalanced_correction, - epsilon=1.0, - max_iterations=10, - sinkhorn_kwargs=sinkhorn_kwargs - ).reg_gw_cost - - grad_matrices = [None, None] - for i, implicit in enumerate([True, False]): - reg_gw_and_grad = jax.value_and_grad( - reg_gw, argnums=( - 0, - 1, - ) - ) - _, grad_reg_gw = reg_gw_and_grad( - self.cx, self.cy, self.a, self.b, implicit - ) - grad_matrices[i] = grad_reg_gw - assert not jnp.any(jnp.isnan(grad_reg_gw[0])) - assert not jnp.any(jnp.isnan(grad_reg_gw[1])) + out_cold = solver_cold(prob) + out_warm = solver_warm(prob) - np.testing.assert_allclose( - grad_matrices[0][0], grad_matrices[1][0], rtol=1e-02, atol=1e-02 - ) - np.testing.assert_allclose( - grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 - ) + cost = out_cold.reg_gw_cost + cost_warm_start = out_warm.reg_gw_cost + assert (cost_warm_start + 5.0) < cost + with pytest.raises(AssertionError): + np.testing.assert_allclose(out_cold.matrix, out_warm.matrix) diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index 33688d2cf..fa32a6907 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -13,9 +13,10 @@ # limitations under the License. """Tests for fit_gmm_pair.""" +import pytest + import jax import jax.numpy as jnp -import pytest from ott.tools.gaussian_mixture import ( fit_gmm, diff --git a/tests/tools/gaussian_mixture/fit_gmm_test.py b/tests/tools/gaussian_mixture/fit_gmm_test.py index ab101c085..f9ff660fc 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for fit_gmm_pair.""" +import pytest + import jax import jax.numpy as jnp import jax.test_util -import pytest from ott.tools.gaussian_mixture import fit_gmm, gaussian_mixture diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index 8790a6d48..4846d06e0 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for gaussian_mixture_pair.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import gaussian_mixture, gaussian_mixture_pair diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index 0c9857921..1d80b26f0 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for gaussian_mixture.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import gaussian_mixture, linalg diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index ee6018480..18b47f903 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for gaussian.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import gaussian, scale_tril diff --git a/tests/tools/gaussian_mixture/linalg_test.py b/tests/tools/gaussian_mixture/linalg_test.py index 0e7a96c61..e1e23b9e6 100644 --- a/tests/tools/gaussian_mixture/linalg_test.py +++ b/tests/tools/gaussian_mixture/linalg_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for linalg.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import linalg diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 5ea9d2fbd..d15a92592 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for probabilities.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.tools.gaussian_mixture import probabilities diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index 32f2f93f3..39d68042e 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -13,10 +13,11 @@ # limitations under the License. """Tests for ScaleTriL.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.math import matrix_square_root from ott.tools.gaussian_mixture import scale_tril diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index 416a4a849..0cfa0fa8c 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -1,14 +1,15 @@ from typing import Any, Optional, Tuple, Union +import pytest +from typing_extensions import Literal + import jax import jax.numpy as jnp import numpy as np -import pytest from sklearn import datasets from sklearn.cluster import KMeans from sklearn.cluster._k_means_common import _is_same_clustering from sklearn.cluster._kmeans import kmeans_plusplus -from typing_extensions import Literal from ott.geometry import costs, pointcloud from ott.tools import k_means diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index 44f2cbe5e..f6e43a4b7 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -13,12 +13,14 @@ # limitations under the License. """Tests for Segmented Sinkhorn.""" +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, pointcloud +from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn from ott.tools import segment_sinkhorn from ott.tools.gaussian_mixture import gaussian_mixture @@ -43,14 +45,13 @@ def test_segment_sinkhorn_result(self, shuffle: bool): rngs = jax.random.split(self.rng, 4) x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) - geom_kwargs = dict(epsilon=0.014) - sinkhorn_kwargs = dict(threshold=.2e-2) - true_regotcost = sinkhorn.sinkhorn( - pointcloud.PointCloud(x, y, **geom_kwargs), - a=self._a, - b=self._b, - **sinkhorn_kwargs - ).reg_ot_cost + geom_kwargs = {"epsilon": 0.014} + sinkhorn_kwargs = {"threshold": 2e-3} + + geom = pointcloud.PointCloud(x, y, **geom_kwargs) + prob = linear_problem.LinearProblem(geom, a=self._a, b=self._b) + solver = sinkhorn.Sinkhorn(**sinkhorn_kwargs) + true_reg_ot_cost = solver(prob).reg_ot_cost if shuffle: # Now, shuffle the order of both arrays, but @@ -88,7 +89,7 @@ def test_segment_sinkhorn_result(self, shuffle: bool): **geom_kwargs ) - np.testing.assert_allclose(true_regotcost.repeat(2), segmented_regotcost) + np.testing.assert_allclose(true_reg_ot_cost.repeat(2), segmented_regotcost) def test_segment_sinkhorn_different_segment_sizes(self): # Test other array sizes @@ -116,13 +117,15 @@ def test_segment_sinkhorn_different_segment_sizes(self): assert segmented_regotcost.shape[0] == 2 assert segmented_regotcost[1] > segmented_regotcost[0] - true_regotcost = jnp.array([ - sinkhorn.sinkhorn(pointcloud.PointCloud(x, y, - epsilon=0.01),).reg_ot_cost - for x, y in zip((x1, x2), (y1, y2)) - ]) + true_reg_ot_cost = [] + for x, y in zip((x1, x2), (y1, y2)): + geom = pointcloud.PointCloud(x, y, epsilon=1e-2) + prob = linear_problem.LinearProblem(geom) + solver = sinkhorn.Sinkhorn() + true_reg_ot_cost.append(solver(prob).reg_ot_cost) + np.testing.assert_allclose( - segmented_regotcost, true_regotcost, atol=1e-4, rtol=1e-4 + segmented_regotcost, true_reg_ot_cost, atol=1e-4, rtol=1e-4 ) def test_sinkhorn_divergence_segment_custom_padding(self, rng): @@ -146,16 +149,17 @@ def g(rng, n): x1, x2, y1, y2 = (g(rngs[i], ns[i]) for i in range(4)) - true_regotcost = jnp.array([ - sinkhorn.sinkhorn( - pointcloud.PointCloud(x, y, cost_fn=b_cost, epsilon=0.1) - ).reg_ot_cost for x, y in zip((x1, x2), (y1, y2)) - ]) + true_reg_ot_cost = [] + for x, y in zip((x1, x2), (y1, y2)): + geom = pointcloud.PointCloud(x, y, cost_fn=b_cost, epsilon=1e-1) + prob = linear_problem.LinearProblem(geom) + solver = sinkhorn.Sinkhorn() + true_reg_ot_cost.append(solver(prob).reg_ot_cost) x = jnp.vstack((x1, x2)) y = jnp.vstack((y1, y2)) - segmented_regotcost = segment_sinkhorn.segment_sinkhorn( + segmented_reg_ot_cost = segment_sinkhorn.segment_sinkhorn( x, y, num_segments=2, @@ -166,4 +170,4 @@ def g(rng, n): sinkhorn_kwargs={'lse_mode': True}, epsilon=0.1, ) - np.testing.assert_allclose(segmented_regotcost, true_regotcost) + np.testing.assert_allclose(segmented_reg_ot_cost, true_reg_ot_cost) diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 0c768a953..e0a0472b4 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -14,13 +14,14 @@ """Tests for the Sinkhorn divergence.""" from typing import Any, Dict, Optional +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest from ott.geometry import costs, geometry, pointcloud -from ott.solvers.linear import sinkhorn +from ott.solvers.linear import acceleration, sinkhorn from ott.tools import sinkhorn_divergence from ott.tools.gaussian_mixture import gaussian_mixture @@ -47,7 +48,7 @@ def setUp(self, rng: jnp.ndarray): "epsilon": 1e-2 }, ) - def test_euclidean_point_cloud(self, cost_fn, epsilon): + def test_euclidean_point_cloud(self, cost_fn: costs.CostFn, epsilon: float): rngs = jax.random.split(self.rng, 2) x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) @@ -68,9 +69,9 @@ def test_euclidean_point_cloud(self, cost_fn, epsilon): geometry_xx = pointcloud.PointCloud(x, epsilon=epsilon, cost_fn=cost_fn) geometry_yy = pointcloud.PointCloud(y, epsilon=epsilon, cost_fn=cost_fn) - div2 = sinkhorn.sinkhorn(geometry_xy, self._a, self._b).reg_ot_cost - div2 -= 0.5 * sinkhorn.sinkhorn(geometry_xx, self._a, self._a).reg_ot_cost - div2 -= 0.5 * sinkhorn.sinkhorn(geometry_yy, self._b, self._b).reg_ot_cost + div2 = sinkhorn.solve(geometry_xy, self._a, self._b).reg_ot_cost + div2 -= 0.5 * sinkhorn.solve(geometry_xx, self._a, self._a).reg_ot_cost + div2 -= 0.5 * sinkhorn.solve(geometry_yy, self._b, self._b).reg_ot_cost np.testing.assert_allclose(div.divergence, div2, rtol=1e-5, atol=1e-5) @@ -80,7 +81,7 @@ def test_euclidean_point_cloud(self, cost_fn, epsilon): x, x, cost_fn=cost_fn, - epsilon=.1, + epsilon=1e-1, sinkhorn_kwargs={'inner_iterations': 1}, ) np.testing.assert_allclose(div.divergence, 0.0, rtol=1e-5, atol=1e-5) @@ -364,12 +365,12 @@ def g(rng, n): # yapf: disable @pytest.mark.fast.with_args( "sinkhorn_kwargs,epsilon", [ - ({"anderson_acceleration": 3}, 1e-2), - ({"anderson_acceleration": 6}, None), - ({"chg_momentum_from": 20}, 1e-3), - ({"chg_momentum_from": 30}, None), - ({"momentum": 1.05}, 1e-3), - ({"momentum": 1.01}, None), + ({"anderson": acceleration.AndersonAcceleration(memory=3)}, 1e-2), + ({"anderson": acceleration.AndersonAcceleration(memory=6)}, None), + ({"momentum": acceleration.Momentum(start=20)}, 1e-3), + ({"momentum": acceleration.Momentum(start=30)}, None), + ({"momentum": acceleration.Momentum(value=1.05)}, 1e-3), + ({"momentum": acceleration.Momentum(value=1.01)}, None), ], only_fast=[0, -1], ) diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index a6fe6bb37..319c9fdd5 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -15,11 +15,14 @@ import functools from typing import Tuple +import pytest + import jax import jax.numpy as jnp import numpy as np -import pytest +from ott.solvers.linear import acceleration +from ott.solvers.linear import implicit_differentiation as implicit_lib from ott.tools import soft_sort @@ -41,10 +44,14 @@ def test_sort_array_squashing_momentum(self, rng: jnp.ndarray): axis=0, squashing_fun=lambda x: x, epsilon=5e-4, - chg_momentum_from=100 + momentum=acceleration.Momentum(start=100), ) xs_sig = soft_sort.sort( - x, axis=0, squashing_fun=None, epsilon=2e-4, chg_momentum_from=100 + x, + axis=0, + squashing_fun=None, + epsilon=2e-4, + momentum=acceleration.Momentum(start=100) ) # Notice xs_lin and xs_sig have no reason to be equal, since they use # different squashing functions, but they should be similar. @@ -159,11 +166,12 @@ def test_soft_sort_jacobian(self, rng: jnp.ndarray, implicit: bool): random_dir = jax.random.normal(rngs[1], (b,)) / b def loss_fn(logits: jnp.ndarray) -> float: + implicit_diff = implicit_lib.ImplicitDiff() if implicit else None ranks_fn = functools.partial( soft_sort.ranks, axis=-1, num_targets=167, - implicit_differentiation=implicit + implicit_diff=implicit_diff, ) return jnp.sum(ranks_fn(logits)[:, idx_column] * random_dir) diff --git a/tests/tools/transport_test.py b/tests/tools/transport_test.py deleted file mode 100644 index 4954c85bd..000000000 --- a/tests/tools/transport_test.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for ott.tools.transport.""" - -import jax -import jax.numpy as jnp -import numpy as np -import pytest - -from ott.geometry import pointcloud -from ott.problems.linear import linear_problem -from ott.tools import transport - - -@pytest.mark.fast -class TestTransport: - """Tests for the Transport class.""" - - def test_transport_from_point(self, rng: jnp.ndarray): - rngs = jax.random.split(rng, 2) - num_a, num_b = 23, 17 - x = jax.random.uniform(rngs[0], (num_a, 4)) - y = jax.random.uniform(rngs[1], (num_b, 4)) - ot = transport.solve(x, y, threshold=1e-2) - - np.testing.assert_array_equal(ot.matrix.shape, (num_a, num_b)) - np.testing.assert_allclose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3) - np.testing.assert_allclose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3) - - def test_transport_from_geom(self, rng: jnp.ndarray): - rngs = jax.random.split(rng, 3) - num_a, num_b = 23, 17 - x = jax.random.uniform(rngs[0], (num_a, 4)) - y = jax.random.uniform(rngs[1], (num_b, 4)) - geom = pointcloud.PointCloud(x, y, epsilon=1e-2, batch_size=8) - b = jax.random.uniform(rngs[2], (num_b,)) - b /= jnp.sum(b) - ot = transport.solve(geom, b=b, threshold=1e-3) - - np.testing.assert_array_equal(ot.matrix.shape, (num_a, num_b)) - np.testing.assert_allclose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3) - np.testing.assert_allclose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3) - - def test_transport_from_problem(self, rng: jnp.ndarray): - rngs = jax.random.split(rng, 3) - num_a, num_b = 23, 17 - x = jax.random.uniform(rngs[0], (num_a, 4)) - y = jax.random.uniform(rngs[1], (num_b, 4)) - geom = pointcloud.PointCloud(x, y, batch_size=9) - b = jax.random.uniform(rngs[2], (num_b,)) - b /= jnp.sum(b) - pb = linear_problem.LinearProblem(geom, b=b) - ot = transport.solve(pb) - - np.testing.assert_array_equal(ot.matrix.shape, (num_a, num_b)) - np.testing.assert_allclose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3) - np.testing.assert_allclose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3) - - def test_transport_wrong_init(self, rng: jnp.ndarray): - rngs = jax.random.split(rng, 2) - num_a, num_b = 23, 17 - x = jax.random.uniform(rngs[0], (num_a, 4)) - y = jax.random.uniform(rngs[1], (num_b, 4)) - geom = pointcloud.PointCloud(x, y, epsilon=1e-2, batch_size=10) - with pytest.raises(AttributeError, match=r".*has no attribute.*'"): - transport.solve(geom, x, threshold=1e-3) - - with pytest.raises(ValueError, match="Cannot instantiate a transport"): - transport.solve('pointcloud', threshold=1e-3)