Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better tracking of n_iters #436

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class SinkhornState(NamedTuple):
gv: Optional[jnp.ndarray] = None
old_fus: Optional[jnp.ndarray] = None
old_mapped_fus: Optional[jnp.ndarray] = None
iteration: int = -1

def set(self, **kwargs: Any) -> "SinkhornState":
"""Return a copy of self, with potential overwrites."""
Expand Down Expand Up @@ -329,6 +330,7 @@ class SinkhornOutput(NamedTuple):
threshold: Optional[jnp.ndarray] = None
converged: Optional[bool] = None
inner_iterations: Optional[int] = None
n_iters: int = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am wondering if we need to keep inner_iterations now (it was added to compute the right number of iterations), WDYT? I think it might make sense to keep it nonetheless (e.g. when looking at errors), maybe this was your conclusion too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be safely removed, was not being used anywhere and I don't think it helps when looking at the errors. Same goes for threshold, would also remove.
Will try in the upcoming weeks to unify the output interface for our problems.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or was/is there a reason why threshold was put there before? Maybe because of traceability?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The threshold is there to settle converged in the case where min_iterations==max_iterations. It was added specifically for that case.

For inner_iterations, i think we should have the user in mind. If the user saves a SinkhornOutput as a checkpoint, it's impossible to reconstruct the "true" number of iterations that were needed to reach a certain error level, nor the threshold that conditioned stopping (apart maybe now by looking at last error not marked as -1, and then divide now n_iters by that number, but that does sound very clean)


def set(self, **kwargs: Any) -> "SinkhornOutput":
"""Return a copy of self, with potential overwrites."""
Expand Down Expand Up @@ -444,14 +446,6 @@ def b(self) -> jnp.ndarray: # noqa: D102
def linear_output(self) -> bool: # noqa: D102
return True

# TODO(michalk8): this should be always present
@property
def n_iters(self) -> int: # noqa: D102
"""Returns the total number of iterations that were needed to terminate."""
if self.errors is None:
return -1
return jnp.sum(self.errors > -1) * self.inner_iterations

