Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix n_iters #437

Merged
merged 17 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Packages
:maxdepth: 1
:caption: Examples

Getting Started <tutorials/notebooks/basic_ot_between_datasets>
Getting Started <tutorials/basic_ot_between_datasets>
tutorials/index

.. toctree::
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -22,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -38,7 +37,6 @@
"\n",
"from matplotlib import pyplot as plt\n",
"\n",
"%pip install -e ../../../\n",
"from ott.geometry import costs, pointcloud\n",
"from ott.problems.nn import dataset\n",
"from ott.solvers.linear import acceleration\n",
Expand All @@ -47,17 +45,6 @@
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -83,7 +70,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -199,7 +185,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -293,7 +278,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -340,7 +324,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -464,7 +447,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -531,7 +513,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -575,7 +556,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -649,9 +629,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "ott",
"language": "python",
"name": "python3"
"name": "ott"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -663,7 +643,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
16,285 changes: 16,285 additions & 0 deletions docs/tutorials/gromov_wasserstein.ipynb

Large diffs are not rendered by default.

42 changes: 21 additions & 21 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,54 @@ Geometry
.. toctree::
:maxdepth: 1

notebooks/introduction_grid
introduction_grid

Linear Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/point_clouds
notebooks/One_Sinkhorn
notebooks/OTT_&_POT
notebooks/Hessians
notebooks/LRSinkhorn
notebooks/sinkhorn_divergence_gradient_flow
notebooks/sparse_monge_displacements
point_clouds
One_Sinkhorn
OTT_&_POT
Hessians
LRSinkhorn
sinkhorn_divergence_gradient_flow
sparse_monge_displacements

Barycenters
^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/Sinkhorn_Barycenters
notebooks/gmm_pair_demo
notebooks/wasserstein_barycenters_gmms
Sinkhorn_Barycenters
gmm_pair_demo
wasserstein_barycenters_gmms

Miscellaneous
^^^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/tracking_progress
notebooks/soft_sort
notebooks/application_biology
tracking_progress
soft_sort
application_biology

Quadratic Optimal Transport
---------------------------
.. toctree::
:maxdepth: 1

notebooks/gromov_wasserstein
notebooks/GWLRSinkhorn
notebooks/gromov_wasserstein_multiomics
gromov_wasserstein
GWLRSinkhorn
gromov_wasserstein_multiomics

Neural Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/neural_dual
notebooks/icnn_inits
notebooks/MetaOT
notebooks/Monge_Gap
neural_dual
icnn_inits
MetaOT
Monge_Gap
16,279 changes: 0 additions & 16,279 deletions docs/tutorials/notebooks/gromov_wasserstein.ipynb

This file was deleted.

8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ legacy_tox_ini = """
skip_missing_interpreters = true

[testenv]
extras = test,neural
extras =
test
# https://github.com/google/flax/issues/3329
py{3.9,3.10,3.11},py3.9-jax-default: neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down Expand Up @@ -299,7 +302,8 @@ select = [
unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
target-version = "py38"
[tool.ruff.per-file-ignores]
"tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/`
# TODO(michalk8): PO004 - remove `self.initialize`
"tests/*" = ["D", "PT004", "E402"]
"*/__init__.py" = ["F401"]
"docs/*" = ["D"]
"src/ott/types.py" = ["D102"]
Expand Down
8 changes: 4 additions & 4 deletions src/ott/problems/linear/barycenter_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ott.geometry import costs, geometry, segment

__all__ = ["FreeBarycenterProblem"]
__all__ = ["FreeBarycenterProblem", "FixedBarycenterProblem"]


@jax.tree_util.register_pytree_node_class
Expand All @@ -44,8 +44,8 @@ class FreeBarycenterProblem:
Only used when ``y`` is not already segmented. When passing
``segment_ids``, 2 arguments must be specified for jitting to work:

- ``num_segments`` - the total number of measures.
- ``max_measure_size`` - maximum of support sizes of these measures.
- ``num_segments`` - the total number of measures.
- ``max_measure_size`` - maximum of support sizes of these measures.
"""

