Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] New Quickstart guide and revamp User guide #726

Merged
merged 16 commits into from
Mar 25, 2025
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
- Added `ot.gaussian.bures_wasserstein_distance` (PR #680)
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -347,7 +347,7 @@ def __getattr__(cls, name):
}

sphinx_gallery_conf = {
"examples_dirs": ["../../examples", "../../examples/da"],
"examples_dirs": ["../../examples"],
"gallery_dirs": "auto_examples",
"filename_pattern": "plot_", # (?!barycenter_fgw)
"nested_sections": False,
5 changes: 3 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -17,9 +17,10 @@ Contents
:maxdepth: 1

self
quickstart
all
auto_examples/plot_quickstart_guide
auto_examples/index
user_guide
all
releases
contributors
contributing
70 changes: 35 additions & 35 deletions docs/source/quickstart.rst → docs/source/user_guide.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

Quick start guide
=================
User guide
==========

In the following we provide some pointers about which functions and classes
to use for different problems related to optimal transport (OT) and machine
@@ -136,12 +136,12 @@ instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in
memory because the cost matrix has to be computed. The exact solver in of time
complexity :math:`\mathcal{O}(n^3\log(n))` and the Sinkhorn solver has been
proven to be nearly :math:`\mathcal{O}(n^2)` which is still too complex for very
large scale solvers.
large scale solvers. For all the generic solvers we need to compute the cost
matrix and the OT matrix of memory size :math:`\mathcal{O}(n^2)` which can be
prohibitive for very large scale problems.


If you need to solve OT with large number of samples, we recommend to use
entropic regularization and memory efficient implementation of Sinkhorn as
proposed in `GeomLoss <https://www.kernel-operations.io/geomloss/>`_. This
If you need to solve OT with large number of samples, we provide "lazy" memory efficient implementation of Sinkhorn in pure
python and using `GeomLoss <https://www.kernel-operations.io/geomloss/>`_. This
implementation is compatible with Pytorch and can handle large number of
samples. Another approach to estimate the Wasserstein distance for very large
number of sample is to use the trick from `Wasserstein GAN
@@ -193,15 +193,19 @@ that will return the optimal transport matrix :math:`\gamma^*`:

# a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix

# unified API
T = ot.solve(M, a, b).plan # exact linear program

# classical API
T = ot.emd(a, b, M) # exact linear program

The method implemented for solving the OT problem is the network simplex. It is
implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
solver is quite efficient and uses sparsity of the solution.



.. minigallery:: ot.emd
.. minigallery:: ot.emd, ot.solve
:add-heading: Examples of use for :any:`ot.emd`
:heading-level: "

@@ -226,7 +230,12 @@ It can computed from an already estimated OT matrix with

# a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
W = ot.emd2(a, b, M) # Wasserstein distance / EMD value

# Wasserstein distance / EMD value with unified API
W = ot.solve(M, a, b, return_matrix=False).value

# with classical API
W = ot.emd2(a, b, M)

Note that the well known `Wasserstein distance
<https://en.wikipedia.org/wiki/Wasserstein_metric>`_ between distributions a and
@@ -246,7 +255,7 @@ the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2`
when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean
distance.

.. minigallery:: ot.emd2
.. minigallery:: ot.emd2, ot.solve
:add-heading: Examples of use for :any:`ot.emd2`
:heading-level: "

@@ -274,6 +283,10 @@ distributions. In the case when the finite sample dataset is supposed Gaussian,
we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
Monge mapping.

All those special cases are accessible with the unified API of POT through the
function :any:`ot.solve_sample` with the parameter :code:`method` that allows to
choose the method used to solve the problem (with :code:`method='1D'` or :code:`method='gaussian'`).


Regularized Optimal Transport
-----------------------------
@@ -330,13 +343,15 @@ The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
linear term. Note that the regularization parameter :math:`\lambda` in the
equation above is given to those functions with the parameter :code:`reg`.

>>> import ot
>>> a = [.5, .5]
>>> b = [.5, .5]
>>> M = [[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[ 0.36552929, 0.13447071],
[ 0.13447071, 0.36552929]])
.. code:: python

# unified API
P = ot.solve(M, a, b, reg=1).plan # OT Sinkhorn matrix
loss = ot.solve(M, a, b, reg=1).value # OT Sinkhorn value

# classical API
P = ot.sinkhorn(a, b, M, reg=1) # OT Sinkhorn matrix
loss = ot.sinkhorn2(a, b, M, reg=1) # OT Sinkhorn value

More details about the algorithms used are given in the following note.

@@ -406,13 +421,10 @@ implementations are not optimized for speed but provide a robust implementation
of algorithms in [18]_ [19]_.


.. minigallery:: ot.sinkhorn
:add-heading: Examples of use for :any:`ot.sinkhorn`
.. minigallery:: ot.sinkhorn ot.sinkhorn2
:add-heading: Examples of use for Sinkhorn algorithm
:heading-level: "

.. minigallery:: ot.sinkhorn2
:add-heading: Examples of use for :any:`ot.sinkhorn2`
:heading-level: "


Other regularizations
@@ -969,18 +981,6 @@ For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=
It's important to note that the `numpy` backend cannot be disabled.


List of compatible modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This list will get longer for new releases and will hopefully disappear when POT
become fully implemented with the backend.

- :any:`ot.bregman`
- :any:`ot.gromov` (some functions use CPU only solvers with copy overhead)
- :any:`ot.optim` (some functions use CPU only solvers with copy overhead)
- :any:`ot.sliced`
- :any:`ot.utils` (partial)


FAQ
---
2 changes: 1 addition & 1 deletion examples/plot_OT_2D_samples.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@

# %% EMD

G0 = ot.emd(a, b, M)
G0 = ot.solve(M, a, b).plan

pl.figure(3)
pl.imshow(G0, interpolation="nearest")
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.