# Tracking progress and metrics in Sinkhorn and Low-rank Sinkhorn

This tutorial shows how to track progress and errors during iterations of Sinkhorn and Low-rank Sinkhorn algorithms.

We use a subset of the "Getting Started" notebook, and use the same example.

In [1]:
import sys

if "google.colab" in sys.modules:
    %pip install -q git+https://github.com/ott-jax/ott@main

In [2]:
from tqdm import tqdm

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.utils import default_progress_fn

In [3]:
rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5

In [4]:
geom = pointcloud.PointCloud(x, y)

In [5]:
solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom, a=None, b=None)

By default, no progress is reported:

In [6]:
ot = solve_fn(geom)

While the Sinkhorn algorithm iterates, various metrics are updated, and you will probably want to track them or simply track progress when tackling larger problems. 

By default, the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` implementation will not report progress, but if we pass a callback function with some specific signature, Sinkhorn will call this function each time it updates its internal metrics. 

The signature of the callback functions is: `(status: Tuple[ndarray, ndarray, ndarray, NamedTuple], args: Any) -> None`.

The arguments are:

- status: status consisting of:
  - the current iteration number
  - the number of inner iterations after which the error is computed
  - the total number of iterations
  - the current {class}`~ott.solvers.linear.sinkhorn.SinkhornState` or {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`. For technical reasons the type of this argument in the signature is simply `NamedTuple` (the common super-type).

-  args: unused, see {mod}`jax.experimental.host_callback`.

As we show later, the same discussion applies for {class}`ott.solvers.linear.sinkhorn_lr.LRSinkhorn`.

## Tracking progress of Sinkhorn iterations

We show two alternative ways:

- using the functional interface: `sinkhorn.solve`
- using the class interface: `sinkhorn.Sinkhorn`

### 1. Using the functional interface

Here, as an example, we use a basic callback function, {func}`~ott.utils.default_progress_fn`, which simply prints iteration and error to the console.

In [7]:
solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
ot = solve_fn(geom, a=None, b=None, progress_fn=default_progress_fn)

10 / 2000 -- 0.049124784767627716
20 / 2000 -- 0.019962385296821594
30 / 2000 -- 0.00910455733537674
40 / 2000 -- 0.004339158535003662
50 / 2000 -- 0.002111591398715973
60 / 2000 -- 0.001037590205669403
70 / 2000 -- 0.0005124583840370178


In the above case, the functional interface leverages the class {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`. By default, this object will stop after 2000 iterations if convergence hasn't been reached before, and it reports its metrics each 10 inner iterations (default value). For this basic example, convergence is reached after 7 iterations.

We can also provide any function with the signature specified above:

In [8]:
def progress_fn(status, *args):
    iteration, inner_iterations, total_iter, state = status
    iteration = int(iteration) + 1
    inner_iterations = int(inner_iterations)
    total_iter = int(total_iter)
    errors = np.asarray(state.errors).ravel()

    # Avoid reporting error on each iteration,
    # because errors are only computed every `inner_iterations`.
    if iteration % inner_iterations == 0:
        error_idx = max(0, iteration // inner_iterations - 1)
        error = errors[error_idx]

        pbar.set_postfix_str(f"error: {error:0.6e}")
        pbar.total = total_iter // inner_iterations
        pbar.update()

In this case, still with the functional interface, a `tqdm` progress bar is instantiated and the iteration and errors are displayed.

Of course, as previously, Sinkhorn will converge after only a few iterations because the problem is simple.

In [9]:
with tqdm() as pbar:
    solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
    ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)

  4%|███▍                                                                                              | 7/200 [00:00<00:12, 15.80it/s, error: 5.124584e-04]


### 2. Using the class interface

Here, we adapt the previous example, but we provide the callback function to the class initializer.

In [10]:
prob = linear_problem.LinearProblem(geom)

with tqdm() as pbar:
    solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)

  4%|███▍                                                                                              | 7/200 [00:00<00:12, 15.69it/s, error: 5.124584e-04]


## Tracking progress of Low-rank Sinkhorn iterations

Low-rank Sinkhorn currently doesn't have a functional interface, so we use only the class interface.

In [11]:
prob = linear_problem.LinearProblem(geom)
rank = 2

with tqdm() as pbar:
    solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)

  8%|███████▊                                                                                         | 16/200 [00:00<00:09, 18.65it/s, error: 3.191826e-04]


That's it, we know how to track progress and errors during Sinkhorn and Low-rank Sinkhorn iterations.