# Tracking progress in OTT solvers.

This tutorial shows how to track progress and errors during iterations of the following solvers:

- {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`
- {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`
- {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`.

We'll see that we simply need to provide a callback function to the solvers.

In [1]:
import sys

if "google.colab" in sys.modules:
    %pip install -q git+https://github.com/ott-jax/ott@main
    %pip install -q tqdm

In [2]:
from tqdm import tqdm

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

from ott import utils
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein

## How to track progress

{mod}`ott` offers a simple and flexible mechanism that works well with {func}`~jax.jit`, and applies to both the functional interface and the class interface.

The solvers {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`, low-rank Sinkhorn {class}`ott.solvers.linear.sinkhorn_lr.LRSinkhorn`, and {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` only report progress if we pass a callback function (with some specific signature) to its initializer. This callback is called at each iteration.

### Callback function signature

The required signature of the callback function is: `(status: Tuple[ndarray, ndarray, ndarray, NamedTuple], args: Any) -> None`.

The arguments are:

- status: a tuple of:
  - the current iteration index (0-based)
  - the number of inner iterations after which the error is computed
  - the total number of iterations
  - the current solver state class: {class}`~ott.solvers.linear.sinkhorn.SinkhornState` or {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`, or {class}`~ott.solvers.quadratic.gromov_wasserstein.GWState`. For technical reasons, the type of this argument in the signature is simply {class}`~typing.NamedTuple` (the common super-type).

-  args: unused, see {mod}`jax.experimental.host_callback`.

Note:

- Above, the {class}`~numpy.ndarray` types are passed by the underlying mechanism {mod}`~jax.experimental.host_callback`, but their arguments simply contain one integer value and can be safely cast.

## Linear problem without tracking (default behavior)

Let's start with the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solver and setup a basic linear problem:

In [3]:
rngs = jax.random.split(jax.random.PRNGKey(0), 2)
d, n_x, n_y = 2, 7, 11
x = jax.random.normal(rngs[0], (n_x, d))
y = jax.random.normal(rngs[1], (n_y, d)) + 0.5

In [4]:
geom = pointcloud.PointCloud(x, y)

This problem is very simple, so the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solver converges after only 7 iterations. The solver would otherwise keep iterating for 200 steps (default value).

In [5]:
solve_fn = jax.jit(sinkhorn.solve)
ot = solve_fn(geom)

print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

has converged: True, #iters: 7, cost: 1.2429015636444092


For cases as simple as this one, it fine to not track progress (the default behavior). However when tackling larger problems, we will probably want to track the various metrics that the Sinkhorn algorithm updates at each iteration.

In the next sections, we show how to track progress for that same example.

## Examples

Here are a few examples of how to track progress for Sinkhorn and low-rank Sinkhorn.

### Tracking progress for Sinkhorn via the functional interface

#### With the basic callback function


{mod}`ott` provides a basic callback function: {func}`~ott.utils.default_progress_fn` that we can use directly: it simply prints iteration and error to the console. It can also serve as a basis for customizations.

Let's simply pass that basic callback as a static argument:

In [6]:
solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
ot = solve_fn(geom, a=None, b=None, progress_fn=utils.default_progress_fn)

print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

10 / 2000 -- 0.049124784767627716
20 / 2000 -- 0.019962385296821594
30 / 2000 -- 0.00910455733537674
40 / 2000 -- 0.004339158535003662
50 / 2000 -- 0.002111591398715973
60 / 2000 -- 0.001037590205669403
70 / 2000 -- 0.0005124583840370178
has converged: True, #iters: 7, cost: 1.2429015636444092


This reveals that the solver reports its metrics each 10 _inner_ iterations (default value).

#### With `tqdm`

Let's first define a function that updates a `tqdm` progress bar.

In [7]:
def progress_fn(status, *args):
    iteration, inner_iterations, total_iter, state = status
    iteration = int(iteration) + 1  # from [0;n-1] to [1;n]
    inner_iterations = int(inner_iterations)
    total_iter = int(total_iter)
    errors = np.asarray(state.errors).ravel()

    # Avoid reporting error on each iteration,
    # because errors are only computed every `inner_iterations`.
    if iteration % inner_iterations == 0:
        error_idx = max(0, iteration // inner_iterations - 1)
        error = errors[error_idx]

        pbar.set_postfix_str(f"error: {error:0.6e}")
        pbar.total = total_iter // inner_iterations
        pbar.update()

and let's use it in the context of an existing `tqdm` progress bar:

In [8]:
with tqdm() as pbar:
    solve_fn = jax.jit(sinkhorn.solve, static_argnames=["progress_fn"])
    ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)

  4%|███▍                                                                                              | 7/200 [00:00<00:12, 15.94it/s, error: 5.124584e-04]


In [9]:
print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

has converged: True, #iters: 7, cost: 1.2429015636444092


### Tracking progress for Sinkhorn via the class interface

Let's reiterate, but this time we provide the callback function to the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` class initializer and display progress with `tqdm`:

In [10]:
prob = linear_problem.LinearProblem(geom)

with tqdm() as pbar:
    solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)

  4%|███▍                                                                                              | 7/200 [00:00<00:11, 16.10it/s, error: 5.124584e-04]


In [11]:
print(
    f"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}"
)

has converged: True, #iters: 7, cost: 1.2429015636444092


### Tracking progress of Low-Rank Sinkhorn iterations via the class interface

We can also track progress of the Low-rank Sinkhorn solver, however because it currently doesn't have a functional interface, we can only use the class interface {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`:

In [12]:
prob = linear_problem.LinearProblem(geom)
rank = 2

with tqdm() as pbar:
    solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)
    ot = jax.jit(solver)(prob)

  8%|███████▊                                                                                         | 16/200 [00:00<00:09, 19.80it/s, error: 3.191826e-04]


In [13]:
print(f"has converged: {ot.converged}, cost: {ot.reg_ot_cost}")

has converged: True, cost: 1.7340877056121826


## Tracking progress of Gromov-Wasserstein iterations

We can track progress in the same way as with the Sinkhorn solvers. Let's define a simple quadratic problem (the same as in the {doc}`docs/tutorials/notebooks/gromov_wasserstein.ipynb` notebook):

In [14]:
# Samples spiral
def sample_spiral(
    n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0
):
    radius = jnp.linspace(min_radius, max_radius, n)
    angles = jnp.linspace(min_angle, max_angle, n)
    data = []
    noise = jax.random.normal(key, (2, n)) * noise
    for i in range(n):
        x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])
        y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])
        data.append([x, y])
    data = jnp.array(data)
    return data


