Skip to content

Commit

Permalink
Remove functional API (#222)
Browse files Browse the repository at this point in the history
* Remove `sinkhorn function`

* Fix `sinkhorn_divergence` test

* Remove `gromov_wasserstein` function

* Remove `make` functions

* Fix `soft_sort` and Jacobian tests

* Remove `Transport` interface

* Fix Jacobian test

* Fix `soft_sort` and tests

* Clean up some tests

* Fix wrong `value_and_grad` usage

* Update notebooks, isort and pre-commit

* [ci skip] Fix rendering in `Sinkhorn`

* Handle TODOs, clean initializer tests

* Add `sinkhorn.solve` utility

* Re-add `gromov_wasserstein.solve`, polish docs

* Remove redundant line from `pyproject.toml`

* Polish quad docs

* Add rank to `sinkhorn.solve`

* Add `rank` to `sinkhorn.solve`
  • Loading branch information
michalk8 committed Jan 7, 2023
1 parent 12ee5d6 commit 7c64f16
Show file tree
Hide file tree
Showing 66 changed files with 1,221 additions and 1,773 deletions.
15 changes: 9 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ repos:
hooks:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/tomcatling/black-nb
rev: '0.7'
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.0
hooks:
- id: black-nb
- id: nbqa-pyupgrade
args: [--py38-plus]
- id: nbqa-black
- id: nbqa-isort
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.11.4
hooks:
- id: isort
- repo: https://github.com/asottile/yesqa
Expand All @@ -30,7 +33,7 @@ repos:
- flake8-bugbear
- flake8-blind-except
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.4.0
rev: v2.5.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
Expand Down Expand Up @@ -65,7 +68,7 @@ repos:
- flake8-blind-except
args: [--docstring-convention, google]
- repo: https://github.com/asottile/pyupgrade
rev: v3.2.2
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py3-plus, --py37-plus, --keep-runtime-typing]
24 changes: 15 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,31 @@ In the simple toy example below, we compute the optimal coupling matrix between
```python
import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

# sample two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings using the Sinkhorn algorithm.
ot = transport.solve(x, y, a=a, b=b)
P = ot.matrix
geom = pointcloud.PointCloud(x, y)
prob = linear_problem.LinearProblem(geom, a, b)

solver = sinkhorn.Sinkhorn()
out = solver(prob)
```

The call to `solve` above works out the optimal transport solution. The `ot` object contains a transport matrix
The call to `solver(prob)` above works out the optimal transport solution. The `out` object contains a transport matrix
(here of size $12\times 14$) that quantifies a `link strength` between each point of the first point cloud, to one or
more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and
are reflected in the crude `solve` API. We provide far more flexibility to define custom cost functions, objectives,
and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/).
more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost
functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/).

![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png)

Expand Down
4 changes: 0 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@
source_suffix = ['.rst']

autosummary_generate = True
autosummary_filename_map = {
"ott.solvers.linear.sinkhorn.sinkhorn":
"ott.solvers.linear.sinkhorn.sinkhorn-function"
}

autodoc_typehints = 'description'

Expand Down
2 changes: 1 addition & 1 deletion docs/solvers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Sinkhorn Solvers
.. autosummary::
:toctree: _autosummary

sinkhorn.sinkhorn
sinkhorn.solve
sinkhorn.Sinkhorn
sinkhorn.SinkhornOutput
sinkhorn_lr.LRSinkhorn
Expand Down
2 changes: 1 addition & 1 deletion docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ Gromov-Wasserstein Solvers
.. autosummary::
:toctree: _autosummary

gromov_wasserstein.solve
gromov_wasserstein.GromovWasserstein
gromov_wasserstein.GWOutput
gromov_wasserstein.gromov_wasserstein

