In [None]:
from __future__ import division
import os
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
import scipy as scp
import pylab as pyl
import warnings

torch.set_default_dtype(torch.float64)
warnings.filterwarnings('ignore')
np.random.seed(1234)
torch.manual_seed(1234)

%matplotlib inline
%load_ext autoreload
%autoreload 




















































































#### Creating directory for image outputs


In [None]:
if not os.path.isdir('SinkhornVersions_images'):
    os.makedirs('SinkhornVersions_images')

#### To compute distance matrix


In [None]:
def distmat(x,y):
    return np.sum( x**2,0 )[:,None] + np.sum( y**2,0 )[None,:] - 2*x.transpose().dot(y)

def torchdistmat(x,y):
   return torch.sum( x**2,0 )[:,None]+torch.sum(y**2,0)[None,:]- 2*torch.matmul(x.t(),y)

#### Normalize vector

In [None]:
normalize = lambda a: a/np.sum( a )
torchnormalize = lambda a: a/torch.sum( a )

#### Compute P and plot

In [None]:
def GetP(u,K,v):
    return u[:,None]*K*v[None,:]

def plotp(x, col,plt, scale=200, edgecolors="k"):
  return plt.scatter( x[0,:], x[1,:], s=scale, edgecolors=edgecolors,  c=col, cmap='plasma', linewidths=2 )

In [None]:
import computational_OT

In [None]:
N = [ 400, 500 ]
x = np.random.rand( 2,N[0] )-0.5
theta = 2*np.pi*np.random.rand( 1,N[1] )
r = 0.8+.2*np.random.rand( 1,N[1] )
y = np.vstack( ( r*np.cos( theta ),r*np.sin( theta ) ) )

In [None]:
# Creating arrays/tensors
numpy_array = np.array([1.23456789])
torch_tensor = torch.tensor([1.23456789])

# Printing with default settings
print(numpy_array)
print(torch_tensor)
print(numpy_array==torch_tensor)

In [None]:
np.set_printoptions(precision=8)
torch.set_printoptions(precision=8)

print(numpy_array)
print(torch_tensor)

In [None]:
epsilons = [0.1, 0.05, 0.01, 0.005, 0.001 ]

#### Sinkhorn

In [None]:
def Sinkhorn(epsilons,N,x,y,iterations = 1000):
    # Sinkhorn
    print("Sinkhorn.... ")
    SinkhornP         = []
    results_Sinkhorn  = []
    times_Sinkhorn    = []
    # a and b   
    a = normalize(np.ones(N[0]))
    a = a.reshape(a.shape[0],-1)
    b = normalize(np.ones(N[1]))
    b = b.reshape(b.shape[0],-1)
    for eps in epsilons:
        print( "Sinkhorn for epsilon = "+str(eps)+":" )    
        #Cost matrix
        C = distmat(x,y)
        #Kernel
        K = np.exp(-C/eps)
        print("Doing for (",N[0],N[1],").")
        print( " |- Iterating")
        #Inflating
        u = a
        v = b
        start = time.time()
        Optimizer = computational_OT.Sinkhorn(K,a,b,u,v,eps)
        out       = Optimizer._update(maxiter = iterations)
        results_Sinkhorn.append(out)
        end = time.time()
        times_Sinkhorn.append(1e-3*(end-start))
        print( " |- Computing P")
        print( "" )
        SinkhornP.append(GetP(np.exp(out['potential_f']/eps),K,np.exp(out['potential_g']/eps)))
    return {
        'results_list': results_Sinkhorn,
        'time_stamps' : times_Sinkhorn,
        'Ps'          : SinkhornP
    }
    

#### Log-domain Sinkhorn

In [None]:
def log_domain_Sinkhorn(epsilons,N,x,y, iterations = 1000):
  print("Log domain Sinkhorn.... ")
  results_logSinkhorn = []
  times_logSinkhorn   = []
  logSinkhornP        = []
  a = normalize(np.ones(N[0]))
  b = normalize(np.ones(N[1]))
  #Cost matrix
  C = distmat(x,y)
  for eps in epsilons:
    print( "Log-domain Sinkhorn for epsilon = "+str(eps)+":" )    
    print( "Doing for (",N[0],N[1],")." )
    print( " |- Iterating" )
    start = time.time()
    logsinkhorn = computational_OT.Log_domainSinkhorn(a,b,C,eps)
    output = logsinkhorn.update( niter = iterations )
    results_logSinkhorn.append( output )
    end = time.time()
    times_logSinkhorn.append(1e-3*(end-start) )
    logSinkhornP.append(GetP(output['potential_f']/eps, np.exp(-C/eps),output['potential_g']/eps))
  return {
        'results_list': results_logSinkhorn,
        'time_stamps' : times_logSinkhorn,
        'Ps'          : logSinkhornP
    }