@property
def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102
u = self.ot_prob.geom.scaling_from_potential(self.f)
Expand Down Expand Up @@ -993,7 +987,7 @@ def one_iteration(
ot_prob,
)
errors = state.errors.at[iteration // self.inner_iterations, :].set(err)
state = state.set(errors=errors)
state = state.set(iteration=iteration, errors=errors)

if self.progress_fn is not None:
jax.experimental.io_callback(
Expand Down Expand Up @@ -1094,7 +1088,8 @@ def output_from_state(
errors=state.errors[:, 0],
threshold=jnp.array(self.threshold),
converged=converged,
inner_iterations=self.inner_iterations
inner_iterations=self.inner_iterations,
n_iters=state.iteration + 1,
)

@property
Expand Down
4 changes: 4 additions & 0 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class LRSinkhornState(NamedTuple):
costs: jnp.ndarray
errors: jnp.ndarray
crossed_threshold: bool
iteration: int = -1

def compute_error( # noqa: D102
self, previous_state: "LRSinkhornState"
Expand Down Expand Up @@ -178,6 +179,7 @@ class LRSinkhornOutput(NamedTuple):
epsilon: float
# TODO(michalk8): Optional is an artifact of the current impl., refactor
reg_ot_cost: Optional[float] = None
n_iters: int = -1

def set(self, **kwargs: Any) -> "LRSinkhornOutput":
"""Return a copy of self, with potential overwrites."""
Expand Down Expand Up @@ -709,6 +711,7 @@ def one_iteration(
costs=state.costs.at[it].set(cost),
errors=state.errors.at[it].set(error),
crossed_threshold=crossed_threshold,
iteration=iteration,
)

if self.progress_fn is not None:
Expand Down Expand Up @@ -779,6 +782,7 @@ def output_from_state(
costs=state.costs,
errors=state.errors,
epsilon=self.epsilon,
n_iters=state.iteration + 1,
)

def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
Expand Down
16 changes: 8 additions & 8 deletions src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class GWOutput(NamedTuple):
linearization of GW.
geom: The geometry underlying the local linearization.
old_transport_mass: Holds total mass of transport at previous iteration.
n_iters: Number of the outer GW iterations.
"""

costs: Optional[jnp.ndarray] = None
Expand All @@ -71,6 +72,7 @@ class GWOutput(NamedTuple):
geom: Optional[geometry.Geometry] = None
# Intermediate values.
old_transport_mass: float = 1.0
n_iters: int = -1

def set(self, **kwargs: Any) -> "GWOutput":
"""Return a copy of self, possibly with overwrites."""
Expand Down Expand Up @@ -99,12 +101,6 @@ def primal_cost(self) -> float:
"""Return transport cost of current linear OT solution at geometry."""
return self.linear_state.transport_cost_at_geom(other_geom=self.geom)

@property
def n_iters(self) -> int: # noqa: D102
if self.errors is None:
return -1
return jnp.sum(self.errors > -1)


class GWState(NamedTuple):
"""State of the Gromov-Wasserstein solver.
Expand All @@ -122,6 +118,7 @@ class GWState(NamedTuple):
when not using warm start.
errors: Holds sequence of vectors of errors of the Sinkhorn algorithm
at each iteration.
iteration: The current outer GW iteration.
"""

costs: jnp.ndarray
Expand All @@ -131,6 +128,7 @@ class GWState(NamedTuple):
old_transport_mass: float
rngs: Optional[jax.random.PRNGKeyArray] = None
errors: Optional[jnp.ndarray] = None
iteration: int = -1

def set(self, **kwargs: Any) -> "GWState":
"""Return a copy of self, possibly with overwrites."""
Expand All @@ -155,7 +153,8 @@ def update( # noqa: D102
costs=costs,
linear_convergence=linear_convergence,
errors=errors,
old_transport_mass=old_transport_mass
old_transport_mass=old_transport_mass,
iteration=iteration,
)


Expand Down Expand Up @@ -321,7 +320,8 @@ def output_from_state(
errors=state.errors,
linear_state=state.linear_state,
geom=state.linear_pb.geom,
old_transport_mass=state.old_transport_mass
old_transport_mass=state.old_transport_mass,
n_iters=state.iteration + 1,
)

def create_initializer(
Expand Down
4 changes: 4 additions & 0 deletions src/ott/solvers/quadratic/gromov_wasserstein_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class LRGWState(NamedTuple):
costs: jnp.ndarray
errors: jnp.ndarray
crossed_threshold: bool
iteration: int = -1

def compute_error( # noqa: D102
self, previous_state: "LRGWState"
Expand Down Expand Up @@ -142,6 +143,7 @@ class LRGWOutput(NamedTuple):
ot_prob: quadratic_problem.QuadraticProblem
epsilon: float
reg_gw_cost: Optional[float] = None
n_iters: int = -1

def set(self, **kwargs: Any) -> "LRGWOutput":
"""Return a copy of self, with potential overwrites."""
Expand Down Expand Up @@ -721,6 +723,7 @@ def one_iteration(
costs=state.costs.at[it].set(cost),
errors=state.errors.at[it].set(error),
crossed_threshold=crossed_threshold,
iteration=iteration,
)

if self.progress_fn is not None:
Expand Down Expand Up @@ -793,6 +796,7 @@ def output_from_state(
costs=state.costs,
errors=state.errors,
epsilon=self.epsilon,
n_iters=state.iteration + 1,
)

def _converged(self, state: LRGWState, iteration: int) -> bool:
Expand Down
13 changes: 4 additions & 9 deletions src/ott/solvers/quadratic/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,12 @@ class GWBarycenterState(NamedTuple):
costs: Optional[jnp.ndarray] = None
costs_bary: Optional[jnp.ndarray] = None
gw_convergence: Optional[jnp.ndarray] = None
n_iters: int = -1

def set(self, **kwargs: Any) -> "GWBarycenterState":
"""Return a copy of self, possibly with overwrites."""
return self._replace(**kwargs)

@property
def n_iters(self) -> int:
"""Number of iterations."""
if self.gw_convergence is None:
return -1
return jnp.sum(self.gw_convergence > -1)


@jax.tree_util.register_pytree_node_class
class GromovWassersteinBarycenter(was_solver.WassersteinSolver):
Expand Down Expand Up @@ -254,14 +248,15 @@ def solve_gw(
costs=costs,
costs_bary=costs_bary,
errors=errors,
gw_convergence=gw_convergence
gw_convergence=gw_convergence,
n_iters=iteration,
)

def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState:
"""No-op."""
# TODO(michalk8): just for consistency with continuous barycenter
# will be refactored in the future to create an output
return state
return state.set(n_iters=state.n_iters + 1)

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
children, aux = super().tree_flatten()
Expand Down
46 changes: 29 additions & 17 deletions src/ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from types import MappingProxyType
from typing import Any, List, Mapping, NamedTuple, Optional, Tuple, Type
from typing import Any, Mapping, NamedTuple, Optional, Tuple, Type

import jax.numpy as jnp

Expand All @@ -26,16 +26,19 @@
"SinkhornDivergenceOutput"
]

Potentials_t = Tuple[jnp.ndarray, jnp.ndarray]


class SinkhornDivergenceOutput(NamedTuple): # noqa: D101
divergence: float
potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]]
potentials: Tuple[Potentials_t, Potentials_t, Potentials_t]
geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry]
errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray]]
converged: Tuple[bool, bool, bool]
a: jnp.ndarray
b: jnp.ndarray
n_iters: Tuple[int, int, int] = (-1, -1, -1)

def to_dual_potentials(self) -> "potentials.EntropicPotentials":
"""Return dual estimators :cite:`pooladian:22`, eq. 8."""
Expand All @@ -46,17 +49,6 @@ def to_dual_potentials(self) -> "potentials.EntropicPotentials":
f_xy, g_xy, prob_xy, f_xx=f_x, g_yy=g_y
)

@property
def n_iters(self) -> Tuple[int, int, int]: # noqa: D102
"""Returns 3 number of iterations that were needed to terminate."""
out = []
for i in range(3):
if self.errors[i] is None:
out.append(-1)
else:
out.append(jnp.sum(self.errors[i] > -1))
return out


def sinkhorn_divergence(
geom: Type[geometry.Geometry],
Expand Down Expand Up @@ -190,14 +182,34 @@ def _sinkhorn_divergence(
out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
)
out = (out_xy, out_xx, out_yy)

return SinkhornDivergenceOutput(
div, tuple([s.f, s.g] for s in out),
(geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
tuple(s.converged for s in out), a, b
divergence=div,
potentials=((out_xy.f, out_xy.g), (out_xx.f, out_xx.g),
(out_yy.f, out_yy.g)),
geoms=(geometry_xy, geometry_xx, geometry_yy),
errors=(out_xy.errors, out_xx.errors, out_yy.errors),
converged=(out_xy.converged, out_xx.converged, out_yy.converged),
a=a,
b=b,
n_iters=(out_xy.n_iters, out_xx.n_iters, out_yy.n_iters),
)


"""
class SinkhornDivergenceOutput(NamedTuple): # noqa: D101
divergence: float
potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]]
geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry]
errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray]]
converged: Tuple[bool, bool, bool]
a: jnp.ndarray
b: jnp.ndarray
n_iters: Tuple[int, int, int] = (-1, -1, -1)
"""


def segment_sinkhorn_divergence(
x: jnp.ndarray,
y: jnp.ndarray,
Expand Down
Loading