def __init__(
Expand Down Expand Up @@ -158,7 +158,7 @@ class FixedBarycenterProblem:
a: batch of histograms of shape ``[batch, num_a]`` where ``num_a`` matches
the first value of the :attr:`~ott.geometry.Geometry.shape` attribute of
``geom``.
weights: ``[batch,]`` positive weights summing to :math`1`. Uniform by
weights: ``[batch,]`` positive weights summing to :math:`1`. Uniform by
default.
"""

Expand Down
14 changes: 1 addition & 13 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ class SinkhornOutput(NamedTuple):
below the convergence threshold.
inner_iterations: number of iterations that were run between two
computations of errors.

"""

f: Optional[jnp.ndarray] = None
Expand Down Expand Up @@ -424,10 +423,6 @@ def transport_cost_at_geom(
return jnp.sum(self.apply(geom.cost_1.T) * geom.cost_2.T)
return jnp.sum(self.matrix * other_geom.cost_matrix)

@property
def linear(self) -> bool: # noqa: D102
return isinstance(self.ot_prob, linear_problem.LinearProblem)

@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
Expand All @@ -440,17 +435,10 @@ def a(self) -> jnp.ndarray: # noqa: D102
def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b

@property
def linear_output(self) -> bool: # noqa: D102
return True

# TODO(michalk8): this should be always present
@property
def n_iters(self) -> int: # noqa: D102
"""Returns the total number of iterations that were needed to terminate."""
if self.errors is None:
return -1
return jnp.sum(self.errors > -1) * self.inner_iterations
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102
Expand Down
10 changes: 4 additions & 6 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class LRSinkhornOutput(NamedTuple):
errors: jnp.ndarray
ot_prob: linear_problem.LinearProblem
epsilon: float
inner_iterations: int
# TODO(michalk8): Optional is an artifact of the current impl., refactor
reg_ot_cost: Optional[float] = None

Expand Down Expand Up @@ -205,10 +206,6 @@ def compute_reg_ot_cost( # noqa: D102
use_danskin=use_danskin
)

@property
def linear(self) -> bool: # noqa: D102
return isinstance(self.ot_prob, linear_problem.LinearProblem)

@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
Expand All @@ -222,8 +219,8 @@ def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b

@property
def linear_output(self) -> bool: # noqa: D102
return True
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
Expand Down Expand Up @@ -773,6 +770,7 @@ def output_from_state(
costs=state.costs,
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
)

def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
Expand Down
16 changes: 9 additions & 7 deletions src/ott/solvers/nn/conjugate_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,18 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver):
max_iter: maximum number of iterations
max_linesearch_iter: maximum number of line search iterations
linesearch_type: type of line search
decrease_factor: decrease factor for a backtracking line search
ls_method: the line search method
linesearch_init: strategy for line search initialization
increase_factor: factor by which to increase the step size during
the line search
"""

gtol: float = 1e-3
max_iter: int = 10
max_linesearch_iter: int = 10
linesearch_type: Literal["zoom", "backtracking"] = "backtracking"
decrease_factor: float = 0.66
ls_method: Literal["wolf", "strong-wolfe"] = "strong-wolfe"
linesearch_type: Literal["zoom", "backtracking",
"hager-zhang"] = "backtracking"
linesearch_init: Literal["increase", "max", "current"] = "increase"
increase_factor: float = 1.5

def solve( # noqa: D102
self,
Expand All @@ -98,9 +100,9 @@ def solve( # noqa: D102
fun=lambda x: f(x) - x.dot(y),
tol=self.gtol,
maxiter=self.max_iter,
decrease_factor=self.decrease_factor,
linesearch=self.linesearch_type,
condition=self.ls_method,
linesearch_init=self.linesearch_init,
increase_factor=self.increase_factor,
implicit_diff=False,
unroll=False
)
Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def primal_cost(self) -> float:
def n_iters(self) -> int: # noqa: D102
if self.errors is None:
return -1
return jnp.sum(self.errors > -1)
return jnp.sum(self.errors[:, 0] != -1)


class GWState(NamedTuple):
Expand Down
Loading