# Constrained Optimal Transport

This notebook provides a tutorial on using the {class}`OTT's <ott.solvers.linear.sinkhorn.Sinkhorn>` to solve constrained optimal transport (COT) problems. The tutorial heavily relies on the paper {cite}`tang:24` and the ott-jax package.

## Packages Import

In [104]:
from dataclasses import dataclass
from typing import Any, Optional, Tuple

from tqdm import trange

import jax
import jax.numpy as jnp

# from jax.scipy.sparse import csr_matrix
from jax.experimental import sparse as jsparse
from jax.scipy.sparse.linalg import cg

from ott.geometry import pointcloud
from ott.geometry.geometry import Geometry
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

## Utils

In [10]:
@dataclass
class NewtonWrapper:
    """Wrapper for the Newton solver.
    This wrapper is used to provide a consistent interface for the
    Newton solver, which is used to solve the linear problem.

    Attributes:
        obj: The objective function value.
        grad: The gradient of the objective function.
        hess: The Hessian of the objective function.
    """

    obj: Any
    grad: Any
    hess: Any

## Constrained Linear Problem (CLP)

Given $a, b \in \mathbb{R}^n$, a cost matrix $C \in \mathbb{R}^{n \times n}$, and matrices $D_1, \dots, D_K, D_{K+1}, \dots, D_{K+L} \in \mathbb{R}^{n \times n}$, the corresponding constrained optimal transport problem (COT) is formulated as:

$$
\min_{P \in \mathcal{U}(a, b)} \langle C, P \rangle \quad \text{subject to} \quad 
\begin{array}{cc}
\forall k = 1, \dots, K & \langle D_k, P \rangle \geq 0 \\
\forall l = 1, \dots, L & \langle D_{K+l}, P \rangle = 0
\end{array}
$$

Where:
- $ \mathcal{U}(a, b) $ is the set of transport plans $ P $ such that $ P1 = a $ and $ P^T1 = b $ (i.e., the marginals are fixed),
- $ \langle C, P \rangle $ represents the inner product between the cost matrix $C$ and the transport plan $P$, which is computed as $ \sum_{i,j} C_{ij} P_{ij} $,
- $ D_k $ are matrices representing **non-negativity** constraints on the transport plan (i.e., $ \langle D_k, P \rangle \geq 0 $ for $ k = 1, \dots, K $),
- $ D_{K+l} $ are matrices representing **equality** constraints on the transport plan (i.e., $ \langle D_{K+l}, P \rangle = 0 $ for $ l = 1, \dots, L $).

### Core Attributes of Constrained Linear Problem (CLP)

The core attributes of the constrained linear problem are implemented in the `ConstrainedLinearProblem` class. The class utilizes the `Geometry` class to define the cost matrix and the underlying geometry of the problem.

#### Key Attributes:

- **`a` (jnp.array)**: The first marginal distribution $a $ with shape $ (n,) $.
- **`b` (jnp.array)**: The second marginal distribution $ b $ with shape $ (n,) $.
- **`geom` (Geometry)**: An object that encapsulates the geometry (e.g., distance or structure) of the problem. This can include additional information, such as the cost matrix $ C $.
- **`constraints` (jnp.array, optional)**: The constraint matrices $ D_1, \dots, D_K, D_{K+1}, \dots, D_{K+L} $, typically of shape $ (n, n, K+L) $, specifying linear constraints on the transport matrix.

This class provides methods to:
- Round an approximate transport matrix to ensure it lies in the feasible set $ \mathcal{U}(a, b) $.
- Compute the cost of an approximate transport plan, taking into account the cost matrix.

