# Constrained Optimal Transport

This notebook provides a tutorial on using the {class}`OTT's <ott.solvers.linear.sinkhorn.Sinkhorn>` to solve constrained optimal transport (COT) problems. The tutorial heavily relies on the paper {cite}`tang:24` and the ott-jax package.

## Notations

Throughout the notebook, we define the following:

- **Transport Matrix Set**: Given $a, b \in \mathbb{R}^n$, which represent densities, let:
  $$
  \mathcal{U}(a, b) = \left\{ P \in \mathbb{R}_+^{n \times n} : P\mathbf{1} = a, \, P^T\mathbf{1} = b \right\}
  $$
  denote the set of transport matrices.

- **Entry-wise Inner Product**: For $C, P \in \mathbb{R}^{n \times n}$, the entry-wise inner product is defined as:
  $$
  \langle C, P \rangle = \sum_{1 \leq i, j \leq n} C_{ij} P_{ij}
  $$

- **Entropy Regularization Term**: For $P \in \mathbb{R}^{n \times n}$ and scalars $s_1, \dots, s_K$, the entropy regularization term is:
  $$
  H(P, s_1, \dots, s_K) = \sum_{1 \leq i, j \leq n} P_{ij} \log P_{ij} + \sum_{k=1}^K s_k \log s_k
  $$

## Packages Import

In [2]:
from dataclasses import dataclass
from typing import Any

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.geometry.geometry import Geometry
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

## Utils

In [None]:
@dataclass
class NewtonWrapper:
    """Wrapper for the Newton solver.
    This wrapper is used to provide a consistent interface for the
    Newton solver, which is used to solve the linear problem.

    Attributes:
        obj: The objective function value.
        grad: The gradient of the objective function.
        hess: The Hessian of the objective function.
    """

    obj: Any
    grad: Any
    hess: Any

## Constrained Optimal Transport (COT)

Given $a, b \in \mathbb{R}^n$, a cost matrix $C \in \mathbb{R}^{n \times n}$, and matrices $D_1, \dots, D_K, D_{K+1}, \dots, D_{K+L} \in \mathbb{R}^{n \times n}$, the corresponding constrained optimal transport problem (COT) is formulated as:

$$
\min_{P \in \mathcal{U}(a, b)} \langle C, P \rangle \quad \text{subject to} \quad 
\begin{array}{cc}
\forall k = 1, \dots, K & \langle D_k, P \rangle \geq 0 \\
\forall l = 1, \dots, L & \langle D_{K+l}, P \rangle = 0
\end{array}
$$

In [None]:
class ConstrainedLinearProblem:
    """
    This class implements a constrained linear problem with a cost matrix,
    constraints, and a regularization parameter epsilon. It provides methods
    to round an approximate transport matrix to ensure it lies in the set U(a,b)
    and to compute the cost of an approximate transport plan.

    Attributes:
        a (jnp.array): distribution a of shape (n,)
        b (jnp.array): distribution b of shape (n,)
        epsilon (float): regularization parameter
        cost_matrix (jnp.array): cost matrix of shape (n,n)
        constraints (jnp.array): constraints for the transport problem of shape (n,n,K+L)

    Methods:
        round(F: jnp.array) -> jnp.array:
            Round an approximate transport matrix F to ensure it lies in
            the set U(a,b)={F: F1=a, F^T1=b, F>=0}.
        cost(P: jnp.array) -> jnp.array:
            Compute the cost of an approximate transport plan P.
    """

    def __init__(
        self,
        a: jnp.array,
        b: jnp.array,
        cost_matrix: jnp.array,
        epsilon: float,
        constraints: jnp.array = None,
    ):
        self.a = a
        self.b = b
        self.epsilon = epsilon
        self.cost_matrix = cost_matrix
        self.constraints = constraints

    def round(self, F: jnp.array) -> jnp.array:
        """Round an approximate transport matrix F to ensure it lies in
        the set U(a,b)={F: F1=a, F^T1=b, F>=0}.

        Args:
            F (jnp.array): approximate transport matrix of shape (n,n)

        Returns:
            jnp.array: rounded transport matrix lying in U(a,b)
        """
        X = jnp.diag(jnp.minimum(self.a / jnp.sum(F, axis=1), 1))
        F = X @ F
        Y = jnp.diag(jnp.minimum(self.b / jnp.sum(F, axis=0), 1))
        F = F @ Y
        err_a, err_b = self.a - jnp.sum(F, axis=1), self.b - jnp.sum(F, axis=0)
        return F + jnp.outer(err_a, err_b) / jnp.sum(err_a)

    def cost(self, P: jnp.array) -> jnp.array:
        """Compute the cost of an approximate transport plan P.

        Args:
            P (jnp.array): approximate transport plan of shape (n,n)

        Returns:
            jnp.array: cost of the transport plan
        """
        return jnp.sum(self.cost_matrix * round(P, self.a, self.b))

