Entropic Regularization of Optimal Transport
============================================

*Important:* Please read the [installation page](http://gpeyre.github.io/numerical-tours/installation_matlab/) for details about how to install the toolboxes.

This numerical tours exposes the general methodology of regularizing the
optimal transport (OT) linear program using entropy. This allows to
derive fast computation algorithm based on iterative projections
according to a Kulback-Leiber divergence.

In [None]:
from __future__ import division

import numpy as np

import matplotlib.pyplot as plt
import scipy as scp
import pylab as pyl

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline
%load_ext autoreload
%autoreload 2

### Helpers

In [None]:
"""To compute distance matrix"""
def distmat(x,y):
    
    return np.sum(x**2,0)[:,None] + np.sum(y**2,0)[None,:] - 2*x.transpose().dot(y)

"""To Normalise a vector"""
normalize = lambda a: a/np.sum(a)

"""To Compute P"""
def GetP(u,K,v):
    return u*K*(v.T)

In [None]:
N=[400,400]

In [None]:
x=np.random.rand(2,N[0])-0.5
theta =2*np.pi*np.random.rand(1,N[1])
r=0.8+.2*np.random.rand(1,N[1])
y=np.vstack((r*np.cos(theta),r*np.sin(theta)))

In [None]:
def plotp(x, col,plt, scale=200, edgecolors="k"):
  return plt.scatter(x[0,:], x[1,:], s=scale, edgecolors=edgecolors,  c=col, cmap='plasma', linewidths=2)

In [None]:

plt.figure(figsize=(10,10))

plotp(x, 'b',plt)
plotp(y, 'r',plt)

plt.axis("off")
plt.xlim(np.min(y[0,:])-.1,np.max(y[0,:])+.1)
plt.ylim(np.min(y[1,:])-.1,np.max(y[1,:])+.1)

plt.show()

In [None]:
C = distmat(x,y)

In [None]:
a = np.ones(N[0])/N[0]
b = np.ones(N[1])/N[1]

In [None]:
import computational_OT

In [None]:
epsilon = .01
K = np.exp(-C/epsilon)
u=np.ones(N[0])
v = np.ones(N[1])



SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilon)
out=SOptimizer._update(maxiter=1000)
#
# Plot
plt.figure(figsize = (12,12))

plt.subplot(2,1,1),
plt.title("$||P1 -a||_1$")
plt.plot( np.asarray(out[2]), linewidth = 2)
plt.yscale( 'log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")

plt.subplot(2,1,2)
plt.title("$||P^T 1 -b||_1$")
plt.plot( np.asarray(out[3]), linewidth = 2)
plt.yscale( 'log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")
plt.show()

In [None]:
epsilon = .01
K = np.exp(-C/epsilon)
u=np.ones(N[0])
v = np.ones(N[1])


SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilon)
outS=SOptimizer._update(maxiter=100)


X = np.hstack( (outS[0],outS[1]) )
X = epsilon*np.log(X)
NOptimizer=computational_OT.NewtonRaphson(X,K,a,b,epsilon)
outN=NOptimizer._update(maxiter=10, debug=False)

In [None]:
# Plot
plt.figure(figsize = (12,6))

#plt.subplot(2,1,1),
plt.title("$||P1 -a||_1 + ||P^T 1 -b||_1$")
error_sinkhorn = np.asarray(out[2]) + np.asarray(out[3])
error_hybrid   = np.asarray(outS[2]+outN[0]) + np.asarray(outS[3]+outN[1])
plt.plot( error_sinkhorn, label='Sinkhorn for $\epsilon=$' + str(epsilon), linewidth = 2)
plt.plot( error_hybrid,label='Hybrid method for $\epsilon=$'+ str(epsilon), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale( 'log')
plt.show()

In [None]:
epsilons=[0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5]
n=len(epsilons)
outS=[]
outN=[]
for i in range(n):
    K = np.exp(-C/epsilons[i])
    u=np.ones(N[0])
    v=np.ones(N[1])

    SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilons[i])
    outS.append(SOptimizer._update(maxiter=100))

    X = np.hstack( (outS[i][0],outS[i][1]) )
    X = epsilons[i]*np.log(X)
    NOptimizer=computational_OT.NewtonRaphson(X,K,a,b,epsilons[i])
    outN.append(NOptimizer._update(maxiter=10, debug=False))
    


In [None]:
P=[]
Q1=[]
Q2=[]

for i in range(n):
    P.append(np.dot(np.dot(np.diag(outS[i][0]),K),np.diag(outS[i][1])))
    Q1.append(np.sort(P, axis=0))
    Q2.append(np.sort(P, axis=1))




for i in range(n):
    plt.title("For epsilons: "+ str(epsilons[i]))
    plt.hist( P[i].flatten(), 20)
    plt.xscale( 'log')
    plt.yscale( 'log')
    plt.show()


In [None]:
# Plot
plt.figure(figsize = (15,6))

#plt.subplot(2,1,1),
plt.title("$||P1 -a||_1 + ||P^T 1 -b||_1$")
for i in range(n):
    error_hybrid   = np.asarray(outS[i][2]+outN[i][0]) + np.asarray(outS[i][3]+outN[i][1])
    plt.plot( error_hybrid,label='Hybrid method for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)

plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale( 'log')
plt.show()

In [None]:
P=[]
Q1=[]
Q2=[]
for i in range(n):
    P.append(np.dot(np.dot(np.diag(outS[i][0]),K),np.diag(outS[i][1])))
    Q1.append(np.sort(P, axis=0))
    Q2.append(np.sort(P, axis=1))

In [None]:
for i in range(n):
    fig,ax=plt.subplots(figsize=(20,5),nrows=1,ncols=3)
    ax[0].set_title("P and e: "+ str(epsilons[i]))
    ax[0].imshow(P[i]);
    ax[1].set_title("Q1 and e: "+ str(epsilons[i]))
    ax[1].imshow(Q1[i][0]);
    ax[2].set_title("Q2 and e: "+ str(epsilons[i]))
    ax[2].imshow(Q2[i][0]);
    plt.show()




In [None]:
for i in range(n):
    fig,ax=plt.subplots(figsize=(20,5),nrows=1,ncols=4)
    P_xx = np.dot(P[i], P[i].T)
    P_yy = np.dot(P[i].T, P[i])
    ax[0].set_title("P_xx and e: "+str(epsilons[i]))
    ax[0].imshow(P_xx)
    ax[1].set_title("P_yy and e: "+str(epsilons[i]))
    ax[1].imshow(P_yy)


    ax[2].set_title("P_xx flattened and e: "+str(epsilons[i]))
    ax[2].hist( P_xx.flatten(), 20, cumulative=False)
    ax[2].set_xscale("log")
    ax[2].set_yscale("log")
    ax[3].set_title("P_yy flattened and e: "+str(epsilons[i]))
    ax[3].hist( P_yy.flatten(), 20, cumulative=False)
    ax[3].set_xscale("log")
    ax[3].set_yscale("log")
    plt.show()



In [None]:
cutoff_x=[1e4,1e-4,1e-6,1e-7,1e-7,1e-8,1e-8,1e-8,1e-8,1e-8,1e-8]
cutoff_y=[1e4,1e-4,1e-6,1e-8,1e-7,4e-8,4e-8,2e-8,2e-8,1e-8,1e-8]
for i in range(n):
    CuthillMckee=computational_OT._Expcuthill_mckee(P[i])
    CuthillMckee._evaluate(cut_offx=cutoff_x[i],cut_offy=cutoff_y[i],epsilon=epsilons[i])