In [100]:
class ConstrainedLinearProblem:
    """
    This class implements a constrained linear problem with a cost matrix,
    constraints, and a regularization parameter epsilon. It provides methods
    to round an approximate transport matrix to ensure it lies in the set U(a,b)
    and to compute the cost of an approximate transport plan.

    Attributes:
        a (jnp.array): distribution a of shape (n,)
        b (jnp.array): distribution b of shape (n,)
        epsilon (float): regularization parameter
        cosh_matrix (jnp.array): cost matrix of shape (n,n)
        constraints (jnp.array): constraints for the transport problem of shape (n,n,K+L)

    Methods:
        round(F: jnp.array) -> jnp.array:
            Round an approximate transport matrix F to ensure it lies in
            the set U(a,b)={F: F1=a, F^T1=b, F>=0}.
        cost(P: jnp.array) -> jnp.array:
            Compute the cost of an approximate transport plan P.
    """

    def __init__(
        self,
        geom: Geometry,
        a: jnp.array,
        b: jnp.array,
        constraints: jnp.array,
        K: int,
    ):
        self.a = a
        self.b = b
        self.geom = geom
        self.constraints = constraints
        self.K = K

    def round(self, F: jnp.array) -> jnp.array:
        """Round an approximate transport matrix F to ensure it lies in
        the set U(a,b)={F: F1=a, F^T1=b, F>=0}.

        Args:
            F (jnp.array): approximate transport matrix of shape (n,n)

        Returns:
            jnp.array: rounded transport matrix lying in U(a,b)
        """
        X = jnp.diag(jnp.minimum(self.a / jnp.sum(F, axis=1), 1))
        F = X @ F
        Y = jnp.diag(jnp.minimum(self.b / jnp.sum(F, axis=0), 1))
        F = F @ Y
        err_a, err_b = self.a - jnp.sum(F, axis=1), self.b - jnp.sum(F, axis=0)
        return F + jnp.outer(err_a, err_b) / jnp.sum(err_a)

    def cost(self, P: jnp.array) -> jnp.array:
        """Compute the cost of an approximate transport plan P.

        Args:
            P (jnp.array): approximate transport plan of shape (n,n)

        Returns:
            jnp.array: cost of the transport plan
        """
        return jnp.sum(self.geom.cost_matrix * self.round(P))

    def violation(self, P: jnp.array) -> jnp.array:
        """Compute total constraint violation of approximate transport plan P."""
        P_rounded = self.round(P)
        constraint_values = jnp.sum(
            self.constraints * P_rounded[:, :, None], axis=(0, 1)
        )

        eq_violation = jnp.sum(
            jnp.abs(jnp.minimum(constraint_values[: self.K], 0.0))
        )
        ineq_violation = jnp.sum(jnp.abs(constraint_values[self.K :]))

        return eq_violation + ineq_violation

## Entropic Relaxation

Motivated by the results in {cite}`tang:24`, we solve an entropic relaxation of the Constrained Optimal Transport (COT) problem. This relaxation provides an exponentially close solution to the original COT problem. The entropic relaxation is formulated as follows:

$$
\min_{P \in \mathcal{U}(a, b), \, s \in \mathbb{R}_+^K}  \langle C, P \rangle + \varepsilon H(P, s_1, \dots, s_K)  \quad \text{subject to} \quad 
\begin{array}{cc}
\forall k = 1, \dots, K & \langle D_k, P \rangle = s_k \\
\forall l = 1, \dots, L & \langle D_{K+l}, P \rangle = 0
\end{array}
$$

Here, $ H(P, s_1, \dots, s_K) $ represents the entropy regularization term for a transport plan $P$ and scalars $s_1, \dots, s_K$, which is computed as $H(P, s_1, \dots, s_K) = \sum_{1 \leq i,j \leq n} P_{ij} \log P_{ij} + \sum_{k=1}^{K} s_k \log s_k$.

### Primal-Dual Formulation

The corresponding primal-dual formulation of the entropic relaxation is:

$$
\max_{u, v \in \mathbb{R}^n; \, h \in \mathbb{R}^{K+L}} \min_{P, s} \mathcal{L}_\varepsilon (u, v, h; P, s)
$$

where the Lagrangian $ \mathcal{L}_\varepsilon  $ is defined as:

