Entropic Regularization of Optimal Transport
============================================

*Important:* Please read the [installation page](http://gpeyre.github.io/numerical-tours/installation_matlab/) for details about how to install the toolboxes.

This numerical tours exposes the general methodology of regularizing the
optimal transport (OT) linear program using entropy. This allows to
derive fast computation algorithm based on iterative projections
according to a Kulback-Leiber divergence.

In [None]:
from __future__ import division

import numpy as np
import time
import matplotlib.pyplot as plt
import scipy as scp
import pylab as pyl

import warnings
warnings.filterwarnings('ignore')
np.random.seed(1234)

%matplotlib inline
%load_ext autoreload
%autoreload 2

### Helpers

In [None]:
"""To compute distance matrix"""
def distmat(x,y):
    
    return np.sum(x**2,0)[:,None] + np.sum(y**2,0)[None,:] - 2*x.transpose().dot(y)

"""To Normalise a vector"""
normalize = lambda a: a/np.sum(a)

"""To Compute P"""
def GetP(u,K,v):
    return u*K*(v.T)

def plotp(x, col,plt, scale=200, edgecolors="k"):
  return plt.scatter(x[0,:], x[1,:], s=scale, edgecolors=edgecolors,  c=col, cmap='plasma', linewidths=2)

In [None]:
N=[400,400]

In [None]:
x=np.random.rand(2,N[0])-0.5
theta =2*np.pi*np.random.rand(1,N[1])
r=0.8+.2*np.random.rand(1,N[1])
y=np.vstack((r*np.cos(theta),r*np.sin(theta)))

In [None]:

plt.figure(figsize=(10,10))

plotp(x, 'b',plt)
plotp(y, 'r',plt)

plt.axis("off")
plt.xlim(np.min(y[0,:])-.1,np.max(y[0,:])+.1)
plt.ylim(np.min(y[1,:])-.1,np.max(y[1,:])+.1)
plt.savefig("AnnulusvsSquare.png")
plt.show()

In [None]:
C = distmat(x,y)

In [None]:
a = np.ones(N[0])/N[0]
b = np.ones(N[1])/N[1]

In [None]:
import computational_OT

In [None]:
epsilon = .06
K = np.exp(-C/epsilon)
u=np.ones(N[0])
v = np.ones(N[1])



SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilon)
out=SOptimizer._update(maxiter=1000)
#
# Plot
plt.figure(figsize = (12,12))

