# Tracking progress in Sinkhorn and Low-Rank Sinkhorn

This tutorial shows how to track progress and errors during iterations of {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` and {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` algorithms.

We use the same basic example as in the {doc}`basic_ot_between_datasets` notebook.

## Without tracking (default behavior)

Let's recap the basic example we use in this notebook:

In [1]:
import sys

if "google.colab" in sys.modules:
    %pip install -q git+https://github.com/ott-jax/ott@main
    %pip install -q tqdm

In [2]:
from tqdm import tqdm

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

from ott import utils
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr

In [3]:
rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5

In [4]:
geom = pointcloud.PointCloud(x, y)

This problem is very simple, so the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solver converges after only 7 iterations. The solver would otherwise keep iterating for 200 steps (default value).

In [5]:
solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom)

print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

has converged: True, #iters: 7, cost: 1.2429015636444092


Obviously, not tracking progress (the default) is fine.However when tackling larger problems, we will probably want to track the various metrics that the Sinkhorn algorithm updates at each iteration.

In the next sections, we show how to track progress for that same example.

## How to track progress

{mod}`ott` offers a simple and flexible mechanism that works well with {func}`~jax.jit`, and applies to both the functional interface and the class interface.

The {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` and low-rank Sinkhorn {class}`ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solver implementations only report progress if we pass a callback function (with some specific signature) to its initializer. This callback is called at each iteration.

### Callback function signature

The required signature of the callback function is: `(status: Tuple[ndarray, ndarray, ndarray, NamedTuple], args: Any) -> None`.

The arguments are:

- status: a tuple of:
  - the current iteration index (0-based)
  - the number of inner iterations after which the error is computed
  - the total number of iterations
  - the current {class}`~ott.solvers.linear.sinkhorn.SinkhornState` or {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`. For technical reasons, the type of this argument in the signature is simply {class}`~typing.NamedTuple` (the common super-type).

-  args: unused, see {mod}`jax.experimental.host_callback`.

Note:

- Above, the {class}`~numpy.ndarray` types are required by the underlying mechanism {mod}`~jax.experimental.host_callback`, but their arguments simply contain one integer value and can be safely cast.

## Examples

Here are a few examples of how to track progress for Sinkhorn and low-rank Sinkhorn.

### Tracking progress for Sinkhorn via the functional interface

#### With the basic callback function


{mod}`ott` provides a basic callback function: {func}`~ott.utils.default_progress_fn` that we can use directly (it simply prints iteration and error to the console). It can also serve as a basis for customizations.

Here, we simply pass that basic callback as a static argument:

In [6]:
solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
ot = solve_fn(geom, a=None, b=None, progress_fn=utils.default_progress_fn)

print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

10 / 2000 -- 0.049124784767627716
20 / 2000 -- 0.019962385296821594
30 / 2000 -- 0.00910455733537674
40 / 2000 -- 0.004339158535003662
50 / 2000 -- 0.002111591398715973
60 / 2000 -- 0.001037590205669403
70 / 2000 -- 0.0005124583840370178
has converged: True, #iters: 7, cost: 1.2429015636444092


This reveals that the solver reports its metrics each 10 _inner_ iterations (default value).

#### With `tqdm`

Here, we first define a function that updates a `tqdm` progress bar.

In [7]:
def progress_fn(status, *args):
    iteration, inner_iterations, total_iter, state = status
    iteration = int(iteration) + 1  # from [0;n-1] to [1;n]
    inner_iterations = int(inner_iterations)
    total_iter = int(total_iter)
    errors = np.asarray(state.errors).ravel()

    # Avoid reporting error on each iteration,
    # because errors are only computed every `inner_iterations`.
    if iteration % inner_iterations == 0:
        error_idx = max(0, iteration // inner_iterations - 1)
        error = errors[error_idx]

        pbar.set_postfix_str(f"error: {error:0.6e}")
        pbar.total = total_iter // inner_iterations
        pbar.update()

and then use it as previously, but in the context of an existing `tqdm` progress bar:

In [8]:
with tqdm() as pbar:
    solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
    ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)

  4%|███▍                                                                                              | 7/200 [00:00<00:10, 18.58it/s, error: 5.124584e-04]


In [9]:
print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

has converged: True, #iters: 7, cost: 1.2429015636444092


### Tracking progress for Sinkhorn via the class interface

Here, we provide the callback function to the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` class initializer and display progress with `tqdm`:

In [10]:
prob = linear_problem.LinearProblem(geom)

with tqdm() as pbar:
    solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)

  4%|███▍                                                                                              | 7/200 [00:00<00:10, 18.95it/s, error: 5.124584e-04]


In [11]:
print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

has converged: True, #iters: 7, cost: 1.2429015636444092


### Tracking progress of Low-Rank Sinkhorn iterations via the class interface

We can track progress of the Low-rank Sinkhorn solver, however because it currently doesn't have a functional interface, we can only use the class interface {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`:

In [12]:
prob = linear_problem.LinearProblem(geom)
rank = 2

with tqdm() as pbar:
    solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)

  8%|███████▊                                                                                         | 16/200 [00:00<00:07, 23.01it/s, error: 3.191826e-04]


In [13]:
print(f"has converged: {ot.converged}, cost: {ot.reg_ot_cost}")

has converged: True, cost: 1.7340877056121826


That's it, this is how to track progress and errors during Sinkhorn and Low-rank Sinkhorn iterations.