From 0327ae3bf58e5825954a95a65a8f0232e7e58e86 Mon Sep 17 00:00:00 2001 From: michalk8 <46717574+michalk8@users.noreply.github.com> Date: Wed, 13 Sep 2023 18:03:29 +0200 Subject: [PATCH] Fix `n_iters` (#437) * Start cleaning `n_iters` * Remove dead code * Update SD * Update GW tutorial * Remove unnecessary hierarchy in tutorials * Remove `n_iters` property from SD * Fix not passing `inner_iterations` * Skip testing `flax` on 3.8 * Fix SD test * Use `.toarray()` * Update `FenchelConjugateLBFGS` for `jaxopt>=0.8` * Fix missing `importorskip` * Split `MetaInitializer` test * [ci skip] Fix typos in bary docs * Skip tests that need `optax` * Remove relative pip install from MG notebook --- docs/index.rst | 2 +- docs/tutorials/index.rst | 42 +++---- pyproject.toml | 8 +- src/ott/problems/linear/barycenter_problem.py | 8 +- src/ott/solvers/linear/sinkhorn.py | 14 +-- src/ott/solvers/linear/sinkhorn_lr.py | 10 +- src/ott/solvers/nn/conjugate_solvers.py | 16 +-- .../solvers/quadratic/gromov_wasserstein.py | 2 +- .../quadratic/gromov_wasserstein_lr.py | 19 +-- src/ott/solvers/quadratic/gw_barycenter.py | 2 +- src/ott/tools/sinkhorn_divergence.py | 36 +++--- tests/geometry/graph_test.py | 4 +- .../initializers/linear/sinkhorn_init_test.py | 42 ------- .../initializers/nn/meta_initializer_test.py | 114 ++++++++++++++++++ tests/solvers/nn/icnn_test.py | 3 + tests/solvers/nn/losses_test.py | 3 + tests/solvers/nn/neuraldual_test.py | 3 + .../gaussian_mixture/fit_gmm_pair_test.py | 3 + tests/tools/map_estimator_test.py | 3 + 19 files changed, 209 insertions(+), 125 deletions(-) create mode 100644 tests/initializers/nn/meta_initializer_test.py diff --git a/docs/index.rst b/docs/index.rst index 170b1b6a..22450bb1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -103,7 +103,7 @@ Packages :maxdepth: 1 :caption: Examples - Getting Started + Getting Started tutorials/index .. toctree:: diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index d6d2cd42..63429a46 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -6,54 +6,54 @@ Geometry .. toctree:: :maxdepth: 1 - notebooks/introduction_grid + introduction_grid Linear Optimal Transport ------------------------ .. toctree:: :maxdepth: 1 - notebooks/point_clouds - notebooks/One_Sinkhorn - notebooks/OTT_&_POT - notebooks/Hessians - notebooks/LRSinkhorn - notebooks/sinkhorn_divergence_gradient_flow - notebooks/sparse_monge_displacements + point_clouds + One_Sinkhorn + OTT_&_POT + Hessians + LRSinkhorn + sinkhorn_divergence_gradient_flow + sparse_monge_displacements Barycenters ^^^^^^^^^^^ .. toctree:: :maxdepth: 1 - notebooks/Sinkhorn_Barycenters - notebooks/gmm_pair_demo - notebooks/wasserstein_barycenters_gmms + Sinkhorn_Barycenters + gmm_pair_demo + wasserstein_barycenters_gmms Miscellaneous ^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 - notebooks/tracking_progress - notebooks/soft_sort - notebooks/application_biology + tracking_progress + soft_sort + application_biology Quadratic Optimal Transport --------------------------- .. toctree:: :maxdepth: 1 - notebooks/gromov_wasserstein - notebooks/GWLRSinkhorn - notebooks/gromov_wasserstein_multiomics + gromov_wasserstein + GWLRSinkhorn + gromov_wasserstein_multiomics Neural Optimal Transport ------------------------ .. toctree:: :maxdepth: 1 - notebooks/neural_dual - notebooks/icnn_inits - notebooks/MetaOT - notebooks/Monge_Gap + neural_dual + icnn_inits + MetaOT + Monge_Gap diff --git a/pyproject.toml b/pyproject.toml index a100ee68..08abdc05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,7 +183,10 @@ legacy_tox_ini = """ skip_missing_interpreters = true [testenv] - extras = test,neural + extras = + test + # https://github.com/google/flax/issues/3329 + py{3.9,3.10,3.11},py3.9-jax-default: neural pass_env = CUDA_*,PYTEST_*,CI commands_pre = gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html @@ -299,7 +302,8 @@ select = [ unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] target-version = "py38" [tool.ruff.per-file-ignores] -"tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/` +# TODO(michalk8): PO004 - remove `self.initialize` +"tests/*" = ["D", "PT004", "E402"] "*/__init__.py" = ["F401"] "docs/*" = ["D"] "src/ott/types.py" = ["D102"] diff --git a/src/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py index ebeba06b..ca5333a8 100644 --- a/src/ott/problems/linear/barycenter_problem.py +++ b/src/ott/problems/linear/barycenter_problem.py @@ -18,7 +18,7 @@ from ott.geometry import costs, geometry, segment -__all__ = ["FreeBarycenterProblem"] +__all__ = ["FreeBarycenterProblem", "FixedBarycenterProblem"] @jax.tree_util.register_pytree_node_class @@ -44,8 +44,8 @@ class FreeBarycenterProblem: Only used when ``y`` is not already segmented. When passing ``segment_ids``, 2 arguments must be specified for jitting to work: - - ``num_segments`` - the total number of measures. - - ``max_measure_size`` - maximum of support sizes of these measures. + - ``num_segments`` - the total number of measures. + - ``max_measure_size`` - maximum of support sizes of these measures. """ def __init__( @@ -158,7 +158,7 @@ class FixedBarycenterProblem: a: batch of histograms of shape ``[batch, num_a]`` where ``num_a`` matches the first value of the :attr:`~ott.geometry.Geometry.shape` attribute of ``geom``. - weights: ``[batch,]`` positive weights summing to :math`1`. Uniform by + weights: ``[batch,]`` positive weights summing to :math:`1`. Uniform by default. """ diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index f7b1e54a..36b1d968 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -318,7 +318,6 @@ class SinkhornOutput(NamedTuple): below the convergence threshold. inner_iterations: number of iterations that were run between two computations of errors. - """ f: Optional[jnp.ndarray] = None @@ -424,10 +423,6 @@ def transport_cost_at_geom( return jnp.sum(self.apply(geom.cost_1.T) * geom.cost_2.T) return jnp.sum(self.matrix * other_geom.cost_matrix) - @property - def linear(self) -> bool: # noqa: D102 - return isinstance(self.ot_prob, linear_problem.LinearProblem) - @property def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @@ -440,17 +435,10 @@ def a(self) -> jnp.ndarray: # noqa: D102 def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b - @property - def linear_output(self) -> bool: # noqa: D102 - return True - - # TODO(michalk8): this should be always present @property def n_iters(self) -> int: # noqa: D102 """Returns the total number of iterations that were needed to terminate.""" - if self.errors is None: - return -1 - return jnp.sum(self.errors > -1) * self.inner_iterations + return jnp.sum(self.errors != -1) * self.inner_iterations @property def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102 diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 740eed8c..0e4cafb3 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -175,6 +175,7 @@ class LRSinkhornOutput(NamedTuple): errors: jnp.ndarray ot_prob: linear_problem.LinearProblem epsilon: float + inner_iterations: int # TODO(michalk8): Optional is an artifact of the current impl., refactor reg_ot_cost: Optional[float] = None @@ -205,10 +206,6 @@ def compute_reg_ot_cost( # noqa: D102 use_danskin=use_danskin ) - @property - def linear(self) -> bool: # noqa: D102 - return isinstance(self.ot_prob, linear_problem.LinearProblem) - @property def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @@ -222,8 +219,8 @@ def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property - def linear_output(self) -> bool: # noqa: D102 - return True + def n_iters(self) -> int: # noqa: D102 + return jnp.sum(self.errors != -1) * self.inner_iterations @property def converged(self) -> bool: # noqa: D102 @@ -773,6 +770,7 @@ def output_from_state( costs=state.costs, errors=state.errors, epsilon=self.epsilon, + inner_iterations=self.inner_iterations, ) def _converged(self, state: LRSinkhornState, iteration: int) -> bool: diff --git a/src/ott/solvers/nn/conjugate_solvers.py b/src/ott/solvers/nn/conjugate_solvers.py index 871d7e1d..0758cf1a 100644 --- a/src/ott/solvers/nn/conjugate_solvers.py +++ b/src/ott/solvers/nn/conjugate_solvers.py @@ -75,16 +75,18 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver): max_iter: maximum number of iterations max_linesearch_iter: maximum number of line search iterations linesearch_type: type of line search - decrease_factor: decrease factor for a backtracking line search - ls_method: the line search method + linesearch_init: strategy for line search initialization + increase_factor: factor by which to increase the step size during + the line search """ gtol: float = 1e-3 max_iter: int = 10 max_linesearch_iter: int = 10 - linesearch_type: Literal["zoom", "backtracking"] = "backtracking" - decrease_factor: float = 0.66 - ls_method: Literal["wolf", "strong-wolfe"] = "strong-wolfe" + linesearch_type: Literal["zoom", "backtracking", + "hager-zhang"] = "backtracking" + linesearch_init: Literal["increase", "max", "current"] = "increase" + increase_factor: float = 1.5 def solve( # noqa: D102 self, @@ -98,9 +100,9 @@ def solve( # noqa: D102 fun=lambda x: f(x) - x.dot(y), tol=self.gtol, maxiter=self.max_iter, - decrease_factor=self.decrease_factor, linesearch=self.linesearch_type, - condition=self.ls_method, + linesearch_init=self.linesearch_init, + increase_factor=self.increase_factor, implicit_diff=False, unroll=False ) diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index ca8656bc..1eff97d6 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -103,7 +103,7 @@ def primal_cost(self) -> float: def n_iters(self) -> int: # noqa: D102 if self.errors is None: return -1 - return jnp.sum(self.errors > -1) + return jnp.sum(self.errors[:, 0] != -1) class GWState(NamedTuple): diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index 0777f291..a6e46974 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -116,7 +116,14 @@ def ent(x: jnp.ndarray) -> float: g = jax.lax.stop_gradient(g) if use_danskin else g out = LRGWOutput( - q=q, r=r, g=g, ot_prob=ot_prob, costs=None, errors=None, epsilon=None + q=q, + r=r, + g=g, + ot_prob=ot_prob, + costs=None, + errors=None, + epsilon=None, + inner_iterations=None, ) cost = out.primal_cost - epsilon * (ent(q) + ent(r) + ent(g)) @@ -141,6 +148,7 @@ class LRGWOutput(NamedTuple): errors: jnp.ndarray ot_prob: quadratic_problem.QuadraticProblem epsilon: float + inner_iterations: int reg_gw_cost: Optional[float] = None def set(self, **kwargs: Any) -> "LRGWOutput": @@ -170,10 +178,6 @@ def compute_reg_gw_cost( # noqa: D102 use_danskin=use_danskin ) - @property - def linear(self) -> bool: # noqa: D102 - return False - @property def geom(self) -> geometry.Geometry: # noqa: D102 """Linearized geometry.""" @@ -188,8 +192,8 @@ def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property - def linear_output(self) -> bool: # noqa: D102 - return False + def n_iters(self) -> int: # noqa: D102 + return jnp.sum(self.errors != -1) * self.inner_iterations @property def converged(self) -> bool: # noqa: D102 @@ -792,6 +796,7 @@ def output_from_state( costs=state.costs, errors=state.errors, epsilon=self.epsilon, + inner_iterations=self.inner_iterations, ) def _converged(self, state: LRGWState, iteration: int) -> bool: diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 0873e9c5..2e398aaf 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -62,7 +62,7 @@ def n_iters(self) -> int: """Number of iterations.""" if self.gw_convergence is None: return -1 - return jnp.sum(self.gw_convergence > -1) + return jnp.sum(self.gw_convergence != -1) @jax.tree_util.register_pytree_node_class diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 6b89f45e..653eaa26 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import MappingProxyType -from typing import Any, List, Mapping, NamedTuple, Optional, Tuple, Type +from typing import Any, Mapping, NamedTuple, Optional, Tuple, Type import jax.numpy as jnp @@ -26,16 +26,19 @@ "SinkhornDivergenceOutput" ] +Potentials_t = Tuple[jnp.ndarray, jnp.ndarray] + class SinkhornDivergenceOutput(NamedTuple): # noqa: D101 divergence: float - potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]] + potentials: Tuple[Potentials_t, Potentials_t, Potentials_t] geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]] converged: Tuple[bool, bool, bool] a: jnp.ndarray b: jnp.ndarray + n_iters: Tuple[int, int, int] def to_dual_potentials(self) -> "potentials.EntropicPotentials": """Return dual estimators :cite:`pooladian:22`, eq. 8.""" @@ -46,17 +49,6 @@ def to_dual_potentials(self) -> "potentials.EntropicPotentials": f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y ) - @property - def n_iters(self) -> Tuple[int, int, int]: # noqa: D102 - """Returns 3 number of iterations that were needed to terminate.""" - out = [] - for i in range(3): - if self.errors[i] is None: - out.append(-1) - else: - out.append(jnp.sum(self.errors[i] > -1)) - return out - def sinkhorn_divergence( geom: Type[geometry.Geometry], @@ -181,7 +173,10 @@ def _sinkhorn_divergence( # Create dummy output, corresponds to scenario where static_b is True. # This choice ensures that `converged`` of this dummy output is True. out_yy = sinkhorn.SinkhornOutput( - errors=jnp.array([-jnp.inf]), reg_ot_cost=0.0, threshold=0.0 + errors=jnp.array([-jnp.inf]), + reg_ot_cost=0.0, + threshold=0.0, + inner_iterations=0, ) else: out_yy = linear.solve(geometry_yy, b, b, **kwargs_symmetric) @@ -190,11 +185,16 @@ def _sinkhorn_divergence( out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) + 0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2 ) - out = (out_xy, out_xx, out_yy) return SinkhornDivergenceOutput( - div, tuple([s.f, s.g] for s in out), - (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out), - tuple(s.converged for s in out), a, b + divergence=div, + potentials=((out_xy.f, out_xy.g), (out_xx.f, out_xx.g), + (out_yy.f, out_yy.g)), + geoms=(geometry_xy, geometry_xx, geometry_yy), + errors=(out_xy.errors, out_xx.errors, out_yy.errors), + converged=(out_xy.converged, out_xx.converged, out_yy.converged), + a=a, + b=b, + n_iters=(out_xy.n_iters, out_xx.n_iters, out_yy.n_iters), ) diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 0c71fc44..18e683bb 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -47,7 +47,7 @@ def random_graph( G ) if return_laplacian else nx.linalg.adjacency_matrix(G) - return jnp.asarray(G.A) + return jnp.asarray(G.toarray()) def gt_geometry(G: jnp.ndarray, *, epsilon: float = 1e-2) -> geometry.Geometry: @@ -140,7 +140,7 @@ def test_approximates_ground_truth( def test_crank_nicolson_more_stable(self, t: Optional[float], n_steps: int): tol = 5 * t G = nx.linalg.adjacency_matrix(balanced_tree(r=2, h=5)) - G = jnp.asarray(G.A, dtype=float) + G = jnp.asarray(G.toarray(), dtype=float) eye = jnp.eye(G.shape[0]) be_geom = graph.Graph.from_graph( diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 9974946a..70f5e750 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -19,7 +19,6 @@ import pytest from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers as linear_init -from ott.initializers.nn import initializers as nn_init from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn @@ -281,44 +280,3 @@ def test_initializer_n_iter( assert default_out.n_iters > init_out.n_iters else: assert default_out.n_iters >= init_out.n_iters - - @pytest.mark.parametrize("lse_mode", [True, False]) - def test_meta_initializer(self, rng: jax.random.PRNGKeyArray, lse_mode: bool): - """Tests Meta initializer""" - n, m, d = 20, 20, 2 - epsilon = 1e-2 - - ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, batch_size=3) - a = ot_problem.a - b = ot_problem.b - geom = ot_problem.geom - - # run sinkhorn - sink_out = run_sinkhorn( - x=ot_problem.geom.x, - y=ot_problem.geom.y, - initializer=linear_init.DefaultInitializer(), - a=ot_problem.a, - b=ot_problem.b, - epsilon=epsilon, - lse_mode=lse_mode - ) - - # overfit the initializer to the problem. - meta_initializer = nn_init.MetaInitializer(geom) - for _ in range(50): - _, _, meta_initializer.state = meta_initializer.update( - meta_initializer.state, a=a, b=b - ) - - prob = linear_problem.LinearProblem(geom, a, b) - solver = sinkhorn.Sinkhorn(initializer=meta_initializer, lse_mode=lse_mode) - meta_out = solver(prob) - - # check initializer is better - if lse_mode: - assert sink_out.converged - assert meta_out.converged - assert sink_out.n_iters > meta_out.n_iters - else: - assert sink_out.n_iters >= meta_out.n_iters diff --git a/tests/initializers/nn/meta_initializer_test.py b/tests/initializers/nn/meta_initializer_test.py new file mode 100644 index 00000000..b6ed835a --- /dev/null +++ b/tests/initializers/nn/meta_initializer_test.py @@ -0,0 +1,114 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import jax +import jax.numpy as jnp +import pytest + +_ = pytest.importorskip("flax") + +from ott.geometry import pointcloud +from ott.initializers.linear import initializers as linear_init +from ott.initializers.nn import initializers as nn_init +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + + +def create_ot_problem( + rng: jax.random.PRNGKeyArray, + n: int, + m: int, + d: int, + epsilon: float = 1e-2, + batch_size: Optional[int] = None +) -> linear_problem.LinearProblem: + # define ot problem + x_rng, y_rng = jax.random.split(rng) + + mu_a = jnp.array([-1, 1]) * 5 + mu_b = jnp.array([0, 0]) + + x = jax.random.normal(x_rng, (n, d)) + mu_a + y = jax.random.normal(y_rng, (m, d)) + mu_b + + a = jnp.ones(n) / n + b = jnp.ones(m) / m + + geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) + + return linear_problem.LinearProblem(geom=geom, a=a, b=b) + + +def run_sinkhorn( + x: jnp.ndarray, + y: jnp.ndarray, + *, + initializer: linear_init.SinkhornInitializer, + a: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, + epsilon: float = 1e-2, + lse_mode: bool = True, +) -> sinkhorn.SinkhornOutput: + """Runs Sinkhorn algorithm with given initializer.""" + + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, initializer=initializer) + return solver(prob) + + +@pytest.mark.fast() +class TestMetaInitializer: + + @pytest.mark.parametrize("lse_mode", [True, False]) + def test_meta_initializer(self, rng: jax.random.PRNGKeyArray, lse_mode: bool): + """Tests Meta initializer""" + n, m, d = 20, 20, 2 + epsilon = 1e-2 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, batch_size=3) + a = ot_problem.a + b = ot_problem.b + geom = ot_problem.geom + + # run sinkhorn + sink_out = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + initializer=linear_init.DefaultInitializer(), + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + lse_mode=lse_mode + ) + + # overfit the initializer to the problem. + meta_initializer = nn_init.MetaInitializer(geom) + for _ in range(50): + _, _, meta_initializer.state = meta_initializer.update( + meta_initializer.state, a=a, b=b + ) + + prob = linear_problem.LinearProblem(geom, a, b) + solver = sinkhorn.Sinkhorn(initializer=meta_initializer, lse_mode=lse_mode) + meta_out = solver(prob) + + # check initializer is better + if lse_mode: + assert sink_out.converged + assert meta_out.converged + assert sink_out.n_iters > meta_out.n_iters + else: + assert sink_out.n_iters >= meta_out.n_iters diff --git a/tests/solvers/nn/icnn_test.py b/tests/solvers/nn/icnn_test.py index a674cf63..2e500d9a 100644 --- a/tests/solvers/nn/icnn_test.py +++ b/tests/solvers/nn/icnn_test.py @@ -15,6 +15,9 @@ import jax.numpy as jnp import numpy as np import pytest + +_ = pytest.importorskip("flax") + from ott.solvers.nn import models diff --git a/tests/solvers/nn/losses_test.py b/tests/solvers/nn/losses_test.py index 2b50df95..9a1e8091 100644 --- a/tests/solvers/nn/losses_test.py +++ b/tests/solvers/nn/losses_test.py @@ -15,6 +15,9 @@ import jax import numpy as np import pytest + +_ = pytest.importorskip("flax") + from ott.geometry import costs from ott.solvers.nn import losses, models diff --git a/tests/solvers/nn/neuraldual_test.py b/tests/solvers/nn/neuraldual_test.py index 8bd92bc5..e242709b 100644 --- a/tests/solvers/nn/neuraldual_test.py +++ b/tests/solvers/nn/neuraldual_test.py @@ -17,6 +17,9 @@ import jax import numpy as np import pytest + +_ = pytest.importorskip("flax") + from ott.problems.nn import dataset from ott.solvers.nn import conjugate_solvers, models, neuraldual diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index 22b1d430..6e5752e3 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -21,6 +21,9 @@ gaussian_mixture_pair, ) +# on 3.8, neural (flax/optax) is not installed +_ = pytest.importorskip("optax") + class TestFitGmmPair: diff --git a/tests/tools/map_estimator_test.py b/tests/tools/map_estimator_test.py index a917e755..fc4e1f19 100644 --- a/tests/tools/map_estimator_test.py +++ b/tests/tools/map_estimator_test.py @@ -16,6 +16,9 @@ import jax.numpy as jnp import pytest + +_ = pytest.importorskip("flax") + from ott.geometry import pointcloud from ott.problems.nn import dataset from ott.solvers.nn import losses, models