In [1]:
from __future__ import division
import os
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
import scipy as scp
import pylab as pyl
import warnings

torch.set_default_dtype(torch.float64)
warnings.filterwarnings('ignore')
np.random.seed(1234)
torch.manual_seed(1234)

%matplotlib inline
%load_ext autoreload
%autoreload 

In [None]:
import torch

NUM_THREADS = os.cpu_count()
print("Number of used CPU threads: ", torch.get_num_threads())
torch.set_num_threads(NUM_THREADS)
print("Number of used CPU threads: ", torch.get_num_threads())




















































































#### Creating directory for image outputs


In [None]:
if not os.path.isdir('SinkhornVersions_images'):
    os.makedirs('SinkhornVersions_images')

#### To compute distance matrix


In [None]:
def distmat(x,y):
    return np.sum( x**2,0 )[:,None] + np.sum( y**2,0 )[None,:] - 2*x.transpose().dot(y)

def torchdistmat(x,y):
   return torch.sum( x**2,0 )[:,None]+torch.sum(y**2,0)[None,:]- 2*torch.matmul(x.t(),y)

#### Normalize vector

In [None]:
normalize = lambda a: a/np.sum( a )
torchnormalize = lambda a: a/torch.sum( a )

#### Compute P and plot

In [None]:
def GetP(u,K,v):
    return u[:,None]*K*v[None,:]

def plotp(x, col,plt, scale=200, edgecolors="k"):
  return plt.scatter( x[0,:], x[1,:], s=scale, edgecolors=edgecolors,  c=col, cmap='plasma', linewidths=2 )

In [None]:
import computational_OT

In [None]:
N = [ 400, 500 ]
x = np.random.rand( 2,N[0] )-0.5
theta = 2*np.pi*np.random.rand( 1,N[1] )
r = 0.8+.2*np.random.rand( 1,N[1] )
y = np.vstack( ( r*np.cos( theta ),r*np.sin( theta ) ) )

In [None]:
iterations_count = int(1e4*3)
epsilons = [  0.1, 0.005, 0.001, 0.0007, 0.0005]

#### Experiment functions

In [None]:
def Experiment(algorithm, C, epsilons, N, iterations = iterations_count):
    P        = []
    results  = []
    time_stamps    = []

    for eps in epsilons:
        print( str(algorithm)+ "for epsilon = "+str(eps)+":" )    
        print("Doing for (",N[0],N[1],").")
        print( " |- Iterating")
        start = time.time()
        if algorithm== "Sinkhorn": 
            # a and b   
            a = normalize(np.ones(N[0]))
            b = normalize(np.ones(N[1]))
            K = np.exp(-C/eps)
            #Inflating
            a = a.reshape(a.shape[0],-1)
            b = b.reshape(b.shape[0],-1)
            u = a
            v = b
            start = time.time()
            Optimizer = computational_OT.Sinkhorn(
                                                    K, 
                                                    a,
                                                    b,
                                                    u,
                                                    v,
                                                    eps
                                                    )
            out = Optimizer._update(maxiter = iterations)
            results.append( out )
            end = time.time()
            if not (np.isnan(np.linalg.norm(out["error_a"])) or np.isnan(np.linalg.norm(out["error_b"]))):
                time_stamps.append( 1e-3*( end-start ) )
            
            print( " |- Computing P")
            print( "" )
            P.append( GetP( np.exp(out['potential_f']/eps),K,np.exp(out['potential_g']/eps) ) )
        elif algorithm == "Log_domainSinkhorn":
            # a and b   
            a = normalize(np.ones(N[0]))
            b = normalize(np.ones(N[1]))
            K = np.exp(-C/eps)
            start = time.time()
            Optimizer = computational_OT.Log_domainSinkhorn(
                                                                a,
                                                                b,
                                                                C,
                                                                eps
                                                                )
            out = Optimizer.update( niter = iterations )
            results.append( out )
            end = time.time()
            print( " |- Computing P")
            print( "" )
            P.append( GetP( np.exp(out['potential_f']/eps),K,np.exp(out['potential_g']/eps) ) )
            time_stamps.append( 1e-3*( end-start ) )
        else:
            # a and b   
            a = torchnormalize(torch.ones(N[0]))
            b = torchnormalize(torch.ones(N[1]))
            K = torch.exp(-C/eps)
            start = time.time()
            Optimizer = computational_OT.torchLog_domainSinkhorn(
                                                                    a,
                                                                    b,
                                                                    C,
                                                                    eps
                                                                    )
            out = Optimizer.update( niter = iterations )
            results.append( out )
            end = time.time()
            print( " |- Computing P")
            print( "" )
            P.append( GetP( torch.exp(out['potential_f']/eps),K,torch.exp(out['potential_g']/eps) ))
            time_stamps.append( 1e-3*( end-start ) )
    return {
        'results_list': results,
        'time_stamps' : time_stamps,
        'Ps'          : P
    }
    