#### torch log-domain Sinkhorn

In [None]:
def torchlog_domain_sinkhorn(epsilons,N,x,y, iterations = 1000):
  print("Log domain Sinkhorn.... ")
  results_torchlogSinkhorn = []
  times_torchlogSinkhorn   = []
  torchlogSinkhornP        = []
  a = torchnormalize(torch.ones(N[0]))
  b = torchnormalize(torch.ones(N[1]))
  #Cost matrix
  C = torchdistmat(x,y)
  for eps in epsilons:
    print( "Log-domain Sinkhorn for epsilon = "+str(eps)+":" )    
    print( "Doing for (",N[0],N[1],")." )
    print( " |- Iterating" )
    start = time.time()
    torchlogsinkhorn = computational_OT.torchLog_domainSinkhorn(a,b,C,eps)
    output = torchlogsinkhorn.update( niter = iterations )
    results_torchlogSinkhorn.append( output )
    end = time.time()
    times_torchlogSinkhorn.append(1e-3*(end-start) )
    torchlogSinkhornP.append(GetP(output['potential_f']/eps, torch.exp(-C/eps),output['potential_g']/eps))
  return {
        'results_list': results_torchlogSinkhorn,
        'time_stamps' : times_torchlogSinkhorn, 
        'Ps'          : torchlogSinkhornP
    }


#### Experiments

##### Sinkhorn

In [None]:
outputSinkhorn = Sinkhorn(epsilons, N, x, y)

In [None]:
plt.figure( figsize = (20,7) )
plt.title( "$||P1 -a||_1+||P1 -b||_1$" )
for i in range( len(outputSinkhorn['results_list']) ):
  error=np.asarray( outputSinkhorn['results_list'][i]['error_a'] )+np.asarray( outputSinkhorn['results_list'][i]['error_b'] )
  plt.plot( error,label='Sinkhorn for $\epsilon=$'+ str(epsilons[i]), linewidth = 2 )
plt.yscale( 'log' )
plt.legend()
plt.savefig("SinkhornVersions_images/ConvergenceSinkhornvaryingepsilon.png")
plt.show()

##### Log-domain Sinkhorn

In [None]:
outputLogSinkhorn = log_domain_Sinkhorn(epsilons,N,x,y)

In [None]:
plt.figure( figsize = (20,7) )
plt.title( "LogSin$||P1-a||_1+||P1-b||_1$" )
for i in range(len(outputLogSinkhorn['results_list'])):
    error = np.asarray( outputLogSinkhorn['results_list'][i]['error'] )
    plt.plot( error, label='Log-domain Sinkhorn for $\epsilon=$'+str(epsilons[i]), linewidth = 2  )
plt.yscale(  'log' )
plt.legend()
plt.savefig("SinkhornVersions_images/ConvergenceLogSinkhornvaryingepsilons.png")
plt.show()

##### torch Log-domain Sinkhorn

In [None]:
x = torch.from_numpy(x)
y = torch.from_numpy(y)

In [None]:
outputtorchLogSinkhorn = torchlog_domain_sinkhorn(epsilons, N, x, y)

In [None]:
plt.figure( figsize = (20,7) )
plt.title( "$||P1-a||_1+||P1-b||_1$" ),
for i in range(len(outputtorchLogSinkhorn['results_list'])):
    error = outputtorchLogSinkhorn['results_list'][i]['error'] 
    plt.plot( error, label='Log-domain Sinkhorn for $\epsilon=$'+str(epsilons[i]), linewidth = 2  )
plt.yscale(  'log' )
plt.legend()
plt.savefig("SinkhornVersions_images/ConvergencetorchLogSinkhornvaryingepsilons.png")
plt.show()

#### Time plot

In [None]:
plt.figure(figsize = (20,7))
plt.title("Time plot for differents Sinkhorn versions")
plt.plot(list(range(len(epsilons))), outputSinkhorn['time_stamps'], label = 'Sinkhorn$', marker='o', linewidth = 2)
plt.plot(list(range(len(epsilons))), outputLogSinkhorn['time_stamps'], label = 'Log-domain Sinkhorn', marker='o', linewidth = 2)
plt.plot(list(range(len(epsilons))), outputtorchLogSinkhorn['time_stamps'], label = 'Log-domain Sinkhorn using pytorch',marker='o', linewidth = 2)
plt.legend()
plt.xticks(list(range(len(epsilons))), epsilons[::-1])
plt.xlabel("$\epsilon$")
plt.ylabel("TIme in ms")
plt.savefig("SinkhornVersions_images/Timeplot.png")
plt.show()