# Constrained Optimal Transport

This notebook provides a tutorial on using the {class}`OTT's <ott.solvers.linear.sinkhorn.Sinkhorn>` to solve constrained optimal transport (COT) problems. The tutorial heavily relies on the paper {cite}`tang:24` and the ott-jax package.

## Packages Import

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.geometry.geometry import Geometry
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

## Utils

In [None]:
@dataclass
class NewtonWrapper:
    """Wrapper for the Newton solver.
    This wrapper is used to provide a consistent interface for the
    Newton solver, which is used to solve the linear problem.

    Attributes:
        obj: The objective function value.
        grad: The gradient of the objective function.
        hess: The Hessian of the objective function.
    """

    obj: Any
    grad: Any
    hess: Any

## Constrained Linear Problem (CLP)

Given $a, b \in \mathbb{R}^n$, a cost matrix $C \in \mathbb{R}^{n \times n}$, and matrices $D_1, \dots, D_K, D_{K+1}, \dots, D_{K+L} \in \mathbb{R}^{n \times n}$, the corresponding constrained optimal transport problem (COT) is formulated as:

$$
\min_{P \in \mathcal{U}(a, b)} \langle C, P \rangle \quad \text{subject to} \quad 
\begin{array}{cc}
\forall k = 1, \dots, K & \langle D_k, P \rangle \geq 0 \\
\forall l = 1, \dots, L & \langle D_{K+l}, P \rangle = 0
\end{array}
$$

Where:
- $ \mathcal{U}(a, b) $ is the set of transport plans $ P $ such that $ P1 = a $ and $ P^T1 = b $ (i.e., the marginals are fixed),
- $ \langle C, P \rangle $ represents the inner product between the cost matrix $C$ and the transport plan $P$, which is computed as $ \sum_{i,j} C_{ij} P_{ij} $,
- $ D_k $ are matrices representing **non-negativity** constraints on the transport plan (i.e., $ \langle D_k, P \rangle \geq 0 $ for $ k = 1, \dots, K $),
- $ D_{K+l} $ are matrices representing **equality** constraints on the transport plan (i.e., $ \langle D_{K+l}, P \rangle = 0 $ for $ l = 1, \dots, L $).

### Core Attributes of Constrained Linear Problem (CLP)

The core attributes of the constrained linear problem are implemented in the `ConstrainedLinearProblem` class. The class utilizes the `Geometry` class to define the cost matrix and the underlying geometry of the problem.

#### Key Attributes:

- **`a` (jnp.array)**: The first marginal distribution $a $ with shape $ (n,) $.
- **`b` (jnp.array)**: The second marginal distribution $ b $ with shape $ (n,) $.
- **`geom` (Geometry)**: An object that encapsulates the geometry (e.g., distance or structure) of the problem. This can include additional information, such as the cost matrix $ C $.
- **`constraints` (jnp.array, optional)**: The constraint matrices $ D_1, \dots, D_K, D_{K+1}, \dots, D_{K+L} $, typically of shape $ (n, n, K+L) $, specifying linear constraints on the transport matrix.

This class provides methods to:
- Round an approximate transport matrix to ensure it lies in the feasible set $ \mathcal{U}(a, b) $.
- Compute the cost of an approximate transport plan, taking into account the cost matrix.

In [None]:
class ConstrainedLinearProblem:
    """
    This class implements a constrained linear problem with a cost matrix,
    constraints, and a regularization parameter epsilon. It provides methods
    to round an approximate transport matrix to ensure it lies in the set U(a,b)
    and to compute the cost of an approximate transport plan.

    Attributes:
        a (jnp.array): distribution a of shape (n,)
        b (jnp.array): distribution b of shape (n,)
        epsilon (float): regularization parameter
        cost_matrix (jnp.array): cost matrix of shape (n,n)
        constraints (jnp.array): constraints for the transport problem of shape (n,n,K+L)

    Methods:
        round(F: jnp.array) -> jnp.array:
            Round an approximate transport matrix F to ensure it lies in
            the set U(a,b)={F: F1=a, F^T1=b, F>=0}.
        cost(P: jnp.array) -> jnp.array:
            Compute the cost of an approximate transport plan P.
    """

    def __init__(
        self,
        geom: Geometry,
        a: jnp.array,
        b: jnp.array,
        constraints: jnp.array = None,
    ):
        self.a = a
        self.b = b
        self.geom = geom
        self.constraints = constraints

    def round(self, F: jnp.array) -> jnp.array:
        """Round an approximate transport matrix F to ensure it lies in
        the set U(a,b)={F: F1=a, F^T1=b, F>=0}.

        Args:
            F (jnp.array): approximate transport matrix of shape (n,n)

        Returns:
            jnp.array: rounded transport matrix lying in U(a,b)
        """
        X = jnp.diag(jnp.minimum(self.a / jnp.sum(F, axis=1), 1))
        F = X @ F
        Y = jnp.diag(jnp.minimum(self.b / jnp.sum(F, axis=0), 1))
        F = F @ Y
        err_a, err_b = self.a - jnp.sum(F, axis=1), self.b - jnp.sum(F, axis=0)
        return F + jnp.outer(err_a, err_b) / jnp.sum(err_a)

    def cost(self, P: jnp.array) -> jnp.array:
        """Compute the cost of an approximate transport plan P.

        Args:
            P (jnp.array): approximate transport plan of shape (n,n)

        Returns:
            jnp.array: cost of the transport plan
        """
        return jnp.sum(self.cost_matrix * round(P, self.a, self.b))

