Entropic Regularization of Optimal Transport
============================================

*Important:* Please read the [installation page](http://gpeyre.github.io/numerical-tours/installation_matlab/) for details about how to install the toolboxes.

This numerical tours exposes the general methodology of regularizing the
optimal transport (OT) linear program using entropy. This allows to
derive fast computation algorithm based on iterative projections
according to a Kulback-Leiber divergence.

In [None]:
from __future__ import division

import numpy as np

import matplotlib.pyplot as plt
import scipy as scp
import pylab as pyl

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline
%load_ext autoreload
%autoreload 2

### Helpers

In [None]:
"""To compute distance matrix"""
def distmat(x,y):
    
    return np.sum(x**2,0)[:,None] + np.sum(y**2,0)[None,:] - 2*x.transpose().dot(y)

"""To Normalise a vector"""
normalize = lambda a: a/np.sum(a)

"""To Compute P"""
def GetP(u,K,v):
    return u*K*(v.T)

Entropic Regularization of Optimal Transport
--------------------------------------------
We consider two input histograms $a,b \in \Si_n$, where we denote the simplex in $\RR^n$
$$ \Si_n \eqdef \enscond{ a \in \RR_+^n }{ \sum_i a_i = 1 }.  $$
We consider the following discrete regularized transport
$$  W_\epsilon(a,b) \eqdef \umin{P \in U(a,b)} \dotp{C}{P} - \epsilon E(P).  $$
where the polytope of coupling is defined as
$$ U(a,b) \eqdef \enscond{P \in (\RR^+)^{n \times m}}{ P \ones_m = a, P^\top \ones_n = b },  $$
where $\ones_n \eqdef (1,\ldots,1)^\top \in \RR^n $,
and for $P \in \RR_+^{n \times m}$, we define its entropy as
$$ E(P) \eqdef -\sum_{i,j} P_{i,j} ( \log(P_{i,j}) - 1). $$


When $\epsilon=0$ one recovers the classical (discrete) optimal transport.
We refer to the monograph [Villani](#biblio) for more details about OT.
The idea of regularizing transport to allows for faster computation is
introduced in [Cuturi](#biblio).


Here the matrix $C \in (\RR^+)^{n \times m} $ defines the ground cost, i.e.
$C_{i,j}$ is the cost of moving mass from a bin indexed by $i$ to a bin indexed by $j$.


The regularized transportation problem can be re-written as a projection
$$ W_\epsilon(a,b) = \epsilon \umin{P \in U(a,b)} \KLdiv{P}{K}
	\qwhereq
	K_{i,j} \eqdef e^{ -\frac{C_{i,j}}{\epsilon} }  $$
of the Gibbs kernel $K$ according to the Kullback-Leibler divergence.
The Kullback-Leibler divergence between $P, K \in \RR_+^{n \times m}$ is
$$ \KLdiv{P}{K} \eqdef \sum_{i,j} P_{i,j} \pa{ \log\pa{ \frac{P_{i,j}}{K_{i,j}} } - 1}. $$


This interpretation of regularized transport as a KL projection and its numerical
applications are detailed in [BenamouEtAl](#biblio).


Given a convex set $\Cc \subset \RR^N$, the projection according to the Kullback-Leiber divergence is defined as
$$ \KLproj_\Cc(\xi) = \uargmin{ \pi \in \Cc } \KLdiv{\pi}{\xi}. $$

# I. Introduction to Sinkhorn's Algorithm
-----------------------------------------------------------------------

A fundamental remark is that the optimality condition of the entropic regularized problem shows that the optimal coupling $P_\epsilon$ necessarily has the form 
$$P_\epsilon = \diag{u} K \diag{v}$$
where the Gibbs kernel is defined as
$$K \eqdef e^{-\frac{C}{\epsilon}}.$$

One thus needs to find two positive scaling vectors $u \in \RR_+^n$ and $v \in \RR_+^m$ such that the two following equality holds
$$P \ones  = u \odot (K v) = a 
\qandq
P^\top \ones  = v \odot (K^\top u) = b.$$

Sinkhorn's algorithm alternate between the resolution of these two equations, and reads
$$u \longleftarrow \frac{a}{K v} \qandq v \longleftarrow \frac{b}{K^\top u}.$$
This algorithm was shown to converge to a solution of the entropic regularized problem by [Sinkhorn](#biblio).

Transport Between Point Clouds
------------------------------
We first test the method for two input measures that are uniform measures
(i.e. constant histograms) supported on two point clouds
(that do not necessarily have the same size).


We thus first load two points clouds $x=(x_i)_{i=1}^{n}, y=(y_i)_{i=1}^{m}, $
where $x_i, y_i \in \RR^2$.


Number of points in each cloud, $N=(n,m)$.

In [None]:
X = np.random.rand(100)
Y = np.random.rand(100)
s = np.argsort(X)
print("Indices de tri: \n", s)
print("Echantillon trié:\n", X[s])

In [None]:
N = [400,400]

Dimension of the clouds.

In [None]:
d = 2

Point cloud $x$, of $n$ points inside a square.

In [None]:
x = np.random.rand(2,N[0])-.5

Point cloud $y$, of $m$ points inside an anulus.

In [None]:
theta = 2*np.pi*np.random.rand(1,N[1])
r = .8 + .2*np.random.rand(1,N[1])
y = np.vstack((np.cos(theta)*r,np.sin(theta)*r))

Shortcut for displaying point clouds.

In [None]:
plotp = lambda x,col: plt.scatter(x[0,:], x[1,:], s=200, edgecolors="k", c=col, linewidths=2)

Display of the two clouds.

In [None]:
plt.figure(figsize=(10,10))

plotp(x, 'b')
plotp(y, 'r')

plt.axis("off")
plt.xlim(np.min(y[0,:])-.1,np.max(y[0,:])+.1)
plt.ylim(np.min(y[1,:])-.1,np.max(y[1,:])+.1)

plt.show()

Cost matrix $C_{i,j} = \norm{x_i-y_j}^2$.

In [None]:

C = distmat(x,y)

Target histograms $(a,b)$, here uniform histograms.

In [None]:
a = np.ones(N[0])/N[0]
b = np.ones(N[1])/N[1]

In [None]:
import computational_OT

Gibbs Kernel $K$.

Initialization of $v=\ones_{m}$ ($u$ does not need to be
initialized).

Regularization strength $\epsilon>0$.

__Exercise 1__

Implement Sinkhorn algorithm.
Display the evolution of the constraints satisfaction errors
$$ \norm{ P \ones - a }_1 \qandq \norm{ P^\top \ones - b } $$
(you need to think about how to compute these residuals from $(u,v)$ alone).
isplay the violation of constraint error in log-plot.

In [None]:
epsilon = .01
K = np.exp(-C/epsilon)
u=np.ones(N[0])
v = np.ones(N[1])



SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilon)
out=SOptimizer._update(maxiter=1000)
#
# Plot
plt.figure(figsize = (12,12))

plt.subplot(2,1,1),
plt.title("$||P1 -a||_1$")
plt.plot( np.asarray(out['error_a']), linewidth = 2)
plt.yscale( 'log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")

plt.subplot(2,1,2)
plt.title("$||P^T 1 -b||_1$")
plt.plot( np.asarray(out['error_b']), linewidth = 2)
plt.yscale( 'log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")
plt.show()

# II. Second order iteration: The Newton-Raphson scheme

In [None]:
epsilon = .01
K = np.exp(-C/epsilon)
u=np.ones(N[0])
v = np.ones(N[1])


sinkhorn_run=200
SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilon)

outS=SOptimizer._update(maxiter=sinkhorn_run)
count=sinkhorn_run
maxiteration=1000

outN=None
Su,Sv=u,v
while True :
    print("Checking at: ",count)
    basincheck=computational_OT.Newtonbasincheck(K,a,b,Su,Sv,epsilon)
    if  basincheck._Checkbasinofattraction():
        X = np.hstack( (outS['error_a'],outS['error_b']) )
        X = epsilon*np.log(X)
        NOptimizer= computational_OT.NewtonRaphson(X,K,a,b,epsilon)
        outN=NOptimizer._update(maxiter=10,debug=False)
        break

    else:
        outS=SOptimizer._update(maxiter=sinkhorn_run)
        Su,Sv=outS['error_a'],outS['error_b']
        count +=sinkhorn_run
        if count==maxiteration:
            X = np.hstack( (outS['error_a'],outS['error_b']) )
            X = epsilon*np.log(X)
            NOptimizer= computational_OT.NewtonRaphson(X,K,a,b,epsilon)
            outN=NOptimizer._update(maxiter=10,debug=False)
            break



In [None]:
# Plot
plt.figure(figsize = (12,12))

plt.subplot(2,1,1),
plt.title("$||P1 -a||_1$")
plt.plot( np.asarray(out['error_a']),label='Sinkhorn for $\epsilon=$' + str(epsilon), linewidth = 2)
plt.plot( np.asarray(outS['error_a']+outN['error_a']),label='Hybrid method for $\epsilon=$'+ str(epsilon), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale( 'log')

plt.subplot(2,1,2)
plt.title("$||P^T 1 -b||_1$")
plt.plot( np.asarray(out['error_b']), label='Sinkhorn for $\epsilon=$' + str(epsilon), linewidth = 2)
plt.plot( np.asarray(outS['error_b']+outN['error_b']),label='Hybrid method for $\epsilon=$'+ str(epsilon), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale( 'log')
plt.yscale( 'log')

plt.subplot(2,1,2)
plt.show()


## III. Explore the sparsity of the matrices involved

Compute the final matrix $P$.

In [None]:
P = np.dot(np.dot(np.diag(outS['u']),K),np.diag(outS['v']))
Q1 = np.sort(P, axis=0)
Q2 = np.sort(P, axis=1)

Plot histogram

In [None]:
plt.hist( P.flatten(), 20)
plt.xscale( 'log')
plt.yscale( 'log')
plt.show()

Display it.

In [None]:
plt.figure()
plt.imshow(P);
plt.figure()
plt.imshow(Q1);
plt.figure()
plt.imshow(Q2);

In [None]:
P_xx = np.dot(P, P.T)
P_yy = np.dot(P.T, P)
plt.figure()
plt.imshow( P_xx );
plt.figure()
plt.imshow( P_yy );
plt.show()

plt.figure()
plt.hist( P_xx.flatten(), 20, cumulative=True, density=True)
plt.xscale( 'log')
plt.ylim( (0,1) )
plt.show()
plt.hist( P_yy.flatten(), 20, cumulative=False)
plt.xscale( 'log')
plt.yscale( 'log')
plt.show()

Cuthill-Mckee

In [None]:
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import reverse_cuthill_mckee

def invert_permutation(p):
    """Return an array s with which np.array_equal(arr[p][s], arr) is True.
    The array_like argument p must be some permutation of 0, 1, ..., len(p)-1.
    """
    p = np.asanyarray(p) # in case p is a tuple, etc.
    s = np.empty_like(p)
    s[p] = np.arange(p.size)
    return s

cut_offx=1e-7
cut_offy=1e-7
P_xx_   = P_xx*( P_xx > cut_offx)
P_xx_csr = csr_matrix(P_xx_)
perm_x = reverse_cuthill_mckee(P_xx_csr)
invp_x = invert_permutation(perm_x)

P_yy_ = P_yy*( P_yy > cut_offy)
P_yy_csr = csr_matrix(P_yy_)
perm_y = reverse_cuthill_mckee(P_yy_csr)
invp_y = invert_permutation(perm_y)

mesh = np.meshgrid( perm_x, perm_y )
P_ = P[mesh]
mesh = np.meshgrid( perm_x, perm_x )
P_xx_ = P_xx[mesh]
mesh = np.meshgrid( perm_y, perm_y )
P_yy_ = P_yy[mesh]

size=10
plt.figure( figsize=(size,size))
plt.imshow( P_xx_ );
plt.figure( figsize=(size,size))
plt.imshow( P_yy_ );
plt.figure( figsize=(size,size))
plt.imshow( P_ );
plt.figure( figsize=(size,size))
plt.imshow( P_xx );

## SVD and low rank approximation

In [None]:
u,s,v = np.linalg.svd(P)

plt.figure()
plt.title("Decay of singular values")
plt.plot(s)

### Clustering from Euclidean distance ###

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage
from matplotlib import pyplot as plt

print("Shape:")
print(x.shape)

Z = linkage(x.T, 'centroid')
fig = plt.figure(figsize=(25, 10))
dn = dendrogram(Z)

W = np.array( Z, dtype=int)
#print(W)

count = x.shape[1]
clusters = [ [i] for i in range(count) ]
for merge in W:
    i,j,k,l = merge
    new_cluster = clusters[i] + clusters[j]
    clusters.append( new_cluster )
    
p1 = np.array( clusters[-1] )
print(p1.shape)

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage
from matplotlib import pyplot as plt

print("Shape:")
print(x.shape)

Z = linkage(y.T, 'centroid')
fig = plt.figure(figsize=(25, 10))
dn = dendrogram(Z)

W = np.array( Z, dtype=int)
#print(W)

count = x.shape[1]
clusters = [ [i] for i in range(count) ]
for merge in W:
    i,j,k,l = merge
    new_cluster = clusters[i] + clusters[j]
    clusters.append( new_cluster )
    
p2 = np.array( clusters[-1] )
print(p2.shape)

In [None]:
clustered_P = np.empty_like(P)
clustered_P[np.arange(p1.size)] = P[p1]
clustered_P = clustered_P.T
clustered_P[np.arange(p2.size)] = clustered_P[p2]
clustered_P = clustered_P.T

plt.figure()
plt.imshow(clustered_P);

clustered_P_xx = np.dot( clustered_P, clustered_P.T )
clustered_P_yy = np.dot( clustered_P.T, clustered_P )
plt.figure()
plt.imshow( clustered_P_xx );
plt.figure()
plt.imshow( clustered_P_yy );

### Clustering from the optimal transport ###

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage
from matplotlib import pyplot as plt

print("Shape:")
print( P_xx.shape)

Z = linkage(P_xx, 'complete')
fig = plt.figure(figsize=(25, 10))
dn = dendrogram(Z)

W = np.array( Z, dtype=int)

count = x.shape[1]
clusters = [ [i] for i in range(count) ]
for merge in W:
    i,j,k,l = merge
    new_cluster = clusters[i] + clusters[j]
    clusters.append( new_cluster )
    
p1 = np.array( clusters[-1] )
print(p1.shape)

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage
from matplotlib import pyplot as plt

print("Shape:")
print(y.shape)

Z = linkage(P_yy, 'complete')
fig = plt.figure(figsize=(25, 10))
dn = dendrogram(Z)

W = np.array( Z, dtype=int)
#print(W)

count = x.shape[1]
clusters = [ [i] for i in range(count) ]
for merge in W:
    i,j,k,l = merge
    new_cluster = clusters[i] + clusters[j]
    clusters.append( new_cluster )
    
p2 = np.array( clusters[-1] )
print(p2.shape)

In [None]:
clustered_P = np.empty_like(P)
clustered_P[np.arange(p1.size)] = P[p1]
clustered_P = clustered_P.T
clustered_P[np.arange(p2.size)] = clustered_P[p2]
clustered_P = clustered_P.T

plt.figure()
plt.imshow(clustered_P);

clustered_P_xx = np.dot( clustered_P, clustered_P.T )
clustered_P_yy = np.dot( clustered_P.T, clustered_P )
plt.figure()
plt.imshow( clustered_P_xx );
plt.figure()
plt.imshow( clustered_P_yy );