#### Experiments

##### Sinkhorn

In [None]:
outputSinkhorn = Experiment("Sinkhorn", distmat(x,y), epsilons, N)

In [None]:
plt.figure( figsize = (20,7) )
plt.title( "$||P1 -a||_1+||P1 -b||_1$" )
for i in range( len(outputSinkhorn['results_list']) ):
  error=np.asarray( outputSinkhorn['results_list'][i]['error_a'] )+np.asarray( outputSinkhorn['results_list'][i]['error_b'] )
  plt.plot( error,label='Sinkhorn for $\epsilon=$'+ str(epsilons[i]), linewidth = 2 )
plt.yscale( 'log' )
plt.legend()
plt.savefig("SinkhornVersions_images/ConvergenceSinkhornvaryingepsilon.png")
plt.show()

##### Log-domain Sinkhorn

In [None]:
outputLogSinkhorn = Experiment("Log_domainSinkhorn", distmat(x,y), epsilons, N)

In [None]:
plt.figure( figsize = (20,7) )
plt.title( "LogSin$||P1-a||_1+||P1-b||_1$" )
for i in range(len(outputLogSinkhorn['results_list'])):
    error = np.asarray( outputLogSinkhorn['results_list'][i]['error'] )
    plt.plot( error, label='Log-domain Sinkhorn for $\epsilon=$'+str(epsilons[i]), linewidth = 2  )
plt.yscale(  'log' )
plt.legend()
plt.savefig("SinkhornVersions_images/ConvergenceLogSinkhornvaryingepsilons.png")
plt.show()

##### torch Log-domain Sinkhorn

In [None]:
x = torch.from_numpy(x)
y = torch.from_numpy(y)

In [None]:
outputtorchLogSinkhorn = Experiment("torchLog_domainSinkhorn", torchdistmat(x,y), epsilons, N)

In [None]:
plt.figure( figsize = (20,7) )
plt.title( "$||P1-a||_1+||P1-b||_1$" ),
for i in range(len(outputtorchLogSinkhorn['results_list'])):
    error = outputtorchLogSinkhorn['results_list'][i]['error'] 
    plt.plot( error, label='Log-domain Sinkhorn for $\epsilon=$'+str(epsilons[i]), linewidth = 2  )
plt.yscale(  'log' )
plt.legend()
plt.savefig("SinkhornVersions_images/ConvergencetorchLogSinkhornvaryingepsilons.png")
plt.show()

#### Time plot

In [None]:
plt.figure(figsize = (10,4))
plt.title("Time plot for differents Sinkhorn versions")
plt.plot(list(range(len(outputSinkhorn['time_stamps']))), outputSinkhorn['time_stamps'], label = 'Sinkhorn', marker='o', linewidth = 2)
plt.plot(list(range(len(epsilons))), outputLogSinkhorn['time_stamps'], label = 'Log-domain Sinkhorn', marker='o', linewidth = 2)
plt.plot(list(range(len(epsilons))), outputtorchLogSinkhorn['time_stamps'], label = 'Log-domain Sinkhorn using Pytorch',marker='o', linewidth = 2)
plt.legend()
plt.xticks(list(range(len(epsilons))), epsilons)
plt.xlabel("$\epsilon$")
plt.yscale('log')
plt.ylabel("Time in ms")
plt.savefig("SinkhornVersions_images/Timeplot.png")
plt.show()