In [6]:
a = jnp.array([0.2, 0.5, 0.3])
b = jnp.array([0.4, 0.1, 0.5])

F = jnp.ones((len(a), len(b)))  # or use random values if you prefer

jnp.sum(a).shape

()

## Entropic Relaxation

Motivated by the results in {cite}`tang:24`, we solve an entropic relaxation of the Constrained Optimal Transport (COT) problem. This relaxation provides an exponentially close solution to the original COT problem. The entropic relaxation is formulated as follows:

$$
\min_{P \in \mathcal{U}(a, b), \, s \in \mathbb{R}_+^K}  \langle C, P \rangle + \varepsilon H(P, s_1, \dots, s_K)  \quad \text{subject to} \quad 
\begin{array}{cc}
\forall k = 1, \dots, K & \langle D_k, P \rangle = s_k \\
\forall l = 1, \dots, L & \langle D_{K+l}, P \rangle = 0
\end{array}
$$

In this formulation, $ H(P, s_1, \dots, s_K) $ represents the entropy regularization term, which is computed as:

$$
H(P, s_1, \dots, s_K) = \sum_{1 \leq i,j \leq n} P_{ij} \log P_{ij} + \sum_{k=1}^{K} s_k \log s_k.
$$

The entropy regularization term encourages smoothness in the transport plan $ P $ and in the scalars $ s_1, \dots, s_K $. As $\varepsilon$ becomes smaller, the relaxation becomes closer to the original COT problem.

### Primal-Dual Formulation

The corresponding primal-dual formulation of the entropic relaxation is:

$$
\max_{u, v \in \mathbb{R}^n; \, t \in \mathbb{R}^{K+L}} \min_{P, s} L(u, v, t; P, s)
$$

where the Lagrangian $ L(u, v, t; P, s) $ is defined as:

$$
L(u, v, t; P, s) := \varepsilon \langle P, \log P \rangle + \langle C, P \rangle - \langle u, P\mathbf{1} - a \rangle - \langle v, P^T\mathbf{1} - b \rangle + \varepsilon \sum_{k=1}^K s_k \log s_k + \sum_{k=1}^K t_k s_k + \sum_{m=1}^{K+L} h_m \langle D_m, P \rangle
$$

We define the Lyapunov function as:

$$
f(u, v, t) = \min_{P, s} L(u, v, t; P, s).
$$

The corresponding intermediate transport plan $ P_\varepsilon(u, v, t) $ is given by:

$$
P_\varepsilon (u,v,t) = \exp \left( \frac{1}{\varepsilon} \left( -C + \sum_{m=1}^{K+L} t_m D_m + u \mathbf{1}^T + \mathbf{1} v^T \right) - 1 \right).
$$

The Lyapunov function $ f(u, v, t) $ has the following explicit formulation:

$$
f(u, v, t) = -\varepsilon \sum_{1 \leq i, j \leq n} \exp \left( \frac{1}{\varepsilon} \left( -C_{ij} + \sum_{m=1}^{K+L} t_m (D_m)_{ij} + u_i + v_j \right) - 1 \right)
+ \sum_{i=1}^n u_i a_i + \sum_{j=1}^n v_j b_j - \varepsilon \sum_{k=1}^K \exp \left( - \frac{1}{\varepsilon} t_k - 1 \right)
$$

By the minimax theorem, solving the entropic relaxation problem is equivalent to maximizing $f$. After maximizing for $(u, v, t)$, the corresponding optimal transport plan $ P_\varepsilon^* $ is a smoothed approximation of the original transport matrix $P$. As $\varepsilon \to 0$, the solution $ P_\varepsilon $ converges to the optimal solution of the original CLP problem.