In [6]:
a = jnp.array([0.2, 0.5, 0.3])
b = jnp.array([0.4, 0.1, 0.5])

F = jnp.ones((len(a), len(b)))  # or use random values if you prefer

jnp.sum(a).shape

()

## Entropic Relaxation

Motivated by the results in {cite}`tang:24`, we solve an entropic relaxation of the COT problem, which has an exponentially close solution to the original COT problem. The entropic relaxation is given by:

$$
\min_{P \in \mathcal{U}(a, b), \, s \in \mathbb{R}_+^K} \left( \langle C, P \rangle + \varepsilon H(P, s_1, \dots, s_K) \right) \quad \text{subject to} \quad 
\begin{array}{cc}
\forall k = 1, \dots, K & \langle D_k, P \rangle = s_k \\
\forall l = 1, \dots, L & \langle D_{K+l}, P \rangle = 0
\end{array}
$$

### Primal-Dual Formulation

The associated primal-dual problem is given by:

$$
\max_{u, v \in \mathbb{R}^n; \, t \in \mathbb{R}^{K+L}} \min_{P, s} L(u, v, t; P, s)
$$

where the Lagrangian $L(u, v, t; P, s)$ is defined as:

$$
L(u, v, t; P, s) := \varepsilon \langle P, \log P \rangle + \langle C, P \rangle - \langle u, P\mathbf{1} - a \rangle - \langle v, P^T\mathbf{1} - b \rangle + \varepsilon \sum_{k=1}^K s_k \log s_k + \sum_{k=1}^K t_k s_k + \sum_{m=1}^{K+L} h_m \langle D_m, P \rangle
$$

We define the Lyapunov function by $f(u, v, t) = \min_{P, s} L(u, v, t; P, s)$. By the minimax theorem, solving the entropic relaxation is equivalent to maximizing $f$, which admits the following formulation:

$$
f(u, v, t) = -\varepsilon \sum_{1 \leq i, j \leq n} \exp \left( \frac{1}{\varepsilon} \left( -C_{ij} + \sum_{m=1}^{K+L} t_m (D_m)_{ij} + u_i + v_j \right) - 1 \right)
+ \sum_{i=1}^n u_i a_i + \sum_{j=1}^n v_j b_j - \varepsilon \sum_{k=1}^K \exp \left( - \frac{1}{\varepsilon} t_k - 1 \right)
$$

with the intermediate transport plan define as: 

$$ P_\varepsilon = \exp \left( \frac{1}{\varepsilon} \left( -C + \sum_{m=1}^{K+L} t_m D_m + u\mathbf{1}^T + \mathbf{1}v^T \right) - 1 \right)$$

