# Constrained Optimal Transport

We provide in this notebook a tutorial for the use of {class}`OTT's <ott.solvers.linear.sinkhorn.Sinkhorn>` to solve constrained optimal transport problem. This tutorials relies heavily on both the paper {cite}`tang:24` and the ott-jax package.

Throughout the notebook:
- given $a$, $b \in \mathbb{R}^n$ respectively the source and target density, $\mathcal{U}(a,b)=\left\{P \in \mathbb{R}_+^{n\times n}:~P\mathbf{1}=a,~P^T\mathbf{1}=b\right\}$ denote the correponding set of transport matrix.
- given $C$, $P\in \mathbb{R}^{n\times n}$, $\langle C, P \rangle=\sum_{1\leq i,j\leq n}C_{ij}P_{ij}$ denote the entry-wise inner product.
- given $P \in \mathbb{R}^{n\times n}$ and $s_1, \dots, s_K$ is given by $H(P,s_1,\dots,s_K)=\sum_{1\leq i,j\leq n}P_{ij}\log P_{ij} + \sum_{k=1}^K s_k \log s_k$ denote the entropy regularization term.

Given $b$, $a \in \mathbb{R}^n$, $C\in \mathbb{R}^{n\times n}$ the cost matrix and matrices $D_1, \dots, D_K, D_{K+1}, \dots, D_{K+L} \in \mathbb{R}^{n\times n}$, the corresponding constrained optimal transport problem (COT) is defined as:
$$\min_{P \in \mathcal{U}(a,b)} \langle C, P \rangle \quad \text{subject to} \quad \begin{array}{cc}
\forall k=1,\dots,K & \langle D_k, P \rangle \geq 0\\
\forall l=1,\dots,L & \langle D_{K+l}, P \rangle =0 
\end{array} $$

Motivated by {cite}`tang:24`'s results, one may solve instead the following entropic relaxation which has an exponentially close solution to the original COT:

$$\min_{P \in \mathcal{U}(a,b),~s\in \mathbb{R}_+^{K}} \langle C, P \rangle + \frac{1}{\eta}H(P,s_1,\dots,s_K) \quad \text{subject to} \quad \begin{array}{cc}
\forall k=1,\dots,K & \langle D_k, P \rangle = s_k\\
\forall l=1,\dots,L & \langle D_{K+l}, P \rangle =0 
\end{array} $$

The associated primal-dual problem is given by:

$$\max_{u,v \in \mathbb{R}^n;t\in \mathbb{R}^{K+L}}\min_{P,s}L(u,v,t;P,s) := \frac{1}{\eta} \langle P, \log P \rangle + \langle C, P \rangle - \langle u, P\mathbf{1}-a \rangle - \langle v, P^T\mathbf{1}- b\rangle + \frac{1}{\eta}\sum_{k=1}^Ks_k\log s_k + \sum_{k=1}^K t_k s_k + \sum_{m=1}^{K+L}h_m\langle D_m, P\rangle$$

We consider $f(u,v,t)=\min_{P,s}L(u,v,t;P,s)$. By the minimax theorem, solving the entropic relaxation is equivalent to maximize $f$ that admits the following formulation:
$$f(u,v,t)=-\frac{1}{\eta}\sum_{1\leq i,j\leq n}\exp{\left(\eta\left(-C_{ij}+\sum_{m=1}^{K+L}t_m(D_m)_{ij}+u_i+v_j\right)-1\right)} + \sum_{i=1}^n u_ia_i + \sum_{j=1}^n v_jb_j - \frac{1}{\eta}\sum_{k=1}^{K}\exp\left(-\eta t_k -1\right)

## Packages Import

In [3]:
import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

In [13]:
A = jnp.array([[1.0, 2.0, -2.5], [3.0, 4.0, 6]])
a = jnp.sum(A, axis=1)  # sum over rows

In [14]:
X = jnp.minimum(a, 1)

In [15]:
X

Array([0.5, 1. ], dtype=float32)

## Utils

In [None]:
def round(F: jnp.array, a: jnp.array, b: jnp.array) -> jnp.array:
    """Round an approximate transport matrix F to ensure it lies in
    the set U(a,b)={F: F1=a, F^T1=b, F>=0}.

    Args:
        F (jnp.array): approximate transport matrix of shape (n,m)
        a (jnp.array): marginal a of shape (n,)
        b (jnp.array): marginal b of shape (m,)

    Returns:
        jnp.array: rounded transport matrix lying in U(a,b)
    """
    X = jnp.minimum(a / jnp.sum(F, axis=1), 1)
    F = X @ F
    Y = jnp.minimum(b / jnp.sum(F, axis=0), 1)
    F = F @ Y
    err_a, err_b = a - jnp.sum(F, axis=1), b - jnp.sum(F, axis=0)
    return F + jnp.outer(err_a, err_b) / jnp.sum(err_a)


def cost(P: jnp.array, C: jnp.array, a: jnp.array, b: jnp.array) -> float:
    """Compute the cost of an approximate transport plan P.

    Args:
        P (jnp.array): approximate transport plan of shape (n,n)
        C (jnp.array): cost matrix of shape (n,n)
        a (jnp.array): marginal a of shape (n,)
        b (jnp.array): marginal b of shape (n,)

    Returns:
        float: cost of the transport plan
    """
    return jnp.sum(C * round(P, a, b))

## Solvers implementation for the constrained OT

### Sinkhorn-type Algorithm for Constrained OT

### Sinkhorn-Newton-Sparse (SNS) for Constrained OT

### Sinkhorn-Newton-Sparse with entropy regularization scheduling for Constrained OT

## Numerical Experiments