### Constrained Sinkhorn Solver

The `ConstrainedSinkhorn` class is an implementation of the Sinkhorn algorithm tailored for regularized optimal transport with additional constraints. It is designed to solve constrained linear problem (CLP) problems. 

The algorithm iteratively computes a transport plan that satisfies the marginal constraints (given by `a` and `b`) while minimizing the transportation cost subject to the constraints. The solver utilizes both Sinkhorn scaling steps and Newton steps to ensure convergence, with additional flexibility provided by a dual-optimization approach.

The state of the solver is encapsulated in the `ConstrainedSinkhornState` dataclass. This state holds the current values of the dual variables $ u $, $ v $, and $ h $, as well as information about the solver's convergence and error.

#### Fields:
- `u`: The dual variable corresponding to the row scaling (size $ n $).
- `v`: The dual variable corresponding to the column scaling (size $ n $).
- `h`: The dual variable associated with the constraints (size $ K $).
- `error`: (Optional) The current error of the solution, used for convergence checking.
- `converged`: Boolean flag indicating whether the solver has converged.


In [None]:
@dataclass
class ConstrainedSinkhornState:
    """State of the Constrained Sinkhorn solver."""

    u: jnp.ndarray
    v: jnp.ndarray
    h: jnp.ndarray
    error: Optional[float] = None
    converged: bool = False