Barycenter Solvers
------------------
Expand Down
7 changes: 0 additions & 7 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,6 @@ The tools package contains high level functions that build on outputs produced b
They can be used to compute Sinkhorn divergences :cite:`sejourne:19`, instantiate transport matrices,
provide differentiable approximations to ranks and quantile functions :cite:`cuturi:19`, etc.

Optimal Transport
-----------------
.. autosummary::
:toctree: _autosummary

transport.Transport

Segmented Sinkhorn
------------------
.. autosummary::
Expand Down
3 changes: 2 additions & 1 deletion examples/fairness/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

from typing import Sequence

import jax
from absl import app, flags, logging
from clu import platform
from ml_collections import config_flags

import jax

from ott.examples.fairness import train

FLAGS = flags.FLAGS
Expand Down
3 changes: 2 additions & 1 deletion examples/fairness/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import functools
from typing import Any

import ml_collections

import flax
import jax
import jax.numpy as jnp
import ml_collections
from flax import jax_utils
from flax.metrics import tensorboard
from flax.training import checkpoints, common_utils
Expand Down
3 changes: 2 additions & 1 deletion examples/soft_error/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.
"""Data loading and data augmentation."""

import jax
import tensorflow as tf
import tensorflow_datasets as tfds

import jax
from flax import jax_utils


Expand Down
3 changes: 2 additions & 1 deletion examples/soft_error/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

from typing import Sequence

import jax
from absl import app, flags, logging
from clu import platform
from ml_collections import config_flags

import jax

from ott.examples.soft_error import train

FLAGS = flags.FLAGS
Expand Down
5 changes: 3 additions & 2 deletions examples/soft_error/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import functools
from typing import Any

import ml_collections
import tensorflow_datasets as tfds

import flax
import jax
import jax.numpy as jnp
import ml_collections
import tensorflow_datasets as tfds
from flax import jax_utils
from flax.metrics import tensorboard
from flax.training import checkpoints, common_utils
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ include = '\.ipynb$'
profile = "black"
include_trailing_comma = true
multi_line_output = 3
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
known_numeric = ["numpy", "scipy", "pandas", "sklearn", "jax", "flax", "optax", "torch"]
known_plotting = ["matplotlib", "mpl_toolkits", "seaborn"]
skip_glob = ["docs/*"]

