Skip to content

Commit

Permalink
Fix n_iters (#437)
Browse files Browse the repository at this point in the history
* Start cleaning `n_iters`

* Remove dead code

* Update SD

* Update GW tutorial

* Remove unnecessary hierarchy in tutorials

* Remove `n_iters` property from SD

* Fix not passing `inner_iterations`

* Skip testing `flax` on 3.8

* Fix SD test

* Use `.toarray()`

* Update `FenchelConjugateLBFGS` for `jaxopt>=0.8`

* Fix missing `importorskip`

* Split `MetaInitializer` test

* [ci skip] Fix typos in bary docs

* Skip tests that need `optax`

* Remove relative pip install from MG notebook
  • Loading branch information
michalk8 committed Sep 13, 2023
1 parent a3d1202 commit 0327ae3
Show file tree
Hide file tree
Showing 19 changed files with 209 additions and 125 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Packages
:maxdepth: 1
:caption: Examples

Getting Started <tutorials/notebooks/basic_ot_between_datasets>
Getting Started <tutorials/basic_ot_between_datasets>
tutorials/index

.. toctree::
Expand Down
42 changes: 21 additions & 21 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,54 @@ Geometry
.. toctree::
:maxdepth: 1

notebooks/introduction_grid
introduction_grid

Linear Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/point_clouds
notebooks/One_Sinkhorn
notebooks/OTT_&_POT
notebooks/Hessians
notebooks/LRSinkhorn
notebooks/sinkhorn_divergence_gradient_flow
notebooks/sparse_monge_displacements
point_clouds
One_Sinkhorn
OTT_&_POT
Hessians
LRSinkhorn
sinkhorn_divergence_gradient_flow
sparse_monge_displacements

Barycenters
^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/Sinkhorn_Barycenters
notebooks/gmm_pair_demo
notebooks/wasserstein_barycenters_gmms
Sinkhorn_Barycenters
gmm_pair_demo
wasserstein_barycenters_gmms

Miscellaneous
^^^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/tracking_progress
notebooks/soft_sort
notebooks/application_biology
tracking_progress
soft_sort
application_biology

Quadratic Optimal Transport
---------------------------
.. toctree::
:maxdepth: 1

notebooks/gromov_wasserstein
notebooks/GWLRSinkhorn
notebooks/gromov_wasserstein_multiomics
gromov_wasserstein
GWLRSinkhorn
gromov_wasserstein_multiomics

Neural Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/neural_dual
notebooks/icnn_inits
notebooks/MetaOT
notebooks/Monge_Gap
neural_dual
icnn_inits
MetaOT
Monge_Gap
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ legacy_tox_ini = """
skip_missing_interpreters = true
[testenv]
extras = test,neural
extras =
test
# https://github.com/google/flax/issues/3329
py{3.9,3.10,3.11},py3.9-jax-default: neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down Expand Up @@ -299,7 +302,8 @@ select = [
unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
target-version = "py38"
[tool.ruff.per-file-ignores]
"tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/`
# TODO(michalk8): PO004 - remove `self.initialize`
"tests/*" = ["D", "PT004", "E402"]
"*/__init__.py" = ["F401"]
"docs/*" = ["D"]
"src/ott/types.py" = ["D102"]
Expand Down
8 changes: 4 additions & 4 deletions src/ott/problems/linear/barycenter_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ott.geometry import costs, geometry, segment

__all__ = ["FreeBarycenterProblem"]
__all__ = ["FreeBarycenterProblem", "FixedBarycenterProblem"]


@jax.tree_util.register_pytree_node_class
Expand All @@ -44,8 +44,8 @@ class FreeBarycenterProblem:
Only used when ``y`` is not already segmented. When passing
``segment_ids``, 2 arguments must be specified for jitting to work:
- ``num_segments`` - the total number of measures.
- ``max_measure_size`` - maximum of support sizes of these measures.
- ``num_segments`` - the total number of measures.
- ``max_measure_size`` - maximum of support sizes of these measures.
"""

def __init__(
Expand Down Expand Up @@ -158,7 +158,7 @@ class FixedBarycenterProblem:
a: batch of histograms of shape ``[batch, num_a]`` where ``num_a`` matches
the first value of the :attr:`~ott.geometry.Geometry.shape` attribute of
``geom``.
weights: ``[batch,]`` positive weights summing to :math`1`. Uniform by
weights: ``[batch,]`` positive weights summing to :math:`1`. Uniform by
default.
"""

Expand Down
14 changes: 1 addition & 13 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ class SinkhornOutput(NamedTuple):
below the convergence threshold.
inner_iterations: number of iterations that were run between two
computations of errors.
"""

f: Optional[jnp.ndarray] = None
Expand Down Expand Up @@ -424,10 +423,6 @@ def transport_cost_at_geom(
return jnp.sum(self.apply(geom.cost_1.T) * geom.cost_2.T)
return jnp.sum(self.matrix * other_geom.cost_matrix)

@property
def linear(self) -> bool: # noqa: D102
return isinstance(self.ot_prob, linear_problem.LinearProblem)

@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
Expand All @@ -440,17 +435,10 @@ def a(self) -> jnp.ndarray: # noqa: D102
def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b

@property
def linear_output(self) -> bool: # noqa: D102
return True

# TODO(michalk8): this should be always present
@property
def n_iters(self) -> int: # noqa: D102
"""Returns the total number of iterations that were needed to terminate."""
if self.errors is None:
return -1
return jnp.sum(self.errors > -1) * self.inner_iterations
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102
Expand Down
10 changes: 4 additions & 6 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class LRSinkhornOutput(NamedTuple):
errors: jnp.ndarray
ot_prob: linear_problem.LinearProblem
epsilon: float
inner_iterations: int
# TODO(michalk8): Optional is an artifact of the current impl., refactor
reg_ot_cost: Optional[float] = None

Expand Down Expand Up @@ -205,10 +206,6 @@ def compute_reg_ot_cost( # noqa: D102
use_danskin=use_danskin
)

@property
def linear(self) -> bool: # noqa: D102
return isinstance(self.ot_prob, linear_problem.LinearProblem)

@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
Expand All @@ -222,8 +219,8 @@ def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b

@property
def linear_output(self) -> bool: # noqa: D102
return True
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
Expand Down Expand Up @@ -773,6 +770,7 @@ def output_from_state(
costs=state.costs,
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
)

def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
Expand Down
16 changes: 9 additions & 7 deletions src/ott/solvers/nn/conjugate_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,18 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver):
max_iter: maximum number of iterations
max_linesearch_iter: maximum number of line search iterations
linesearch_type: type of line search
decrease_factor: decrease factor for a backtracking line search
ls_method: the line search method
linesearch_init: strategy for line search initialization
increase_factor: factor by which to increase the step size during
the line search
"""

gtol: float = 1e-3
max_iter: int = 10
max_linesearch_iter: int = 10
linesearch_type: Literal["zoom", "backtracking"] = "backtracking"
decrease_factor: float = 0.66
ls_method: Literal["wolf", "strong-wolfe"] = "strong-wolfe"
linesearch_type: Literal["zoom", "backtracking",
"hager-zhang"] = "backtracking"
linesearch_init: Literal["increase", "max", "current"] = "increase"
increase_factor: float = 1.5

def solve( # noqa: D102
self,
Expand All @@ -98,9 +100,9 @@ def solve( # noqa: D102
fun=lambda x: f(x) - x.dot(y),
tol=self.gtol,
maxiter=self.max_iter,
decrease_factor=self.decrease_factor,
linesearch=self.linesearch_type,
condition=self.ls_method,
linesearch_init=self.linesearch_init,
increase_factor=self.increase_factor,
implicit_diff=False,
unroll=False
)
Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def primal_cost(self) -> float:
def n_iters(self) -> int: # noqa: D102
if self.errors is None:
return -1
return jnp.sum(self.errors > -1)
return jnp.sum(self.errors[:, 0] != -1)


class GWState(NamedTuple):
Expand Down
19 changes: 12 additions & 7 deletions src/ott/solvers/quadratic/gromov_wasserstein_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ def ent(x: jnp.ndarray) -> float:
g = jax.lax.stop_gradient(g) if use_danskin else g

out = LRGWOutput(
q=q, r=r, g=g, ot_prob=ot_prob, costs=None, errors=None, epsilon=None
q=q,
r=r,
g=g,
ot_prob=ot_prob,
costs=None,
errors=None,
epsilon=None,
inner_iterations=None,
)

cost = out.primal_cost - epsilon * (ent(q) + ent(r) + ent(g))
Expand All @@ -141,6 +148,7 @@ class LRGWOutput(NamedTuple):
errors: jnp.ndarray
ot_prob: quadratic_problem.QuadraticProblem
epsilon: float
inner_iterations: int
reg_gw_cost: Optional[float] = None

def set(self, **kwargs: Any) -> "LRGWOutput":
Expand Down Expand Up @@ -170,10 +178,6 @@ def compute_reg_gw_cost( # noqa: D102
use_danskin=use_danskin
)

@property
def linear(self) -> bool: # noqa: D102
return False

@property
def geom(self) -> geometry.Geometry: # noqa: D102
"""Linearized geometry."""
Expand All @@ -188,8 +192,8 @@ def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b

@property
def linear_output(self) -> bool: # noqa: D102
return False
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
Expand Down Expand Up @@ -792,6 +796,7 @@ def output_from_state(
costs=state.costs,
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
)

def _converged(self, state: LRGWState, iteration: int) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/quadratic/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def n_iters(self) -> int:
"""Number of iterations."""
if self.gw_convergence is None:
return -1
return jnp.sum(self.gw_convergence > -1)
return jnp.sum(self.gw_convergence != -1)


@jax.tree_util.register_pytree_node_class
Expand Down
Loading

0 comments on commit 0327ae3

Please sign in to comment.