$$
\mathcal{L}_\varepsilon (u, v, h; P, s) := \varepsilon \langle P, \log P \rangle + \langle C, P \rangle - \langle u, P\mathbf{1} - a \rangle - \langle v, P^T\mathbf{1} - b \rangle + \varepsilon \sum_{k=1}^K s_k \log s_k + \sum_{k=1}^K h_k s_k + \sum_{m=1}^{K+L} h_m \langle D_m, P \rangle
$$

We define the Lyapunov function as:

$$
f(u, v, h) = \min_{P, s} \mathcal{L}_\varepsilon (u, v, h; P, s).
$$

The corresponding intermediate transport plan $ P_\varepsilon(u, v, h) $ is given by:

$$
P_\varepsilon (u,v,h) = \exp \left( \frac{1}{\varepsilon} \left( -C + \sum_{m=1}^{K+L} h_m D_m + u \mathbf{1}^T + \mathbf{1} v^T \right) - 1 \right).
$$

The Lyapunov function $ f(u, v, h) $ has the following explicit formulation:

$$
f(u, v, t) = -\varepsilon \sum_{1 \leq i, j \leq n} \exp \left( \frac{1}{\varepsilon} \left( -C_{ij} + \sum_{m=1}^{K+L} h_m (D_m)_{ij} + u_i + v_j \right) - 1 \right)
+ \sum_{i=1}^n u_i a_i + \sum_{j=1}^n v_j b_j - \varepsilon \sum_{k=1}^K \exp \left( - \frac{1}{\varepsilon} h_k - 1 \right)
$$

By the minimax theorem, solving the entropic relaxation problem is equivalent to maximizing $f$. After maximizing for $(u, v, h)$, the corresponding optimal transport plan $ P_\varepsilon^* $ is a smoothed approximation of the original transport matrix $P^*$. As $\varepsilon \to 0$, the solution $ P_\varepsilon^*$ converges to the optimal solution of the original CLP problem.

### Constrained Sinkhorn Solver

The `ConstrainedSinkhorn` class is an implementation of the Sinkhorn algorithm tailored for regularized optimal transport with additional constraints. It is designed to solve constrained linear problem (CLP) problems. 

The algorithm iteratively computes a transport plan that satisfies the marginal constraints (given by `a` and `b`) while minimizing the transportation cost subject to the constraints. The solver utilizes both Sinkhorn scaling steps and Newton steps to ensure convergence, with additional flexibility provided by a dual-optimization approach.

The state of the solver is encapsulated in the `ConstrainedSinkhornState` dataclass. This state holds the current values of the dual variables $ u $, $ v $, and $ h $, as well as information about the solver's convergence and error.

#### Fields:
- `u`: The dual variable corresponding to the row scaling (size $ n $).
- `v`: The dual variable corresponding to the column scaling (size $ n $).
- `h`: The dual variable associated with the constraints (size $ K + L$).
- `error`: (Optional) The current error of the solution, used for convergence checking.
- `converged`: Boolean flag indicating whether the solver has converged.


In [132]:
def _sparsify(P: jnp.ndarray, rho: jnp.array) -> jnp.ndarray:
    """Sparsify the transport plan."""
    P = jnp.where(P > rho, P, 0.0)
    return P


@dataclass
class ConstrainedSinkhornState:
    """State of the Constrained Sinkhorn solver."""

    u: jnp.ndarray
    v: jnp.ndarray
    h: jnp.ndarray
    error: Optional[float] = None
    converged: bool = False


