# Sinkhorn for Constrained OT Problems

In this tutorial, we implement Sinkhorn algorithm for constrained OT problems.
Based on the provided algorithm pseudocode.

In [1]:
## Imports
import jax
import jax.numpy as jnp

from ott.geometry import pointcloud

Define r,c,C,n,D,eta globally to access them from any function

In [2]:
global r,c,C,n,D,eta

In [3]:
def check_constrained_sinkhorn_inputs_consistancy(x_init, y_init, a_init, n_iter):
    assert C.shape == (n,n)
    assert D.shape == (n,n)
    assert r.shape == (n,1)
    assert c.shape == (n,1)
    assert x_init.shape == (n,1)
    assert y_init.shape == (n,1)
    assert a_init.shape == (n,1)
    assert isinstance(eta,float)
    assert n_iter>=0

In [4]:
def compute_P(x, y, a,):
    matrix1 = jnp.ones((n,1))
    sum_amDm = sum(a[m] * D[m] for m in range(n))
    x_dot_M1T = jnp.dot(x,jnp.transpose(matrix1))
    M1_dot_yT = jnp.dot(matrix1,jnp.transpose(y))
    return jnp.exp(eta * (-C + sum_amDm + x_dot_M1T + M1_dot_yT )- 1.0)

# Implementation of Algorithm 1

In [None]:
def constrained_sinkhorn( x_init, y_init, a_init, n_iter=100):
  check_constrained_sinkhorn_inputs_consistancy(x_init, y_init, a_init, n_iter)
  x, y, a = x_init, y_init, a_init
  matrix1 = jnp.ones((n,1))
  for _ in range(n_iter):
    P = compute_P(x, y, a,)

    x = x + (jnp.log(r) - jnp.log(jnp.dot(P, matrix1)))/ eta

    P = compute_P(x, y, a,)

    y = y + (jnp.log(c) - jnp.log(jnp.dot(jnp.transpose(P), matrix1))) / eta

    a, t = optimize(x, y, a,)
    x = x + t*matrix1

  return x, y, a

In [6]:
def compute_gradient(x, y, a):
    """
    Compute gradient w.r.t. a and t.
    """
    P = compute_P(x, y, a)

    grad_a = jnp.exp(-eta * a - 1) - jnp.dot(jnp.dot(D,P),jnp.ones((n,1)))

    grad_x = r - P @ jnp.ones((n, 1))
    grad_t = jnp.sum(grad_x)

    return jnp.vstack([grad_a, grad_t])

Implementation of f(x+1t,y,a) using equation (7)

In [7]:
def f_at(a,t,x,y):
    return f(x+t*jnp.ones((n,1)),y,a)

def f(x,y,a):
    return (-1/eta * sum([jnp.exp(eta * (-C[i,j]
                            + sum(a[m] * D[i,j] for m in range(n))
                            +x[i]+y[j] ))
                            for i in range(n) for j in range(n)])
                     + sum(x[i]*r[i] for i in range(n))
                     + sum(x[j]*r[j] for j in range(n))
                     - 1/eta * (sum(- eta * a[k] -1 for k in range (n))))

In [8]:
def compute_hessian(x, y, a):
    P = compute_P(x, y, a)

    H = jnp.zeros((n + 1, n + 1))

    diag_a = eta ** 2 * jnp.exp(-eta * a - 1).ravel()
    H = H.at[:n, :n].set(jnp.diag(diag_a))

    diag_P1 = jnp.diag(P @ jnp.ones((n, 1)).ravel())
    H_tt = jnp.sum(diag_P1)
    return H.at[n, n].set(H_tt)

In [9]:
def newton_step(x, y, a, t):
    """
    Perform one Newton step on (a, t).
    """
    grad = compute_gradient(x, y, a)
    H = compute_hessian(x, y, a)

    delta = jnp.linalg.solve(H, -grad)

    delta_a = delta[:n]
    delta_t = delta[-1]

    return delta_a, delta_t

Implementation of the standard backtracking line search scheme (Boyd and Vandenberghe, 2004).

In [10]:
def backtracking_line_search(f, a, t, delta_a, delta_t, alpha=0.4, beta=0.7):
    """
    Basic backtracking line search.
    """
    s = 1.0
    f_current = f(a, t)
    grad_norm = jnp.linalg.norm(jnp.vstack([delta_a, delta_t]))
    while (f(a + s * delta_a, t + s * delta_t)
           < f_current + alpha * s * grad_norm):
        s *= beta
    return s

Compute a,t using standard backtracking method and newton_step

In [11]:
def optimize(x, y, a,  num_steps=10):
    """
    Run Newton optimization over (a, t).
    """
    t = 0

    for _ in range(num_steps):
        delta_a, delta_t = newton_step(x, y, a, t)
        def f(a,t):
            return f_at(a,t,x,y)
        step_size = backtracking_line_search(f, a, t, delta_a, delta_t)

        a = a + step_size * delta_a
        t = t + step_size * delta_t

    return a, t

Example of usage in constrained problem

In [12]:
def sample_points(n=50, seed=0):
  rngs = jax.random.split(jax.random.key(seed), 2)
  r = jnp.abs(jax.random.normal(rngs[0], (n, 1)))
  c = jnp.abs(jax.random.normal(rngs[1], (n, 1)))
  return r/jnp.sum(r), c/jnp.sum(c)



In [None]:
n=50
r,c = sample_points(n,seed=1)
geom = pointcloud.PointCloud(r,c)
C = geom.cost_matrix
D = jnp.ones_like(C)
eta = 1.
x_init = jnp.zeros((n,1))
y_init = jnp.zeros((n,1))
a_init = jnp.zeros((n,1))

x_out, y_out, a_out = constrained_sinkhorn(x_init, y_init, a_init)



This implementation seems to work but is really slow as I didn't managed to faithfully reproduce their implementation of the step 

$a,t\leftarrow argmax_{\tilde{a},\tilde{t}}f(x +\tilde{t} \mathbb{1}, y,\tilde{a})$

Hence, my implementation is not in $O(n^2)$