class ConstrainedSinkhorn:
    """
    Constrained Sinkhorn solver for regularized optimal transport with additional constraints.
    """

    def __init__(
        self,
        use_sns: bool = False,
        N1: int = 20,
        N2: int = 80,
        threshold: float = 1e-3,
        min_iterations: int = 10,
        max_iterations: int = 2000,
        inner_iterations: int = 1,
        implicit_differentiation: bool = True,
    ):
        self.use_sns = use_sns
        self.N1 = N1
        self.N2 = N2
        self.threshold = threshold
        self.min_iterations = min_iterations
        self.max_iterations = max_iterations
        self.inner_iterations = inner_iterations
        self.implicit_differentiation = implicit_differentiation

    def init(self, cot_lp) -> ConstrainedSinkhornState:
        """Initialize dual variables."""
        n = cot_lp.a.shape[0]
        u = jnp.zeros(n)
        v = jnp.zeros(n)
        h = jnp.zeros(cot_lp.constraints["matrices"].shape[2])
        return ConstrainedSinkhornState(u=u, v=v, h=h)

    def _apply_transport(
        self, cot_lp, u: jnp.ndarray, v: jnp.ndarray, h: jnp.ndarray
    ) -> jnp.ndarray:
        """Compute intermediate transport plan."""
        Ds = cot_lp.constraints["matrices"]
        K = cot_lp.constraints["K"]
        modulation = jnp.tensordot(Ds[:, :, :K], h[:K], axes=([2], [0]))
        base_transport = cot_lp.geom.transport_from_potentials(u, v)
        return (
            base_transport
            * jnp.exp(modulation / cot_lp.geom.epsilon)
            / jnp.exp(1)
        )

    def _scaling_step(
        self, cot_lp, u: jnp.ndarray, v: jnp.ndarray, h: jnp.ndarray, axis: int
    ) -> jnp.ndarray:
        """Perform a scaling step (row or column)."""
        transp = self._apply_transport(cot_lp, u, v, h)
        if axis == 1:
            marg = jnp.sum(transp, axis=1)
            target = cot_lp.a
        else:
            marg = jnp.sum(transp, axis=0)
            target = cot_lp.b
        return cot_lp.geom.epsilon * (jnp.log(target) - jnp.log(marg))

    def _lyapunov(
        self, cot_lp, u: jnp.ndarray, v: jnp.ndarray, h: jnp.ndarray
    ) -> jnp.ndarray:
        """Compute Lyapunov function."""
        transp = self._apply_transport(cot_lp, u, v, h)
        epsilon = cot_lp.geom.epsilon
        K = cot_lp.constraints["K"]
        return (
            -epsilon * jnp.sum(transp)
            + jnp.sum(u * cot_lp.a)
            + jnp.sum(v * cot_lp.b)
            - epsilon * jnp.sum(jnp.exp(-(1 / epsilon) * h[:K] - 1))
        )

    def _project_v_perp(self, z: jnp.ndarray) -> jnp.ndarray:
        """Project onto v⊥, v = (1_n, -1_n, 0)"""
        n = z.shape[0] // 3
        v = jnp.concatenate(
            [jnp.ones(n), -jnp.ones(n), jnp.zeros(z.shape[0] - 2 * n)]
        )
        return z - jnp.dot(z, v) / jnp.dot(v, v) * v

    def _newton_step(
        self,
        cot_lp,
        u: jnp.ndarray,
        v: jnp.ndarray,
        convergence_tol: float = 1e-6,
        max_iter: int = 20,
    ) -> jnp.ndarray:
        """Compute Newton step for updating h variables."""

        def objective(w: jnp.ndarray) -> jnp.ndarray:
            return self._lyapunov(
                cot_lp, u + w[:1] * jnp.ones_like(u), v, w[1:]
            )

        grad_fn = jax.grad(objective)
        hess_fn = jax.hessian(objective)
        newton_wrap = NewtonWrapper(obj=objective, grad=grad_fn, hess=hess_fn)

        w = jnp.ones(cot_lp.constraints["matrices"].shape[2] + 1)

        for _ in range(max_iter):
            grad = grad_fn(w)
            hess = hess_fn(w)
            delta = -jnp.linalg.solve(hess, grad)
            alpha = self._backtracking_line_search(newton_wrap, w, delta)
            w = w + alpha * delta

            if jnp.linalg.norm(delta) < convergence_tol:
                break

        if jnp.linalg.norm(delta) >= convergence_tol:
            raise ValueError("Newton method did not converge.")

        return w

    def _backtracking_line_search(
        self,
        newton_wrap,
        t: jnp.ndarray,
        delta: jnp.ndarray,
        alpha_start: float = 1.0,
        rho: float = 0.5,
        c: float = 1e-4,
        max_line_search_iter: int = 20,
    ) -> float:
        """Backtracking line search for Newton step."""
        alpha = alpha_start
        grad = newton_wrap.grad(t)
        slope = jnp.dot(grad, delta)

        for _ in range(max_line_search_iter):
            if (
                newton_wrap.obj(t + alpha * delta)
                <= newton_wrap.obj(t) + c * alpha * slope
            ):
                break
            alpha *= rho

        if alpha < 1e-10:
            raise ValueError("Line search failed.")
        return alpha

    def _one_iteration_sinkhorn(
        self, cot_lp, state: ConstrainedSinkhornState, iteration: int
    ) -> ConstrainedSinkhornState:
        """Perform one iteration of constrained Sinkhorn algorithm.

        Args:
            cot_lp: Constrained linear problem.
            state: Current state of the solver.
            iteration: Current iteration number.

        Returns:
            ConstrainedSinkhornState: Updated state of the solver.
        """

        u, v, h = state.u, state.v, state.h

        # Scale rows
        u += self._scaling_step(cot_lp, u, v, h, axis=1)

        # Scale columns
        v += self._scaling_step(cot_lp, u, v, h, axis=0)

        # Newton step for constraint correction
        w = self._newton_step(cot_lp, u, v)
        u += w[0] * jnp.ones_like(u)
        h += w[1:]

        return ConstrainedSinkhornState(u=u, v=v, h=h)

    def _one_iteration_newton(
        self, cot_lp, state: ConstrainedSinkhornState, iteration: int
    ) -> ConstrainedSinkhornState:
        """Perform one iteration of the Newton method.

        Args:
            cot_lp: Constrained linear problem.
            state: Current state of the solver.
            iteration: Current iteration number.

        Returns:
            ConstrainedSinkhornState: Updated state of the solver.
        """

        u, v, h = state.u, state.v, state.h

        return ConstrainedSinkhornState(u=u, v=v, h=h)

    def __call__(self, cot_lp) -> ConstrainedSinkhornState:
        """Run the constrained Sinkhorn solver.
        Args:
            cot_lp: Constrained linear problem.
        Returns:
            ConstrainedSinkhornState: Final state of the solver.
        """
        state = self.init(cot_lp)

        def body_fn(state, iteration):
            if iteration < self.N1:
                new_state = self._one_iteration_sinkhorn(
                    cot_lp, state, iteration
                )
            elif iteration < self.N1 + self.N2 and self.use_sns:
                new_state = self._one_iteration_newton(cot_lp, state, iteration)
            else:
                new_state = state  # No more updates after N1 + N2
            return new_state, None

        def cond_fn(val):
            state, iteration = val
            return (iteration < self.max_iterations) & (
                (iteration < self.min_iterations)
                | (state.error is None)
                | (state.error > self.threshold)
            )

        iteration = 0
        val = (state, iteration)

        while cond_fn(val):
            state, _ = body_fn(state, iteration)
            iteration += 1
            val = (state, iteration)

        converged = iteration < self.max_iterations
        return ConstrainedSinkhornState(
            u=state.u, v=state.v, h=state.h, converged=converged
        )

## Numerical Experiments