class ConstrainedSinkhorn:
    """
    Constrained Sinkhorn solver for regularized optimal transport with additional constraints.
    """

    def __init__(
        self,
        use_sns: bool = False,
        N1: int = 20,
        N2: int = 80,
        threshold: float = 1e-3,
    ):
        self.use_sns = use_sns
        self.N1 = N1
        self.N2 = N2 if use_sns else 0
        self.threshold = threshold

    def init(self, cot_lp) -> ConstrainedSinkhornState:
        """Initialize dual variables."""
        n = cot_lp.a.shape[0]
        u = jnp.zeros(n)
        v = jnp.zeros(n)
        h = jnp.zeros(cot_lp.constraints.shape[2])
        return ConstrainedSinkhornState(u=u, v=v, h=h)

    def _compute_transport(
        self, cot_lp, u: jnp.ndarray, v: jnp.ndarray, h: jnp.ndarray
    ) -> jnp.ndarray:
        """Compute intermediate transport plan."""
        constraints = cot_lp.constraints
        K = cot_lp.K
        modulation = jnp.einsum("ijk, k -> ij", constraints[:, :, :K], h[:K])
        print(modulation)
        base_transport = cot_lp.geom.transport_from_potentials(u, v)
        print(base_transport)
        return (
            base_transport
            * jnp.exp(modulation / cot_lp.geom.epsilon)
            / jnp.exp(1)
        )

    def _scaling_step(
        self, cot_lp, u: jnp.ndarray, v: jnp.ndarray, h: jnp.ndarray, axis: int
    ) -> jnp.ndarray:
        """Perform a scaling step (row or column)."""
        transp = self._compute_transport(cot_lp, u, v, h)
        if axis == 1:
            marg = jnp.sum(transp, axis=1)
            target = cot_lp.a
        else:
            marg = jnp.sum(transp, axis=0)
            target = cot_lp.b
        return cot_lp.geom.epsilon * (jnp.log(target) - jnp.log(marg))

    def _lyapunov(
        self,
        cot_lp,
        u: jnp.ndarray,
        v: jnp.ndarray,
        h: jnp.ndarray,
        sns: bool = False,
    ) -> jnp.ndarray:
        """Compute Lyapunov function."""
        transp = self._compute_transport(cot_lp, u, v, h)
        epsilon = cot_lp.geom.epsilon
        K = cot_lp.K
        lyapunov = (
            -epsilon * jnp.sum(transp)
            + jnp.sum(u * cot_lp.a)
            + jnp.sum(v * cot_lp.b)
            - epsilon * jnp.sum(jnp.exp(-(1 / epsilon) * h[:K] - 1))
        )
        if sns:
            lyapunov -= 0.5 * jnp.sum(u - v) ** 2
        return lyapunov

    def _projech_v_perp(self, z: jnp.ndarray, n: int) -> jnp.ndarray:
        """Project onto v⊥, v = (1_n, -1_n, 0)"""
        v = jnp.concatenate(
            [jnp.ones(n), -jnp.ones(n), jnp.zeros(z.shape[0] - 2 * n)]
        )
        return z - jnp.dot(z, v) / jnp.dot(v, v) * v

    def _approximate_hessian(
        self,
        lyap,
        transp: jnp.ndarray,
        u: jnp.ndarray,
        v: jnp.ndarray,
        h: jnp.ndarray,
        rho: float = 0.0,
    ) -> jnp.ndarray:
        """Compute approximate Hessian for the SNS Loop."""
        diag_px = jnp.diag(transp.sum(axis=1))
        diag_py = jnp.diag(transp.sum(axis=0))

        P_sparse = _sparsify(transp, rho)

        hess_uh = jax.hessian(lyap, (0, 2))(u, v, h)  # ∇²_{u,a}
        hess_vh = jax.hessian(lyap, (1, 2))(u, v, h)  # ∇²_{v,a}
        hess_hh = jax.hessian(lyap, (2, 2))(u, v, h)  # ∇²_{a,a}

        row1 = jnp.hstack([diag_px, P_sparse, hess_uh])
        row2 = jnp.hstack([P_sparse.T, diag_py, hess_vh])
        row3 = jnp.hstack([hess_uh.T, hess_vh.T, hess_hh])

        return jnp.vstack([row1, row2, row3])

    def _regularize_hessian(self, hess: jnp.ndarray, n: int) -> jnp.ndarray:
        v = jnp.concatenate(
            [jnp.ones(n), -jnp.ones(n), jnp.zeros(hess.shape[0] - 2 * n)]
        )
        return hess - v @ v.T

    def _tosparse(mat: jnp.ndarray) -> jsparse.BCOO:
        """Convert a dense matrix to sparse format."""
        return jsparse.BCOO.fromdense(mat)

    def _todense(mat: jsparse.BCOO) -> jnp.ndarray:
        """Convert a sparse matrix to dense format."""
        return mat.todense()

    def _solve_linear_system(
        self, A: jnp.ndarray, y: jnp.ndarray
    ) -> tuple[jnp.ndarray, int]:
        """Solve the linear system Ax = y using conjugate gradient."""
        A = self._tosparse(A)
        x, info = cg(A, y, tol=1e-10, maxiter=1000)
        return x, info

    def _newton_step(
        self,
        cot_lp,
        u: jnp.ndarray,
        v: jnp.ndarray,
        convergence_tol: float = 1e-6,
        max_iter: int = 20,
    ) -> jnp.ndarray:
        """Compute Newton step for updating h variables."""

        def objective(w: jnp.ndarray) -> jnp.ndarray:
            return self._lyapunov(
                cot_lp, u + w[:1] * jnp.ones_like(u), v, w[1:]
            )

        grad_fn = jax.grad(objective)
        hess_fn = jax.hessian(objective)
        newton_wrap = NewtonWrapper(obj=objective, grad=grad_fn, hess=hess_fn)

        w = jnp.ones(cot_lp.constraints.shape[2] + 1)

        for _ in range(max_iter):
            grad = grad_fn(w)
            hess = hess_fn(w)
            delta = -jnp.linalg.solve(hess, grad)
            alpha = self._backtracking_line_search(newton_wrap, w, delta)
            w += alpha * delta

            if jnp.linalg.norm(delta) < convergence_tol:
                break

        if jnp.linalg.norm(delta) >= convergence_tol:
            raise ValueError("Newton method did not converge.")

        return w

    def _backtracking_line_search(
        self,
        newton_wrap,
        t: jnp.ndarray,
        delta: jnp.ndarray,
        alpha_start: float = 1.0,
        rho: float = 0.5,
        c: float = 1e-4,
        max_line_search_iter: int = 20,
    ) -> float:
        """Backtracking line search for Newton step."""
        alpha = alpha_start
        grad = newton_wrap.grad(t)
        slope = jnp.dot(grad, delta)

        for _ in range(max_line_search_iter):
            if (
                newton_wrap.obj(t + alpha * delta)
                <= newton_wrap.obj(t) + c * alpha * slope
            ):
                break
            alpha *= rho

        if alpha < 1e-10:
            raise ValueError("Line search failed.")
        return alpha

    def _one_iteration_sinkhorn(
        self, cot_lp, state: ConstrainedSinkhornState
    ) -> ConstrainedSinkhornState:
        """Perform one iteration of constrained Sinkhorn algorithm.

        Args:
            cot_lp: Constrained linear problem.
            state: Current state of the solver.
            iteration: Current iteration number.

        Returns:
            ConstrainedSinkhornState: Updated state of the solver.
        """

        u, v, h = state.u, state.v, state.h

        # Scale rows
        u += self._scaling_step(cot_lp, u, v, h, axis=1)

        # Scale columns
        v += self._scaling_step(cot_lp, u, v, h, axis=0)

        # Newton step
        w = self._newton_step(cot_lp, u, v)
        u += w[0] * jnp.ones_like(u)
        h += w[1:]

        return ConstrainedSinkhornState(u=u, v=v, h=h)

    def _one_iteration_newton(
        self,
        cot_lp,
        state: ConstrainedSinkhornState,
        rho: float = 1e-6,
    ) -> ConstrainedSinkhornState:
        """Perform one iteration of the Newton method.

        Args:
            cot_lp: Constrained linear problem.
            state: Current state of the solver.
            iteration: Current iteration number.

        Returns:
            ConstrainedSinkhornState: Updated state of the solver.
        """

        u, v, h, n = state.u, state.v, state.h, state.u.shape[0]

        def f_tilde(u_, v_, h_):
            return self._lyapunov(cot_lp, u_, v_, h_, sns=True)

        newton_wrap = NewtonWrapper(
            f_tilde, jax.grad(f_tilde), jax.hessian(f_tilde)
        )

        transp = self._compute_transport(cot_lp, u, v, h)

        H = -(1 / cot_lp.geom.epsilon) * self._approximate_hessian(
            f_tilde, transp, u, v, h, rho
        )
        H = self._regularize_hessian(H, n=u.shape[0])

        z = jnp.concatenate([u, v, h])
        delta_z, info = self._solve_linear_system(H, -newton_wrap.grad(u, v, h))
        if info != 0:
            print("CG did not converge.")

        alpha = self._backtracking_line_search(newton_wrap, z, delta_z)
        z += alpha * delta_z

        return ConstrainedSinkhornState(u=z[:n], v=z[n : 2 * n], h=z[2 * n :])

    def __call__(self, cot_lp, callback=None) -> ConstrainedSinkhornState:
        """Run the constrained Sinkhorn solver.
        Args:
            cot_lp: Constrained linear problem.
        Returns:
            ConstrainedSinkhornState: Final state of the solver.
        """
        state = self.init(cot_lp)
        total_iterations = self.N1 + self.N2

        for iteration in trange(total_iterations, desc="Sinkhorn Epochs"):
            if iteration < self.N1:
                state = self._one_iteration_sinkhorn(cot_lp, state)
            elif iteration < self.N1 + self.N2 and self.use_sns:
                state = self._one_iteration_newton(cot_lp, state)

            if callback is not None:
                callback(state)

        return ConstrainedSinkhornState(
            u=state.u, v=state.v, h=state.h, converged=True
        )