In [None]:
class ConstrainedSinkhorn:
    """
    This class implements the constrained Sinkhorn algorithm for solving
    the optimal transport problem with constraints. It provides methods
    to compute the intermediate transport plan, the Lyapunov function,
    and the Newton step for the constrained Sinkhorn problem.

    Attributes:
        cot_lp (ConstrainedLinearProblem): instance of the constrained linear problem
        geometry (Geometry): geometry object for the transport problem
        K (int): number of constraints
        Ds (jnp.array): constraint matrices of shape (n,n,K+L)

    Methods:
        intermediate_transport(u: jnp.array, v: jnp.array, h: jnp.array) -> jnp.array:
            Compute the intermediate transport plan for the constrained Sinkhorn problem.
        lyapunov(u: jnp.array, v: jnp.array, h: jnp.array) -> jnp.array:
            Compute the Lyapunov function for the constrained Sinkhorn problem.
        newton_step(u: jnp.array, v: jnp.array, convergence_tol: float = 1e-6, max_iter: int = 20):
            Compute the Newton step for the constrained Sinkhorn problem.
        backtracking_line_search(newton_wrap: NewtonWrapper, t, delta, alpha_start: float = 1.0,
            rho: float = 0.5, c: float = 1e-4, max_line_search_iter: int = 20):
            Perform backtracking line search to find a suitable step size.
    """

    def __init__(self, cot_lp):
        self.cot_lp = cot_lp
        self.geometry = Geometry(
            cost_matrix=cot_lp.cost_matrix, epsilon=cot_lp.epsilon
        )
        self.K = cot_lp.constraints["K"]
        self.Ds = jnp.copy(cot_lp.constraints["matrices"])  # shape (n, n, K+L)

    def intermediate_transport(
        self, u: jnp.array, v: jnp.array, h: jnp.array
    ) -> jnp.array:
        """Compute the intermediate transport plan for the constrained Sinkhorn problem.

        Args:
            u (jnp.array): dual scaling for a of shape (n,)
            v (jnp.array): dual scaling for b of shape (n,)
            h (jnp.array): constraint dual variables of shape (K+L,)

        Returns:
            jnp.array: intermediate transport plan of shape (n,n)
        """
        modulation = jnp.tensordot(
            self.Ds[:, :, : self.K], h[: self.K], axes=([2], [0])
        )
        return (
            self.geometry.transport_from_potentials(u, v)
            * jnp.exp(modulation / self.geometry.epsilon)
            / jnp.exp(1)
        )

    def lyapunov(self, u: jnp.array, v: jnp.array, h: jnp.array) -> jnp.array:
        """Compute the Lyapunov function for the constrained Sinkhorn problem.

        Args:
            u (jnp.array): dual scaling for a of shape (n,)
            v (jnp.array): dual scaling for b of shape (n,)
            h (jnp.array): constraint dual variables of shape (K+L,)

        Returns:
            jnp.array: Lyapunov function value of shape ()
        """
        # Compute the Lyapunov (objective) function
        transp = self.intermediate_transport(u, v, h)
        return (
            -self.geometry.epsilon * jnp.sum(transp)
            + jnp.sum(u * self.cot_lp.a)
            + jnp.sum(v * self.cot_lp.b)
            - self.geometry.epsilon
            * jnp.sum(jnp.exp(-(1 / self.geometry.epsilon) * h[: self.K] - 1))
        )

    def newton_step(
        self,
        u: jnp.array,
        v: jnp.array,
        convergence_tol: float = 1e-6,
        max_iter: int = 20,
    ):
        """Compute the Newton step for the constrained Sinkhorn problem.

        Args:
            u (jnp.array): dual scaling for a of shape (n,)
            v (jnp.array): dual scaling for b of shape (n,)
            convergence_tol (float): tolerance for convergence
            max_iter (int): maximum number of iterations

        Returns:
            jnp.array: Newton step of shape (K+L,)
        """

        def objective(w: jnp.array) -> jnp.array:
            return self.lyapunov(u + w[:1] * jnp.ones_like(u), v, w[1:])

        grad_fn = jax.grad(objective)
        hess_fn = jax.hessian(objective)

        newton_wrap = NewtonWrapper(
            obj=objective,
            grad=grad_fn,
            hess=hess_fn,
        )

        w = jnp.ones(self.Ds.shape[2] + 1)

        for _ in range(max_iter):
            grad = grad_fn(w)
            hess = hess_fn(w)

            # Solve Newton direction
            delta = -jnp.linalg.solve(hess, grad)

            # Line search
            alpha = self.backtracking_line_search(newton_wrap, t, delta)

            # Update t
            t = t + alpha * delta

            if jnp.linalg.norm(delta) < convergence_tol:
                break

        return t

    def backtracking_line_search(
        self,
        newton_wrap: NewtonWrapper,
        t,
        delta,
        alpha_start: float = 1.0,
        rho: float = 0.5,
        c: float = 1e-4,
        max_line_search_iter: int = 20,
    ):
        """Perform backtracking line search to find a suitable step size.

        Args:
            newton_wrap (NewtonWrapper): wrapper for objective, gradient, and Hessian
            t (jnp.array): current point
            delta (jnp.array): Newton direction
            alpha_start (float): initial step size
            rho (float): reduction factor for step size
            c (float): sufficient decrease condition constant
            max_line_search_iter (int): maximum number of iterations for line search

        Returns:
            float: step size that satisfies the sufficient decrease condition
        """
        alpha = alpha_start
        grad = newton_wrap.grad(t)
        slope = jnp.dot(grad, delta)

        for _ in range(max_line_search_iter):
            if (
                newton_wrap.obj(t + alpha * delta)
                <= newton_wrap.obj(t) + c * alpha * slope
            ):
                break
            alpha *= rho
        return alpha

In [None]:
def sinkhorn_with_constraints(
    geom, a, b, f_objective, u_init, v_init, t_init, N, eta
):
    """Implements Sinkhorn-type algorithm under linear constraint.

    Args:
        geom: instance of ConstrainedGeometry
        r: left marginal (target row sums)
        c: right marginal (target column sums)
        f_objective: function that solves the maximization problem in a-step
        x_init, y_init: initial dual variables
        a_init: initial constraint dual variable
        N: number of iterations
        eta: regularization parameter (inverse of epsilon)
    """
    u, v, t = u_init, v_init, t_init

    for i in range(N):
        # Row update
        geom.update_dual_t(t)  # Update cost matrix with current a

        log_P1 = geom.apply_lse_kernel(u, v) - 1  # log(P @ 1)
        u += (jnp.log(a) - log_P1) / eta

        log_PT1 = geom.apply_lse_kernel(u, v, axis=1) - 1  # log(P^T @ 1)
        v += (jnp.log(b) - log_PT1) / eta

        # Constraint dual update (problem-specific)
        t_new, h = geom.lyapunov(u, v, t, a, b)
        t = t_new

        # Update x with shift by t
        u += h * jnp.ones_like(u)

    return u, v, t

## Numerical Experiments