[tool.pytest.ini_options]
Expand Down
2 changes: 1 addition & 1 deletion src/ott/initializers/linear/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _vectorized_update(
"""Inner loop DualSort Update.
Args:
f : potential f, array of size n.
f: potential f, array of size n.
modified_cost: cost matrix minus diagonal column-wise.
Returns:
Expand Down
16 changes: 8 additions & 8 deletions src/ott/problems/linear/linear_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ class LinearProblem:
r"""Linear OT problem.
This class describes the main ingredients appearing in a linear OT problem.
Namely, a `geom` object (including cost structure/points) describing point
clouds or the support of measures, followed by probability masses `a` and `b`.
Unabalancedness of the problem is also kept track of, through two coefficients
`tau_a` and `tau_b`, which are both kept between 0 and 1
Namely, a ``geom`` object (including cost structure/points) describing point
clouds or the support of measures, followed by probability masses ``a`` and
``b``. Unbalancedness of the problem is also kept track of, through two
coefficients ``tau_a`` and ``tau_b``, which are both kept between 0 and 1
(1 corresponding to a balanced OT problem).
Args:
geom: The ground geometry cost of the linear problem.
a: The first marginal. If `None`, it will be uniform.
b: The second marginal. If `None`, it will be uniform.
tau_a: If smaller than `1`, defines how much unbalanced the problem is on
the first marginal.
tau_b: If smaller than `1`, defines how much unbalanced the problem is on
the second marginal.
tau_a: If `< 1`, defines how much unbalanced the problem is
on the first marginal.
tau_b: If `< 1`, defines how much unbalanced the problem is
on the second marginal.
"""

def __init__(
Expand Down
42 changes: 21 additions & 21 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

@jax.tree_util.register_pytree_node_class
class QuadraticProblem:
r"""Quadratic regularized OT problem.
r"""Quadratic OT problem.
The quadratic loss of a single OT matrix is assumed to
have the form given in :cite:`peyre:16`, eq. 4.
Expand All @@ -48,37 +48,37 @@ class QuadraticProblem:
Args:
geom_xx: Ground geometry of the first space.
geom_yy: Ground geometry of the second space.
geom_xy: Geometry defining the linear penalty term for Fused Gromov
Wasserstein. If `None`, the problem reduces to a
plain Gromov Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein,
geom_xy: Geometry defining the linear penalty term for
Fused Gromov-Wasserstein. If `None`, the problem reduces to a plain
Gromov-Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein,
i.e. problem = purely quadratic + fused_penalty * linear problem.
Ignored if ``geom_xy`` is not specified.
scale_cost: option to rescale the cost matrices:
- if `True`, use the default for each geometry.
- if `False`, keep the original scaling in geometries.
- if :obj:`True`, use the default for each geometry.
- if :obj:`False`, keep the original scaling in geometries.
- if :class:`str`, use a specific method available in
:class:`~ott.geometry.geometry.Geometry` or
:class:`~ott.geometry.pointcloud.PointCloud`.
- if `None`, do not scale the cost matrices.
- if :obj:`None`, do not scale the cost matrices.
a: jnp.ndarray[n] representing the probability weights of the samples
from geom_xx. If None, it will be uniform.
b: jnp.ndarray[n] representing the probability weights of the samples
from geom_yy. If None, it will be uniform.
a: array representing the probability weights of the samples
from ``geom_xx``. If `None`, it will be uniform.
b: array representing the probability weights of the samples
from ``geom_yy``. If `None`, it will be uniform.
loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear
part of the loss (see in the pydoc of the class lin1, lin2). The second
one is the quadratic part (quad1, quad2). By default, the loss
is set as the 4 functions representing the squared Euclidean loss, and
this property is taken advantage of in subsequent computations. See
Alternatively, KL loss can be specified in no less optimized way.
tau_a: if lower that 1.0, defines how much unbalanced the problem is on
part of the loss. The second one is the quadratic part (quad1, quad2).
By default, the loss is set as the 4 functions representing the squared
Euclidean loss, and this property is taken advantage of in subsequent
computations. Alternatively, KL loss can be specified in no less optimized
way.
tau_a: if `< 1.0`, defines how much unbalanced the problem is on
the first marginal.
tau_b: if lower that 1.0, defines how much unbalanced the problem is on
tau_b: if `< 1.0`, defines how much unbalanced the problem is on
the second marginal.
gw_unbalanced_correction: Whether the unbalanced version of
:cite:`sejourne:21` is used. Otherwise ``tau_a`` and ``tau_b`` only affect
:cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect
the inner Sinkhorn loop.
ranks: Ranks of the cost matrices, see
:meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when
Expand Down Expand Up @@ -274,7 +274,7 @@ def update_linearization(
If the problem is unbalanced (``tau_a < 1.0 or tau_b < 1.0``), two cases are
possible, as explained in :meth:`init_linearization` above.
Finally, it is also possible to consider a Fused Gromov Wasserstein problem.
Finally, it is also possible to consider a Fused Gromov-Wasserstein problem.
Details about the resulting cost matrix are also given in
:meth:`init_linearization`.
Expand Down
1 change: 1 addition & 0 deletions src/ott/solvers/linear/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class AndersonAcceleration:
"""Implements Anderson acceleration for Sinkhorn."""

# TODO(michalk8): use memory=0 as no Anderson acceleration?
memory: int = 2 # Number of iterates considered to form interpolation.
refresh_every: int = 1 # Recompute interpolation periodically.
ridge_identity: float = 1e-2 # Ridge used in the linear system.
Expand Down
Loading

0 comments on commit 7c64f16

Please sign in to comment.