class ConstrainedSinkhornRunner:
    def __init__(self, solver, cot_lp, store_matrices=False):
        self.solver = solver
        self.store_matrices = store_matrices

        self.transport_matrices = []
        self.costs = []
        self.violations = []
        self._cot_lp = cot_lp

    def _callback(self, state):
        matrix = self.solver._compute_transport(
            self._cot_lp, state.u, state.v, state.h
        )
        rounded = self._cot_lp.round(matrix)

        self.costs.append(self._cot_lp.cost(rounded))
        self.violations.append(self._cot_lp.violation(rounded))

        if self.store_matrices:
            self.transport_matrices.append(rounded.astype(jnp.float32))

    def run(self, cot_lp):
        self.transport_matrices.clear()
        self.costs.clear()
        self.violations.clear()
        self.final_state = None

        self._cot_lp = cot_lp
        self.final_state = self.solver(cot_lp, callback=self._callback)

### Difficulties

I wanted to implement the second algorithm of {cite}`tang:24`, I encountered some difficulties as jax seems not to be fully compatible with sparse matrix representation. 

## Numerical Experiments

### Random Assignment Problem 

In [None]:
# TODO simply and remove sinkhorn_sns

In [85]:
def generate_random_assignment_problem(
    n=500, epsilon=0.01, seed=42, tI=0.5, tE=0.5
):

    key = jax.random.key(seed)

    C = jax.random.uniform(key=key, minval=0, maxval=1, shape=(n, n))
    DI = jax.random.uniform(key=key, minval=0, maxval=1, shape=(n, n))
    DE = jax.random.uniform(key=key, minval=0, maxval=1, shape=(n, n))

    a = jnp.ones(n) / n
    b = jnp.ones(n) / n

    D1 = (DI - tI) / n
    D2 = (DE - tE) / n

    geom = Geometry(cost_matrix=C, epsilon=epsilon)

    return ConstrainedLinearProblem(
        geom=geom, a=a, b=b, constraints=jnp.stack([D1, D2], axis=2), K=1
    )