# Samples Swiss roll
def sample_swiss_roll(
    n, min_radius, max_radius, length, key, min_angle=0, max_angle=10, noise=0.1
):
    spiral = sample_spiral(
        n, min_radius, max_radius, key[0], min_angle, max_angle, noise
    )
    third_axis = jax.random.uniform(key[1], (n, 1)) * length
    swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))
    return swiss_roll


# Data parameters
n_spiral = 400
n_swiss_roll = 500
length = 10
min_radius = 3
max_radius = 10
noise = 0.8
min_angle = 0
max_angle = 9
angle_shift = 3

# Seed
seed = 14
key = jax.random.PRNGKey(seed)
key, *subkey = jax.random.split(key, 4)

In [15]:
spiral = sample_spiral(
    n_spiral,
    min_radius,
    max_radius,
    key=subkey[0],
    min_angle=min_angle + angle_shift,
    max_angle=max_angle + angle_shift,
    noise=noise,
)
swiss_roll = sample_swiss_roll(
    n_swiss_roll,
    min_radius,
    max_radius,
    key=subkey[1:],
    length=length,
    min_angle=min_angle,
    max_angle=max_angle,
)

and let's track progress while the solver iterates:

In [16]:
# apply Gromov-Wasserstein
geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)
geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)
prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy)

solver = gromov_wasserstein.GromovWasserstein(
    epsilon=100.0,
    max_iterations=20,
    store_inner_errors=True,  # needed for reporting errors
    progress_fn=utils.default_progress_fn,  # callback function
)
out = solver(prob)

n_outer_iterations = jnp.sum(out.costs != -1)
has_converged = bool(out.linear_convergence[n_outer_iterations - 1])
print(f"{n_outer_iterations} outer iterations were needed.")
print(f"The last Sinkhorn iteration has converged: {has_converged}")
print(f"The outer loop of Gromov Wasserstein has converged: {out.converged}")
print(f"The final regularized GW cost is: {out.reg_gw_cost:.3f}")

1 / 20 -- -1.0
2 / 20 -- 0.13043604791164398
3 / 20 -- 0.08981532603502274
4 / 20 -- 0.06759563088417053
5 / 20 -- 0.05465726554393768
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 1183.608


That's it, this is how to track progress of Sinkhorn, Low-rank Sinkhorn, and Gromov-Wasserstein solvers.