plt.subplot(2,1,1),
plt.title("$||P1 -a||_1$")
plt.plot( np.asarray(out[2]), linewidth = 2)
plt.yscale( 'log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")

plt.subplot(2,1,2)
plt.title("$||P^T 1 -b||_1$")
plt.plot( np.asarray(out[3]), linewidth = 2)
plt.yscale( 'log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")
plt.show()

In [None]:
epsilon = .06
K = np.exp(-C/epsilon)
u=np.ones(N[0])
v = np.ones(N[1])


SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilon)
outS=SOptimizer._update(maxiter=110)


X = np.hstack( (outS[0].T,outS[1].T) )
X = epsilon*np.log(X)
NOptimizer=computational_OT.NewtonRaphson(X,K,a,b,epsilon)
outN=NOptimizer._update(maxiter=10, debug=False)



In [None]:
# Plot
plt.figure(figsize = (12,6))

#plt.subplot(2,1,1),
plt.title("$||P1 -a||_1 + ||P^T 1 -b||_1$")
error_sinkhorn = np.asarray(out[2]) + np.asarray(out[3])
error_hybrid   = np.asarray(outS[2]+outN[0]) + np.asarray(outS[3]+outN[1])
plt.plot( error_sinkhorn, label='Sinkhorn for $\epsilon=$' + str(epsilon), linewidth = 2)
plt.plot( error_hybrid,label='Hybrid method for $\epsilon=$'+ str(epsilon), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale( 'log')
plt.savefig("SinkhornNewton.png")
plt.show()

## Experiments for different epsilons

In [None]:
epsilons=[0.01,0.02,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5]
n=len(epsilons)
outS={}
outN={}
P={}

for i in range(n):
    K = np.exp(-C/epsilons[i])
    u=np.ones(N[0])
    v=np.ones(N[1])

    SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilons[i])
    out1=SOptimizer._update(maxiter=1000)
    outS[i]=out1

    X = np.hstack( (outS[i][0],outS[i][1]) )
    X = epsilons[i]*np.log(X)
    NOptimizer=computational_OT.NewtonRaphson(X,K,a,b,epsilons[i])
    out2=NOptimizer._update(maxiter=10, debug=False)
    outN[i]=out2
    
    
    P[i] = np.dot(np.dot(np.diag(out1[0]),K),np.diag(out1[1]))
    P_xx = np.dot(P[i], P[i].T)
    P_yy = np.dot(P[i].T, P[i])
    _,ax=plt.subplots(figsize=(20,5),nrows=1,ncols=3)

    ax[0].set_title("P$_{\epsilon}$ histogram for $\epsilon$: "+str(epsilons[i]))

    ax[0].hist(P[i].flatten(), 20)
    ax[0].set_xscale("log")
    ax[0].set_yscale("log")
    ax[1].set_title("P$_{\epsilon}$P$^{T}_{\epsilon}$ flattened and $\epsilon$: "+str(epsilons[i]))
    ax[1].hist( P_xx.flatten(), 20, cumulative=False)
    ax[1].set_xscale("log")
    ax[1].set_yscale("log")
    
    ax[2].set_title("P$^{T}_{\epsilon}$P$_{\epsilon}$ flattened and $\epsilon$: "+str(epsilons[i]))
    ax[2].hist( P_yy.flatten(), 20, cumulative=False)
    ax[2].set_xscale("log")
    ax[2].set_yscale("log")
    plt.savefig("Phist"+str(i)+".png")
    plt.show()
    print("\n \n")
    

### Compute the cutoffs

In [None]:
q=0.8
cutoff_x=[]
cutoff_y=[]
for i in range(n):
    cutoff_x.append(np.quantile(np.dot(P[0],P[0].T).flatten(),q))
    cutoff_y.append(np.quantile(np.dot(P[0].T,P[0]).flatten(),q))


## Cuthill Mckee

In [None]:
for i in range(n):
    CuthillMckee=computational_OT._Expcuthill_mckee(P[i])
    CuthillMckee._evaluate(cut_offx=cutoff_x[i],cut_offy=cutoff_y[i],epsilon=epsilons[i],index=i)


In [None]:
# Plot
plt.figure(figsize = (15,6))

#plt.subplot(2,1,1),
plt.title("$||P1 -a||_1 + ||P^T 1 -b||_1$")
for i in range(n):
    error_hybrid   = np.asarray(outS[i][2]) + np.asarray(outS[i][3])
    plt.plot( error_hybrid,label='Sinkhorn for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)

plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale( 'log')
plt.savefig("Sinkhornvaryepsilon.png")
plt.show()

### Nested Dissection

In [None]:
for i in range(n): 
    nd=computational_OT.NestedDissection(P[i],stopdim=50 )
    nd._evaluate(cutoff_x[i],cutoff_y[i],epsilons[i],index=i)
    

### Damped Newton

In [None]:
rho=0.95
c1=0.05
DampedNewtonP=[]
results_DampedNewton  = []
times_DampedNewton    = []
Hessians_DampedNewton = []

#epsilons=[0.05,0.08,0.1]
epsilons=[0.1,0.2, 0.3, 0.4, 0.5, 0.75, 1.0 ]
for eps in epsilons:
    # Line Search
    print("Damped Newton for epsilon="+str(eps)+":")    
    #Cost matrix
    C = distmat(x,y)

    # a and b
    a = normalize(np.ones(N[0]))
    a=a.reshape(a.shape[0],-1)
    b = normalize(np.ones(N[1]))
    b=b.reshape(b.shape[0],-1)

    #Epsilon 

    # epsilon = .05

    #Kernel
    K=np.exp(-C/eps)

    f,g=a,b

    print("Doing for (",N[0],N[1],").")
    print( " |- Iterating")
    start=time.time()
    Optimizer=computational_OT.DampedNewton(K,a,b,f,g,eps,rho,c1)
    out=Optimizer._update(maxiter=100)
    results_DampedNewton.append(out)
    end=time.time()
    times_DampedNewton.append(end-start)
    print( " |- Computing P")
    DampedNewtonP.append(GetP(np.exp(out[0]/eps),K,np.exp(out[1]/eps)))
    print( " |- Recording (unstabilized) Hessian")
    mat  = -eps*Optimizer.Hessian
    diag = 1/np.sqrt( np.vstack( (a,b) ) ).flatten()
    mat = diag*mat*diag
    Hessians_DampedNewton.append( mat )

    print( "" )


In [None]:
plt.figure(figsize = (20,7))
plt.title("$$")

plt.title("$||P1 -a||_1+||P^T 1 -b||_1$")

for i in range(len(results_DampedNewton)):
  error=np.asarray(results_DampedNewton[i][2])+np.asarray(results_DampedNewton[i][3])
  plt.plot( error,label='Damped Newton for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)

plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale( 'log')
plt.savefig("ErrorLinesearchNewton.png")
plt.show()

print("\n Error plots can increase! The error is not the objective function!")

In [None]:
plt.figure(figsize = (20,7))
plt.title("$$")

plt.title("Objective Function")

for i in range(len(results_DampedNewton)):
  plt.plot( np.asarray(results_DampedNewton[i][4]),label='Damped Newton for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)

plt.xlabel("Number of iterations")
plt.ylabel("Objective value")
plt.legend()
plt.savefig("ObjectiveLineSearchNewton.png")
plt.show()

In [None]:
plt.figure(figsize = (20,7))
plt.subplot(2,1,1),
plt.title("Alpha")

for i in range(len(results_DampedNewton)):
  plt.plot( np.asarray(results_DampedNewton[i][5]),label='Damped Newton for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)

plt.xlabel("Number of iterations")
plt.ylabel("Alpha in log-scale")
plt.legend()
# plt.yscale( 'log')
plt.savefig("AlphaLineSearchNewton.png")
plt.show()



In [None]:
def print_spectral_statistics(mat, stabilize=False):
    if stabilize:
        # Stabilizing largest and smallest eigenvalue
        min_vector = np.hstack( (np.ones(N[0]), -np.ones(N[1])) )
        max_vector = np.hstack( (np.ones(N[0]),  np.ones(N[1])) )
        norm = np.sqrt( N[0] + N[1] )
        min_vector = min_vector/norm
        max_vector = max_vector/norm
        min_vector = min_vector.reshape( (min_vector.shape[0], 1) )
        max_vector = max_vector.reshape( (max_vector.shape[0], 1) )
        #
        mat = mat + np.dot( min_vector, min_vector.T)
        mat = mat - np.dot( max_vector, max_vector.T)
    # endif
    eig, v = np.linalg.eig( mat )
    sorting_indices = np.argsort(eig)
    eig = eig[sorting_indices]
    v   = v[:, sorting_indices]
    #print( "Mean eigenvalue: ", np.mean(eig) )
    print( "List of smallest eigenvalues: ", eig[:10])
    print( "List of largest  eigenvalues: ", eig[-10:])
    min_index = np.argmin(eig)
    max_index = np.argmax(eig)
    min_value = eig[ min_index ]
    max_value = eig[max_index]
    min_vector = v[:, min_index]
    min_vector = min_vector/min_vector[0]
    max_vector = v[:,max_index]
    max_vector = max_vector/max_vector[0]
    condition_number = max_value/min_value
    # Test smallest and largest
    # print( "Min eigenvalue vector: ", min_vector)
    # print( "Max eigenvalue vector: ", max_vector)
    #
    #print( v[:,0]*np.sqrt( self.N1 + self.N2))
    #vector = v[:,0]
    #test = np.dot( result, vector)
    #print( np.linalg.norm(test) )
    #print("Min absolute eigenvalues: ", min_value)
    #print("Norm of v-1: ", np.linalg.norm(min_vector-eig_vector))
    print("Condition number: ", condition_number)
    plt.hist( eig, 50)
    plt.title( "Histogram of eigenvalues for Hessian")
    plt.xlabel( "Eigenvalues")
    plt.yscale( "log" )
    plt.show()

In [None]:
for i in range(len(epsilons)):
#for i in range(len(epsilons)-1,len(epsilons)):
    eps = epsilons[i]
    print("Spectral statistics of Hessian for epsilon="+str(eps))
    Hessian = Hessians_DampedNewton[i]
    print_spectral_statistics( Hessian, stabilize=False )
    print("")


### Annulus vs Rotated Annulus

In [None]:
N=[400,400]

### Sampled Annulus

In [None]:
random = 0.05*np.random.normal(size=N[0])
random2 = np.floor(3*np.random.uniform(size=N[0]))/3
biased_unif = (random + random2) % 1
theta   = 2*np.pi*biased_unif
r=1.0
X=np.vstack((np.cos(theta)*r,np.sin(theta)*r))

In [None]:
Angleorderedmatrices=[]

### Rotated Annulus

In [None]:
Y=np.vstack((np.cos(theta+(np.pi/2))*r,np.sin(theta+(np.pi/2))*r))


In [None]:
_,ax=plt.subplots(figsize=(10,5),nrows=1,ncols=2)

plotp(X, 'b',ax[0],scale=50, edgecolors=(0,0,0,0))
plotp(Y, 'r',ax[1],scale=50, edgecolors=(0,0,0,0))

# ax.axis("off")
ax[0].axis("off")
ax[1].axis("off")

# ax[0].xlim(np.min(X[0,:])-.1,np.max(X[0,:])+.1)
# ax[1].ylim(np.min(Y[1,:])-.1,np.max(Y[1,:])+.1)
plt.savefig("SuperimposedAnnulusvsRotatedAnnulus1.png")
plt.show()

In [None]:
_,ax=plt.subplots(figsize=(8,8),nrows=1,ncols=1)

plotp(X, 'b',ax,scale=50, edgecolors=(0,0,0,0))
plotp(Y, 'r',ax,scale=50, edgecolors=(0,0,0,0))

ax.axis("off")
# ax[0].axis("off")
# ax[1].axis("off")

# ax[0].xlim(np.min(x[0,:])-.1,np.max(x[0,:])+.1)
# ax[1].ylim(np.min(y[1,:])-.1,np.max(y[1,:])+.1)
plt.savefig("SuperimposedAnnulusvsRotatedAnnulus.png")
plt.show()


In [None]:
import math
C=np.zeros((N[0],N[1]))
for i in range(N[0]):
    for j in range(N[1]):
        dotproduct=np.dot(X[:,i],Y[:,j])
        norm=np.sqrt(np.linalg.norm(X[:,i])*np.linalg.norm(Y[:,j]))
        angle=np.arccos(dotproduct/norm)
        C[i][j]=r*angle
# C=distmat(X,Y)

In [None]:
# a and b
a = normalize(np.ones(N[0]))
a=a.reshape(a.shape[0],-1)
b = normalize(np.ones(N[1]))
b=b.reshape(b.shape[0],-1)

In [None]:
epsilon = .2

K = np.exp(-C/epsilon)
u=np.ones(N[0])
v = np.ones(N[1])


SOptimizer=computational_OT.Sinkhorn(K,a,b,u,v,epsilon)
outS=SOptimizer._update(maxiter=1000)


# Plot
plt.figure(figsize = (12,12))

plt.subplot(2,1,1),
plt.title("$||P1 -a||_1$")
plt.plot( np.asarray(outS[2]), linewidth = 2)
plt.yscale('log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")

plt.subplot(2,1,2)
plt.title("$||P^T 1 -b||_1$")
plt.plot( np.asarray(outS[3]), linewidth = 2)
plt.yscale('log')
plt.ylabel("Error in log scale")
plt.xlabel("Number of iterations")
plt.legend(["Sample size: "+str(i)+" and Epsilon="+str(epsilon) for i in N],loc="upper right")
plt.show()

In [None]:
Ptest = outS[0]*K*(outS[1].T)
P_xx = np.dot(Ptest, Ptest.T)
P_yy = np.dot(Ptest.T, Ptest)
_,ax=plt.subplots(figsize=(20,5),nrows=1,ncols=3)

ax[0].set_title("P histogram for epsilon: "+str(epsilon))

ax[0].hist(Ptest.flatten(), 20)
ax[0].set_xscale("log")
ax[0].set_yscale("log")
ax[1].set_title("P_xx flattened and e: "+str(epsilon))
ax[1].hist( P_xx.flatten(), 20, cumulative=False)
ax[1].set_xscale("log")
ax[1].set_yscale("log")

ax[2].set_title("P_yy flattened and e: "+str(epsilon))
ax[2].hist( P_yy.flatten(), 20, cumulative=False)
ax[2].set_xscale("log")
ax[2].set_yscale("log")
plt.savefig("PhistAnnulusvsRotatedAnnulus"+str(epsilon)+".png")
plt.show()
print("\n \n")


In [None]:
q=0.8
cutoff_x=0
cutoff_y=0
cutoff_x=np.quantile(np.dot(Ptest[0],Ptest[0].T).flatten(),q)
cutoff_y=np.quantile(np.dot(Ptest[0].T,Ptest[0]).flatten(),q)


In [None]:
sorted_theta=np.sort(theta)


perm_matrix=np.zeros((400,400))
for i in range(400):
    for j in range(400):
        if theta[i]==sorted_theta[j]:
            perm_matrix[i][j]=1

P_=np.dot(Ptest,perm_matrix)
P_=np.dot(perm_matrix.T,P_)
Angleorderedmatrices.append(P_)



In [None]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
CuthillMckee=computational_OT._Expcuthill_mckee(Ptest)
CuthillMckee._evaluate(cut_offx=cutoff_x,cut_offy=cutoff_y,epsilon=epsilon,index=epsilon)

In [None]:
plt.figure(figsize=(8,8))

plotp(X, 'b',plt,scale=50, edgecolors=(0,0,0,0))
plotp(Y, 'r',plt,scale=50, edgecolors=(0,0,0,0))


A = Ptest * (Ptest > np.max(Ptest)*.8)
i,j = np.where(A != 0)
plt.plot([X[0,i],Y[0,j]],[X[1,i],Y[1,j]],'k',lw = 2)


A = Ptest * (Ptest > np.max(Ptest)*.2)
i,j = np.where(A != 0)
plt.plot([X[0,i],Y[0,j]],[X[1,i],Y[1,j]],'k',lw = 1)

plt.axis("off")
plt.xlim(np.min(Y[0,:])-.1,np.max(Y[0,:])+.1)
plt.ylim(np.min(Y[1,:])-.1,np.max(Y[1,:])+.1)
plt.savefig("FinalAnnulusvsRotatedAnnulus"+str(epsilon)+".png")
plt.show()

In [None]:
_,ax=plt.subplots(figsize=(20,5),nrows=1,ncols=5)
ax[0].set_title("For epsilon=0.01.")
ax[0].imshow(Angleorderedmatrices[0])
ax[1].set_title("For epsilon=0.02.")

ax[1].imshow(Angleorderedmatrices[1])
ax[2].set_title("For epsilon=0.05.")

ax[2].imshow(Angleorderedmatrices[2])
ax[3].set_title("For epsilon=0.1.")

ax[3].imshow(Angleorderedmatrices[3])
ax[4].set_title("For epsilon=0.2.")

ax[4].imshow(Angleorderedmatrices[4])

plt.savefig("Angleorderedmatrices.png")
plt.show()
print("\n \n")