In [123]:
def run_sinkforn(n=500, epsilon=0.01):
    cot_lp = generate_random_assignment_problem(n=n, epsilon=epsilon)

    # Algorithm 1: Basic Sinkhorn
    sinkhorn_solver = ConstrainedSinkhorn(N1=3, threshold=1e-7)
    runner1 = ConstrainedSinkhornRunner(
        solver=sinkhorn_solver, cot_lp=cot_lp, store_matrices=True
    )
    runner1.run(cot_lp)
    return runner1

In [119]:
def run_sinkhorn_sns(n=500, epsilon=0.01):
    cot_lp = generate_random_assignment_problem(n=n, epsilon=epsilon)

    # Algorithm 2: Sinkhorn-Newton-Sparse
    newton_solver = ConstrainedSinkhorn(
        threshold=1e-7,
        use_sns=True,
        N1=20,  # 20 Sinkhorn warmup steps
        N2=80,  # Newton steps after
    )
    runner2 = ConstrainedSinkhornRunner(
        solver=newton_solver, cot_lp=cot_lp, store_matrices=False
    )
    runner2.run(cot_lp)

    return runner2

In [102]:
cot_lp = generate_random_assignment_problem(n=500, epsilon=0.01)
cot_lp.constraints.shape

(500, 500, 2)

