From 6cd26b8c7c11a30b2fb2884a477aa73a04791b32 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 12 Sep 2023 11:40:50 +0200 Subject: [PATCH 1/3] Better tracking of `n_iters` --- src/ott/solvers/linear/sinkhorn.py | 15 ++---- src/ott/solvers/linear/sinkhorn_lr.py | 4 ++ .../solvers/quadratic/gromov_wasserstein.py | 16 +++---- .../quadratic/gromov_wasserstein_lr.py | 4 ++ src/ott/solvers/quadratic/gw_barycenter.py | 13 ++---- src/ott/tools/sinkhorn_divergence.py | 46 ++++++++++++------- 6 files changed, 54 insertions(+), 44 deletions(-) diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index f7b1e54a8..f52affd30 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -57,6 +57,7 @@ class SinkhornState(NamedTuple): gv: Optional[jnp.ndarray] = None old_fus: Optional[jnp.ndarray] = None old_mapped_fus: Optional[jnp.ndarray] = None + iteration: int = -1 def set(self, **kwargs: Any) -> "SinkhornState": """Return a copy of self, with potential overwrites.""" @@ -329,6 +330,7 @@ class SinkhornOutput(NamedTuple): threshold: Optional[jnp.ndarray] = None converged: Optional[bool] = None inner_iterations: Optional[int] = None + n_iters: int = -1 def set(self, **kwargs: Any) -> "SinkhornOutput": """Return a copy of self, with potential overwrites.""" @@ -444,14 +446,6 @@ def b(self) -> jnp.ndarray: # noqa: D102 def linear_output(self) -> bool: # noqa: D102 return True - # TODO(michalk8): this should be always present - @property - def n_iters(self) -> int: # noqa: D102 - """Returns the total number of iterations that were needed to terminate.""" - if self.errors is None: - return -1 - return jnp.sum(self.errors > -1) * self.inner_iterations - @property def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102 u = self.ot_prob.geom.scaling_from_potential(self.f) @@ -993,7 +987,7 @@ def one_iteration( ot_prob, ) errors = state.errors.at[iteration // self.inner_iterations, :].set(err) - state = state.set(errors=errors) + state = state.set(iteration=iteration, errors=errors) if self.progress_fn is not None: jax.experimental.io_callback( @@ -1094,7 +1088,8 @@ def output_from_state( errors=state.errors[:, 0], threshold=jnp.array(self.threshold), converged=converged, - inner_iterations=self.inner_iterations + inner_iterations=self.inner_iterations, + n_iters=state.iteration + 1, ) @property diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 5185a9faf..142c5529d 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -51,6 +51,7 @@ class LRSinkhornState(NamedTuple): costs: jnp.ndarray errors: jnp.ndarray crossed_threshold: bool + iteration: int = -1 def compute_error( # noqa: D102 self, previous_state: "LRSinkhornState" @@ -178,6 +179,7 @@ class LRSinkhornOutput(NamedTuple): epsilon: float # TODO(michalk8): Optional is an artifact of the current impl., refactor reg_ot_cost: Optional[float] = None + n_iters: int = -1 def set(self, **kwargs: Any) -> "LRSinkhornOutput": """Return a copy of self, with potential overwrites.""" @@ -709,6 +711,7 @@ def one_iteration( costs=state.costs.at[it].set(cost), errors=state.errors.at[it].set(error), crossed_threshold=crossed_threshold, + iteration=iteration, ) if self.progress_fn is not None: @@ -779,6 +782,7 @@ def output_from_state( costs=state.costs, errors=state.errors, epsilon=self.epsilon, + n_iters=state.iteration + 1, ) def _converged(self, state: LRSinkhornState, iteration: int) -> bool: diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 24626049a..4345da897 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -61,6 +61,7 @@ class GWOutput(NamedTuple): linearization of GW. geom: The geometry underlying the local linearization. old_transport_mass: Holds total mass of transport at previous iteration. + n_iters: Number of the outer GW iterations. """ costs: Optional[jnp.ndarray] = None @@ -71,6 +72,7 @@ class GWOutput(NamedTuple): geom: Optional[geometry.Geometry] = None # Intermediate values. old_transport_mass: float = 1.0 + n_iters: int = -1 def set(self, **kwargs: Any) -> "GWOutput": """Return a copy of self, possibly with overwrites.""" @@ -99,12 +101,6 @@ def primal_cost(self) -> float: """Return transport cost of current linear OT solution at geometry.""" return self.linear_state.transport_cost_at_geom(other_geom=self.geom) - @property - def n_iters(self) -> int: # noqa: D102 - if self.errors is None: - return -1 - return jnp.sum(self.errors > -1) - class GWState(NamedTuple): """State of the Gromov-Wasserstein solver. @@ -122,6 +118,7 @@ class GWState(NamedTuple): when not using warm start. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. + iteration: The current outer GW iteration. """ costs: jnp.ndarray @@ -131,6 +128,7 @@ class GWState(NamedTuple): old_transport_mass: float rngs: Optional[jax.random.PRNGKeyArray] = None errors: Optional[jnp.ndarray] = None + iteration: int = -1 def set(self, **kwargs: Any) -> "GWState": """Return a copy of self, possibly with overwrites.""" @@ -155,7 +153,8 @@ def update( # noqa: D102 costs=costs, linear_convergence=linear_convergence, errors=errors, - old_transport_mass=old_transport_mass + old_transport_mass=old_transport_mass, + iteration=iteration, ) @@ -321,7 +320,8 @@ def output_from_state( errors=state.errors, linear_state=state.linear_state, geom=state.linear_pb.geom, - old_transport_mass=state.old_transport_mass + old_transport_mass=state.old_transport_mass, + n_iters=state.iteration + 1, ) def create_initializer( diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index ec56d430c..e2743ba9e 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -53,6 +53,7 @@ class LRGWState(NamedTuple): costs: jnp.ndarray errors: jnp.ndarray crossed_threshold: bool + iteration: int = -1 def compute_error( # noqa: D102 self, previous_state: "LRGWState" @@ -142,6 +143,7 @@ class LRGWOutput(NamedTuple): ot_prob: quadratic_problem.QuadraticProblem epsilon: float reg_gw_cost: Optional[float] = None + n_iters: int = -1 def set(self, **kwargs: Any) -> "LRGWOutput": """Return a copy of self, with potential overwrites.""" @@ -721,6 +723,7 @@ def one_iteration( costs=state.costs.at[it].set(cost), errors=state.errors.at[it].set(error), crossed_threshold=crossed_threshold, + iteration=iteration, ) if self.progress_fn is not None: @@ -793,6 +796,7 @@ def output_from_state( costs=state.costs, errors=state.errors, epsilon=self.epsilon, + n_iters=state.iteration + 1, ) def _converged(self, state: LRGWState, iteration: int) -> bool: diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 0873e9c53..3f554cc68 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -52,18 +52,12 @@ class GWBarycenterState(NamedTuple): costs: Optional[jnp.ndarray] = None costs_bary: Optional[jnp.ndarray] = None gw_convergence: Optional[jnp.ndarray] = None + n_iters: int = -1 def set(self, **kwargs: Any) -> "GWBarycenterState": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs) - @property - def n_iters(self) -> int: - """Number of iterations.""" - if self.gw_convergence is None: - return -1 - return jnp.sum(self.gw_convergence > -1) - @jax.tree_util.register_pytree_node_class class GromovWassersteinBarycenter(was_solver.WassersteinSolver): @@ -254,14 +248,15 @@ def solve_gw( costs=costs, costs_bary=costs_bary, errors=errors, - gw_convergence=gw_convergence + gw_convergence=gw_convergence, + n_iters=iteration, ) def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState: """No-op.""" # TODO(michalk8): just for consistency with continuous barycenter # will be refactored in the future to create an output - return state + return state.set(n_iters=state.n_iters + 1) def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux = super().tree_flatten() diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 6b89f45e3..4eadee2db 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import MappingProxyType -from typing import Any, List, Mapping, NamedTuple, Optional, Tuple, Type +from typing import Any, Mapping, NamedTuple, Optional, Tuple, Type import jax.numpy as jnp @@ -26,16 +26,19 @@ "SinkhornDivergenceOutput" ] +Potentials_t = Tuple[jnp.ndarray, jnp.ndarray] + class SinkhornDivergenceOutput(NamedTuple): # noqa: D101 divergence: float - potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]] + potentials: Tuple[Potentials_t, Potentials_t, Potentials_t] geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]] converged: Tuple[bool, bool, bool] a: jnp.ndarray b: jnp.ndarray + n_iters: Tuple[int, int, int] = (-1, -1, -1) def to_dual_potentials(self) -> "potentials.EntropicPotentials": """Return dual estimators :cite:`pooladian:22`, eq. 8.""" @@ -46,17 +49,6 @@ def to_dual_potentials(self) -> "potentials.EntropicPotentials": f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y ) - @property - def n_iters(self) -> Tuple[int, int, int]: # noqa: D102 - """Returns 3 number of iterations that were needed to terminate.""" - out = [] - for i in range(3): - if self.errors[i] is None: - out.append(-1) - else: - out.append(jnp.sum(self.errors[i] > -1)) - return out - def sinkhorn_divergence( geom: Type[geometry.Geometry], @@ -190,14 +182,34 @@ def _sinkhorn_divergence( out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) + 0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2 ) - out = (out_xy, out_xx, out_yy) + return SinkhornDivergenceOutput( - div, tuple([s.f, s.g] for s in out), - (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out), - tuple(s.converged for s in out), a, b + divergence=div, + potentials=((out_xy.f, out_xy.g), (out_xx.f, out_xx.g), + (out_yy.f, out_yy.g)), + geoms=(geometry_xy, geometry_xx, geometry_yy), + errors=(out_xy.errors, out_xx.errors, out_yy.errors), + converged=(out_xy.converged, out_xx.converged, out_yy.converged), + a=a, + b=b, + n_iters=(out_xy.n_iters, out_xx.n_iters, out_yy.n_iters), ) +""" +class SinkhornDivergenceOutput(NamedTuple): # noqa: D101 + divergence: float + potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]] + geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] + errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray]] + converged: Tuple[bool, bool, bool] + a: jnp.ndarray + b: jnp.ndarray + n_iters: Tuple[int, int, int] = (-1, -1, -1) +""" + + def segment_sinkhorn_divergence( x: jnp.ndarray, y: jnp.ndarray, From 7f9b508beea19121e5b3f964e5c74a1a2b2ea38f Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:44:30 +0200 Subject: [PATCH 2/3] Remove `inner_iterations` --- src/ott/solvers/linear/sinkhorn.py | 6 +----- src/ott/solvers/linear/sinkhorn_lr.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index f52affd30..c549e729e 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -317,9 +317,7 @@ class SinkhornOutput(NamedTuple): algorithm. converged: whether the output corresponds to a solution whose error is below the convergence threshold. - inner_iterations: number of iterations that were run between two - computations of errors. - + n_iters: Total number of Sinkhorn iterations. """ f: Optional[jnp.ndarray] = None @@ -329,7 +327,6 @@ class SinkhornOutput(NamedTuple): ot_prob: Optional[linear_problem.LinearProblem] = None threshold: Optional[jnp.ndarray] = None converged: Optional[bool] = None - inner_iterations: Optional[int] = None n_iters: int = -1 def set(self, **kwargs: Any) -> "SinkhornOutput": @@ -1088,7 +1085,6 @@ def output_from_state( errors=state.errors[:, 0], threshold=jnp.array(self.threshold), converged=converged, - inner_iterations=self.inner_iterations, n_iters=state.iteration + 1, ) diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 142c5529d..96c031535 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -173,7 +173,7 @@ class LRSinkhornOutput(NamedTuple): g: jnp.ndarray costs: jnp.ndarray # TODO(michalk8): must be called `errors`, because of `store_inner_errors` - # in future, enforce via class hierarchy + # in the future, enforce via class hierarchy errors: jnp.ndarray ot_prob: linear_problem.LinearProblem epsilon: float From 799db562f7e7745332207494facb06e7c0e423ea Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 12 Sep 2023 17:09:07 +0200 Subject: [PATCH 3/3] Extend fixed-point iter --- src/ott/geometry/costs.py | 2 +- src/ott/geometry/graph.py | 5 ++- .../initializers/linear/initializers_lr.py | 2 +- src/ott/math/fixed_point_loop.py | 37 ++++++++++++------- src/ott/math/matrix_square_root.py | 2 +- .../solvers/linear/continuous_barycenter.py | 2 +- src/ott/solvers/linear/discrete_barycenter.py | 2 +- src/ott/solvers/linear/lr_utils.py | 4 +- src/ott/solvers/linear/sinkhorn.py | 13 ++++--- src/ott/solvers/linear/sinkhorn_lr.py | 12 +++--- .../solvers/quadratic/gromov_wasserstein.py | 10 ++--- .../quadratic/gromov_wasserstein_lr.py | 8 ++-- src/ott/solvers/quadratic/gw_barycenter.py | 12 +++--- src/ott/tools/k_means.py | 4 +- 14 files changed, 62 insertions(+), 53 deletions(-) diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 9f1a6c3a0..d8632172a 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -778,7 +778,7 @@ def init_state() -> Tuple[jnp.ndarray, float]: ) return cov_init, diffs - cov, diffs = fixed_point_loop.fixpoint_iter( + (cov, diffs), _ = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=min_iterations, diff --git a/src/ott/geometry/graph.py b/src/ott/geometry/graph.py index c7dac0c99..04098c35a 100644 --- a/src/ott/geometry/graph.py +++ b/src/ott/geometry/graph.py @@ -175,7 +175,7 @@ def body_fn( else: constants = L, None - return fixpoint_fn( + (res, _), _ = fixpoint_fn( cond_fn=(lambda *_, **__: True) if force_scan else conf_fn, body_fn=body_fn, min_iterations=self.n_steps if force_scan else 1, @@ -183,7 +183,8 @@ def body_fn( inner_iterations=1, constants=constants, state=state, - )[1] + ) + return res @property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index 06c1f5b65..49dbade57 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -652,4 +652,4 @@ def body_fn( inner_iterations=inner_iterations, constants=consts, state=init_fn(), - ).factor + )[0].factor diff --git a/src/ott/math/fixed_point_loop.py b/src/ott/math/fixed_point_loop.py index 9034eba62..ae954df8b 100644 --- a/src/ott/math/fixed_point_loop.py +++ b/src/ott/math/fixed_point_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable +from typing import Any, Callable, Tuple import jax import jax.numpy as jnp @@ -24,7 +24,7 @@ def fixpoint_iter( cond_fn: Callable[[int, Any, Any], bool], body_fn: Callable[[Any, Any, Any, Any], Any], min_iterations: int, max_iterations: int, inner_iterations: int, constants: Any, state: Any -): +) -> Tuple[Any, int]: """Implementation of a fixed point loop. This fixed point loop iterator applies ``body_fn`` to a tuple @@ -51,7 +51,8 @@ def fixpoint_iter( state : state variable Returns: - outputs state returned by ``body_fn`` upon termination. + outputs state returned by ``body_fn`` upon termination and + the total number of iterations. """ # noqa: D401 # If number of minimal iterations matches maximal number, force a scan instead # of a while loop. @@ -83,20 +84,23 @@ def one_iteration(iteration_state, compute_error): return (iteration_state, None) if force_scan else iteration_state if force_scan: + n_iters = max_iterations // inner_iterations (_, state), _ = jax.lax.scan( lambda carry, x: unrolled_body_fn(carry), (0, state), None, length=max_iterations // inner_iterations ) else: - _, state = jax.lax.while_loop(max_cond_fn, unrolled_body_fn, (0, state)) - return state + n_iters, state = jax.lax.while_loop( + max_cond_fn, unrolled_body_fn, (0, state) + ) + return state, n_iters def fixpoint_iter_fwd( cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, constants, state -): +) -> Tuple[Any, int]: """Forward iteration of fixed point iteration to handle backpropagation. The main difference with fixpoint_iter is the checkpointing, in variable @@ -114,7 +118,8 @@ def fixpoint_iter_fwd( state : state variable Returns: - outputs state returned by body_fn upon termination. + outputs state returned by body_fn upon termination and + the total number of iterations. """ force_scan = min_iterations == max_iterations compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1 @@ -166,7 +171,7 @@ def one_iteration(iteration_state, compute_error): max_cond_fn, unrolled_body_fn, (0, states, state) ) - return state, (constants, iteration, states) + return (state, iteration), (constants, iteration, states) def fixpoint_iter_bwd( @@ -174,6 +179,8 @@ def fixpoint_iter_bwd( ): """Backward iteration of fixed point iteration, using checkpointed states.""" del cond_fn + g, _ = g # out, iteration + force_scan = (min_iterations == max_iterations) constants, iteration, states = res # The tree may contain some python floats @@ -187,16 +194,15 @@ def bwd_cond_fn(iteration_g_gconst): return iteration >= 0 def unrolled_body_fn_no_errors(iteration, constants, state): - compute_error_flags = jnp.zeros((inner_iterations,), dtype=bool) - def one_iteration(iteration_state, compute_error): + def one_iteration(iteration_state, _): iteration, state = iteration_state - state = body_fn(iteration, constants, state, compute_error) + state = body_fn(iteration, constants, state, False) iteration += 1 return (iteration, state), None iteration_state, _ = jax.lax.scan( - one_iteration, (iteration, state), compute_error_flags + one_iteration, (iteration, state), xs=None, length=inner_iterations ) _, state = iteration_state return state @@ -207,7 +213,10 @@ def unrolled_body_fn(iteration_g_gconst): lambda x: x[iteration // inner_iterations], states ) _, pullback = jax.vjp( - unrolled_body_fn_no_errors, iteration, constants, state + unrolled_body_fn_no_errors, + iteration, + constants, + state, ) _, gi_constants, g_state = pullback(g) g_constants = jax.tree_util.tree_map( @@ -218,7 +227,7 @@ def unrolled_body_fn(iteration_g_gconst): if force_scan: (_, g_state, g_constants), _ = jax.lax.scan( - lambda carry, x: unrolled_body_fn(carry), (0, g, g_constants), + lambda carry, _: unrolled_body_fn(carry), (0, g, g_constants), None, length=max_iterations // inner_iterations ) diff --git a/src/ott/math/matrix_square_root.py b/src/ott/math/matrix_square_root.py index 4a0177780..7796cf190 100644 --- a/src/ott/math/matrix_square_root.py +++ b/src/ott/math/matrix_square_root.py @@ -107,7 +107,7 @@ def new_err(x, norm_x, y): dtype=dtype) state = (errors, y, z) const = (x, threshold) - errors, y, z = fixed_point_loop.fixpoint_iter_backprop( + (errors, y, z), _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const, state ) diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index c4d717fb7..1c0f85680 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -220,7 +220,7 @@ def body_fn( iteration, bar_prob, solver.linear_ot_solver, solver.store_inner_errors ) - state = fixed_point_loop.fixpoint_iter( + state, _ = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=solver.min_iterations, diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index dcfdc1470..d6799d30a 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -229,7 +229,7 @@ def body_fn(iteration, const, state, compute_error): state = (errors, d, f_u, g_v) - state = fixed_point_loop.fixpoint_iter_backprop( + state, _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const, state ) diff --git a/src/ott/solvers/linear/lr_utils.py b/src/ott/solvers/linear/lr_utils.py index 8ade265c9..661004978 100644 --- a/src/ott/solvers/linear/lr_utils.py +++ b/src/ott/solvers/linear/lr_utils.py @@ -169,7 +169,7 @@ def body_fn( err=jnp.inf, ) - state: State = fixed_point_loop.fixpoint_iter_backprop( + state, _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, init_state ) @@ -306,7 +306,7 @@ def body_fn( err=jnp.inf ) - state: State = fixed_point_loop.fixpoint_iter_backprop( + state, _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, init_state ) diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index c549e729e..2deac09c4 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -57,7 +57,6 @@ class SinkhornState(NamedTuple): gv: Optional[jnp.ndarray] = None old_fus: Optional[jnp.ndarray] = None old_mapped_fus: Optional[jnp.ndarray] = None - iteration: int = -1 def set(self, **kwargs: Any) -> "SinkhornState": """Return a copy of self, with potential overwrites.""" @@ -984,7 +983,7 @@ def one_iteration( ot_prob, ) errors = state.errors.at[iteration // self.inner_iterations, :].set(err) - state = state.set(iteration=iteration, errors=errors) + state = state.set(errors=errors) if self.progress_fn is not None: jax.experimental.io_callback( @@ -1029,7 +1028,9 @@ def init_state( return self.anderson.init_maps(ot_prob, state) if self.anderson else state def output_from_state( - self, ot_prob: linear_problem.LinearProblem, state: SinkhornState + self, + ot_prob: linear_problem.LinearProblem, + state: SinkhornState, ) -> SinkhornOutput: """Create an output from a loop state. @@ -1085,7 +1086,6 @@ def output_from_state( errors=state.errors[:, 0], threshold=jnp.array(self.threshold), converged=converged, - n_iters=state.iteration + 1, ) @property @@ -1168,11 +1168,12 @@ def body_fn( const = ot_prob, solver state = solver.init_state(ot_prob, init) - state = fix_point( + state, n_iters = fix_point( cond_fn, body_fn, solver.min_iterations, solver.max_iterations, solver.inner_iterations, const, state ) - return solver.output_from_state(ot_prob, state) + out = solver.output_from_state(ot_prob, state) + return out.set(n_iters=n_iters) def _iterations_taped( diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 96c031535..89cafe395 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -51,7 +51,6 @@ class LRSinkhornState(NamedTuple): costs: jnp.ndarray errors: jnp.ndarray crossed_threshold: bool - iteration: int = -1 def compute_error( # noqa: D102 self, previous_state: "LRSinkhornState" @@ -526,7 +525,7 @@ def recompute_couplings( g = jnp.exp(gamma * h) return q, r, g - state_inner = fixed_point_loop.fixpoint_iter_backprop( + state_inner, _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner ) @@ -623,7 +622,7 @@ def recompute_couplings( r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1)) return q, r, g - state_inner = fixed_point_loop.fixpoint_iter_backprop( + state_inner, _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner ) @@ -711,7 +710,6 @@ def one_iteration( costs=state.costs.at[it].set(cost), errors=state.errors.at[it].set(error), crossed_threshold=crossed_threshold, - iteration=iteration, ) if self.progress_fn is not None: @@ -763,13 +761,16 @@ def init_state( ) def output_from_state( - self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState + self, + ot_prob: linear_problem.LinearProblem, + state: LRSinkhornState, ) -> LRSinkhornOutput: """Create an output from a loop state. Args: ot_prob: the transport problem. state: a LRSinkhornState. + n_iters: Number of iteration. Returns: A LRSinkhornOutput. @@ -782,7 +783,6 @@ def output_from_state( costs=state.costs, errors=state.errors, epsilon=self.epsilon, - n_iters=state.iteration + 1, ) def _converged(self, state: LRSinkhornState, iteration: int) -> bool: diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 4345da897..40961519e 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -118,7 +118,6 @@ class GWState(NamedTuple): when not using warm start. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. - iteration: The current outer GW iteration. """ costs: jnp.ndarray @@ -128,7 +127,6 @@ class GWState(NamedTuple): old_transport_mass: float rngs: Optional[jax.random.PRNGKeyArray] = None errors: Optional[jnp.ndarray] = None - iteration: int = -1 def set(self, **kwargs: Any) -> "GWState": """Return a copy of self, possibly with overwrites.""" @@ -154,7 +152,6 @@ def update( # noqa: D102 linear_convergence=linear_convergence, errors=errors, old_transport_mass=old_transport_mass, - iteration=iteration, ) @@ -321,7 +318,6 @@ def output_from_state( linear_state=state.linear_state, geom=state.linear_pb.geom, old_transport_mass=state.old_transport_mass, - n_iters=state.iteration + 1, ) def create_initializer( @@ -412,7 +408,7 @@ def body_fn( return new_state - state = fixed_point_loop.fixpoint_iter( + state, n_iters = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=solver.min_iterations, @@ -421,8 +417,8 @@ def body_fn( constants=solver, state=solver.init_state(prob, init, rng=rng) ) - - return solver.output_from_state(state) + out = solver.output_from_state(state) + return out.set(n_iters=n_iters) def solve( diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index e2743ba9e..66a46e77a 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -535,7 +535,7 @@ def recompute_couplings( g = jnp.exp(gamma * h) return q, r, g - state_inner = fixed_point_loop.fixpoint_iter_backprop( + state_inner, _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner ) @@ -634,7 +634,7 @@ def recompute_couplings( r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1)) return q, r, g - state_inner = fixed_point_loop.fixpoint_iter_backprop( + state_inner, _ = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner ) @@ -777,7 +777,9 @@ def init_state( ) def output_from_state( - self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState + self, + ot_prob: quadratic_problem.QuadraticProblem, + state: LRGWState, ) -> LRGWOutput: """Create an output from a loop state. diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 3f554cc68..4f1a84777 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -120,8 +120,9 @@ def __call__( The solution. """ state = self.init_state(problem, bar_size, **kwargs) - state = iterations(self, problem, state) - return self.output_from_state(state) + state, n_iters = iterations(self, problem, state) + out = self.output_from_state(state) + return out.set(n_iters=n_iters) def init_state( self, @@ -249,14 +250,13 @@ def solve_gw( costs_bary=costs_bary, errors=errors, gw_convergence=gw_convergence, - n_iters=iteration, ) def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState: """No-op.""" # TODO(michalk8): just for consistency with continuous barycenter # will be refactored in the future to create an output - return state.set(n_iters=state.n_iters + 1) + return state def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux = super().tree_flatten() @@ -305,7 +305,7 @@ def init_transports( def iterations( # noqa: D103 solver: GromovWassersteinBarycenter, problem: gw_barycenter.GWBarycenterProblem, init_state: GWBarycenterState -) -> GWBarycenterState: +) -> Tuple[GWBarycenterState, int]: def cond_fn( iteration: int, constants: GromovWassersteinBarycenter, @@ -331,4 +331,4 @@ def body_fn( inner_iterations=1, constants=(solver, problem), state=init_state, - ) + )[0] diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index 4b0b02dda..d8510d70a 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -163,7 +163,7 @@ def body_fn( state = init_fn(geom, rng) constants = (geom, jnp.arange(geom.shape[0])) - state = fixed_point_loop.fixpoint_iter( + state, _ = fixed_point_loop.fixpoint_iter( lambda *_, **__: True, body_fn, min_iterations=k - 1, @@ -324,7 +324,7 @@ def finalize_fn(const: KMeansConst, state: KMeansState) -> KMeansState: x_weights = jnp.hstack([weights[:, None] * geom.x, weights[:, None]]) const = KMeansConst(geom, x_weights) - state = fixpoint_fn( + state, _ = fixpoint_fn( cond_fn, body_fn, min_iterations=min_iterations,