In [133]:
runner1 = run_sinkforn(n=500, epsilon=0.01)

Sinkhorn Epochs:   0%|          | 0/3 [00:00<?, ?it/s]

[[-0.  0.  0. ...  0. -0. -0.]
 [ 0. -0. -0. ... -0.  0. -0.]
 [-0. -0. -0. ...  0.  0.  0.]
 ...
 [ 0.  0. -0. ... -0. -0. -0.]
 [ 0.  0. -0. ... -0.  0. -0.]
 [-0. -0.  0. ...  0. -0. -0.]]
[[2.1650169e-12 5.4472860e-31 2.1565559e-38 ... 0.0000000e+00
  1.9458725e-01 7.2392836e-16]
 [0.0000000e+00 3.7714716e-07 6.6030156e-17 ... 4.5822802e-17
  9.9644900e-35 2.7383852e-05]
 [1.8161704e-04 1.3900607e-19 4.1368245e-17 ... 1.3783516e-23
  1.0402777e-34 1.3883547e-26]
 ...
 [1.8268673e-25 5.9871477e-31 7.1876061e-16 ... 8.6292865e-17
  4.5884278e-21 1.2446233e-17]
 [2.9497292e-25 0.0000000e+00 3.7062878e-04 ... 4.2354602e-13
  2.2470801e-37 4.3548018e-01]
 [1.9013572e-20 8.2777811e-15 8.6234069e-37 ... 2.8374440e-23
  8.7695345e-15 1.4574665e-07]]
[[-0.  0.  0. ...  0. -0. -0.]
 [ 0. -0. -0. ... -0.  0. -0.]
 [-0. -0. -0. ...  0.  0.  0.]
 ...
 [ 0.  0. -0. ... -0. -0. -0.]
 [ 0.  0. -0. ... -0.  0. -0.]
 [-0. -0.  0. ...  0. -0. -0.]]
[[1.5313757e-15 3.8530294e-34 0.0000000e+00 ... 0.00

Sinkhorn Epochs:  33%|███▎      | 1/3 [00:03<00:07,  3.58s/it]

[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan

Sinkhorn Epochs:  67%|██████▋   | 2/3 [00:07<00:03,  3.59s/it]

[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan

Sinkhorn Epochs: 100%|██████████| 3/3 [00:10<00:00,  3.58s/it]

Traced<ConcreteArray([[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]], dtype=float32)>with<JVPTrace(level=4/0)> with
  primal = Traced<ConcreteArray([[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]], dtype=float32)>with<JVPTrace(level=2/0)> with
    primal = Array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)
    tangent = Traced<ShapedArray(float32[500,500])>with<BatchTrace(level=1/0)> with
      val = Array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan,




In [None]:
runner1.transport_matrices[0]

(500, 500)