### In this notebook we look into the performance of semi-dual damped Newton in Jax framework of python.

In [None]:
import os
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
import numpy as np
from __future__ import division
import time 
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline 
%load_ext autoreload                                                                                                                                                                                            
%autoreload 

In [None]:
import computational_OT

if not os.path.isdir('Images'):
    os.makedirs('Images')
if not os.path.isdir('Images/DampedNewton_SemiDual_images_JAX'):
    os.makedirs('Images/DampedNewton_SemiDual_images_JAX')

In [None]:
"""To compute distance matrix"""
def distmat(x,y):
    return jnp.sum( x**2,0 )[:,None] + jnp.sum( y**2,0 )[None,:] - 2*x.transpose().dot(y)
"""To Normalise a vector"""
normalize = lambda a: a/jnp.sum( a )
"""To Compute P"""
def GetP(u,K,v):
    return u[:,None]*K*v[None,:]
def plotp(x, col,plt, scale=200, edgecolors="k"):
  return plt.scatter( x[0,:], x[1,:], s=scale, edgecolors=edgecolors,  c=col, cmap='plasma', linewidths=2 )

In [None]:
def generate_data(N):
    """
     N is a list of the size of the data on x and y
    """
    key = jax.random.PRNGKey(0)
    key1, key2 = jax.random.split(key)
    x = jax.random.uniform(key = key1, shape = (2,N[0]) )-0.5
    theta = 2*jnp.pi*jax.random.uniform(key = key2, shape =  (1,N[1]) )
    r = 0.8+.2*jax.random.uniform(key = key2, shape =  (1,N[1]) )
    y = jnp.vstack( ( r*jnp.cos( theta ),r*jnp.sin( theta ) ) )
    return x,y

In [None]:
N = [ 500,  600 ]
x , y = generate_data(N)

In [None]:
rho = 0.95
c = 0.1
DampedNewtonP = []
results_DampedNewton  = []
times_DampedNewton    = []
Hessians_DampedNewton = []
#Cost matrix
C = distmat(x,y)
# a and b
a = normalize(jnp.ones(shape =(N[0],1)))
b = normalize(jnp.ones(shape =(N[1],1)))
epsilons = [  0.5, 0.03 ]
for eps in epsilons:
    # Line Search
    print("Damped Newton for epsilon="+str(eps)+":")    
    #Kernel
    K = jnp.exp(-C/eps)
    f = 0*a
    print("Doing for (",N[0],N[1],").")
    print( " |- Iterating" )  
    start = time.time()
    out = computational_OT.Damped_Newton_SemiDual_JAX._update(K,a,b,f,eps,rho,c, maxiter = 50 , debug = True)
    results_DampedNewton.append(out)
    end = time.time()
    times_DampedNewton.append(end-start)
    print( " |- Computing P")
    DampedNewtonP.append(GetP(jnp.exp(out['potential_f']/eps),K,jnp.exp(out['potential_g']/eps)))
    print( " |- Recording (unstabilized) Hessian \n")
    mat  = -eps*out['Hessian']
    diag = 1/jnp.sqrt( jnp.vstack( a ) ).flatten()
    mat = diag*mat*diag
    Hessians_DampedNewton.append( mat )

In [None]:
plt.figure(figsize = (12,5))
plt.title("$$")
plt.title("$||P1 -a||_1+||P^T 1 -b||_1$")
for i in range(len(results_DampedNewton)):
  error = jnp.asarray(results_DampedNewton[i]['error'])
  plt.plot( error, label = 'Damped Newton for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend(loc = "upper right")
plt.yscale( 'log' )
plt.tight_layout()
plt.savefig("Images/DampedNewton_SemiDual_images_JAX/ErrorLinesearchNewton.pdf", format = 'pdf')
plt.show()
print("\n Error plots can increase! The error is not the objective function!")

In [None]:
plt.figure(figsize = (12,5))
plt.title("$$")
plt.title("Objective Function")
for i in range(len(results_DampedNewton)):
  value = jnp.asarray(results_DampedNewton[i]['objectives'])
  plt.plot( value,label='Damped Newton for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Objective value")
plt.yscale('log')
plt.legend()
plt.savefig("Images/DampedNewton_SemiDual_images_JAX/ObjectiveLineSearchNewton.pdf", format = 'pdf')
plt.show()

In [None]:
plt.figure(figsize = (20,7))
plt.subplot(2,1,1),
plt.title("Alpha")
for i in range(len(results_DampedNewton)):
  plt.plot( jnp.asarray(results_DampedNewton[i]["linesearch_steps"]),label='Damped Newton for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Alpha in log-scale")
plt.legend()
# plt.yscale( 'log')
plt.savefig("Images/DampedNewton_SemiDual_images_JAX/AlphaLineSearchNewton.pdf", format = 'pdf')
plt.show()

In [None]:
def print_spectral_statistics(mat, stabilize=False):
    if stabilize:
        # Stabilizing largest and smallest eigenvalue
        min_vector = jnp.hstack( (jnp.ones((N[0],1))) )
        max_vector = jnp.hstack( (jnp.ones((N[0],1))) )
        norm = jnp.sqrt( N[0] )
        min_vector = min_vector/norm
        max_vector = max_vector/norm
        min_vector = min_vector.reshape( (min_vector.shape[0], 1) )
        max_vector = max_vector.reshape( (max_vector.shape[0], 1) )
        #
        mat = mat + jnp.dot( min_vector, min_vector.T)
        mat = mat - jnp.dot( max_vector, max_vector.T)
    # endif
    eig, v = jnp.linalg.eigh( mat )
    sorting_indices = jnp.argsort(eig)
    eig = eig[sorting_indices]
    v   = v[:, sorting_indices]
    
    #print( "Mean eigenvalue: ", np.mean(eig) )
    print( "List of smallest eigenvalues: ", eig[:10])
    print( "List of largest  eigenvalues: ", eig[-10:])
    min_index = jnp.argmin(eig)
    max_index = jnp.argmax(eig)
    min_value = eig[min_index]
    max_value = eig[max_index]
    min_vector = v[:, min_index]
    min_vector = min_vector/min_vector[0]
    max_vector = v[:,max_index]
    max_vector = max_vector/max_vector[0]
    condition_number = max_value/min_value
    # Test smallest and largest
    # print( "Min eigenvalue vector: ", min_vector)
    # print( "Max eigenvalue vector: ", max_vector)
    #
    #print( v[:,0]*np.sqrt( self.N1 + self.N2))
    #vector = v[:,0]
    #test = np.dot( result, vector)
    #print( np.linalg.norm(test) )
    #print("Min absolute eigenvalues: ", min_value)
    #print("Norm of v-1: ", np.linalg.norm(min_vector-eig_vector))
    print("Condition number: ", condition_number)
    # plt.hist( eig, 50)
    # plt.title( "Histogram of eigenvalues for Hessian")
    # plt.xlabel( "Eigenvalues")
    # plt.yscale( "log" )
    # plt.show()
    return eig,v

In [None]:
eigs=[]
eigvecs = []
for i in range(len(epsilons)):
    eps = epsilons[i]
    print("Spectral statistics of Hessian for epsilon="+str(eps))
    Hessian = Hessians_DampedNewton[i]
    ev=print_spectral_statistics( Hessian, stabilize=False)
    eigs.append(ev[0])
    eigvecs.append(ev[1])
    print("")


In [None]:
fig,ax=plt.subplots(figsize=(12,3),nrows=1, ncols=len(epsilons), sharey=True)
plt.title("Histogram of eigenvalues.")
for i in range(len(epsilons)):
    ax[i].hist( eigs[i], 50)
    ax[i].set_title( " $\epsilon$: "+str(epsilons[i]))
    ax[i].set_xlabel("Eigenvalues")
    ax[i].set_yscale( "log" )
plt.subplots_adjust(wspace=0,hspace=0)
plt.tight_layout()
plt.savefig("Images/DampedNewton_SemiDual_images_JAX/eigenhistunstabilized.pdf", format = 'pdf')
plt.show()

In [None]:
def build_preconditioners( num_eigs,modified_Hessian, ansatz=True ):
    # Diagonalize
    eigenvalues, eigenvectors = jnp.linalg.eigh( modified_Hessian )
    sorting_indices = np.argsort( eigenvalues )
    eigenvalues  = eigenvalues[sorting_indices]
    eigenvectors = eigenvectors[:, sorting_indices]
    # Form null vector
    if not ansatz:
        null_vector = eigenvectors[:, 0]
    else:
        null_vector = jnp.hstack( (np.ones(N[0])) )
        norm = jnp.sqrt( N[0])
        null_vector = null_vector/norm
    # Form other vectors
    indices = []
    for i in range(num_eigs):
        indices.append(i+1)
    precond_vectors = eigenvectors[:, indices ]
    precond_vectors = []
    for index in indices:
        precond_vectors.append( eigenvectors[:,index] )
    #
    return null_vector, precond_vectors 

In [None]:
num_eigs = 25
null_vector, precond_vectors = build_preconditioners( num_eigs, Hessians_DampedNewton[-1], ansatz=False )

In [None]:
rho = 0.95
c = 0.1
reset_starting_point    = True
final_modified_Hessians = []
DampedNewtonP           = []
results_DampedNewton    = []
times_DampedNewton      = []
precond_epsilons = [ 0.5, 0.1 ]
f, g = None, None
# Cost matrix
C = distmat(x,y)
# a and b
a = normalize(jnp.ones(shape =(N[0],1)))    
b = normalize(jnp.ones(shape =(N[1],1)))
for eps in precond_epsilons:
    #Kernel
    K = jnp.exp(-C/eps)
    # Line Search
    print( "Damped Newton for epsilon="+str(eps)+":" )    
    if (f is None) or (g is None): 
        f, g = a*0, b*0 
    print( "Doing for (",N[0],N[1],").")
    print( " |- Iterating" )  
    start = time.time() 
    out = computational_OT.Damped_Newton_precond_SemiDual_JAX._update( K, a, b, f, g, eps, rho, c, null_vector, precond_vectors[:], maxiter = 50, 
                                                                       iterative_inversion = 30, version = 2, debug = False, optType = 'cg')
    results_DampedNewton.append( out ) 
    end = time.time()   
    times_DampedNewton.append(1e3*(end-start))
    print( " |- Computing P" )
    if not reset_starting_point:
        f = out.x[:a.shape[0]]
        g = out.x[a.shape[0]:]
        # f = f.reshape( f.shape[0], -1)
        # g = g.reshape( g.shape[0], -1)
    DampedNewtonP.append( GetP(jnp.exp(out['potential_f']/eps),K,jnp.exp(out['potential_g']/eps)) )
    #final_modified_Hessians.append( Optimizer.modified_Hessian )

In [None]:
plt.figure(figsize = (20,7))
plt.title("$$")
plt.title("$||P1 -a||_1+||P^T1 -b||_1$")
for i in range(len(results_DampedNewton)):
  error = np.asarray(results_DampedNewton[i]['error'])
  plt.plot( error, label = 'Damped Newton for $\epsilon=$'+ str(precond_epsilons[i]), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Error in log-scale")
plt.legend()
plt.yscale('log')
plt.savefig("Images/DampedNewton_SemiDual_images_JAX/ErrorDampedNewtonwithPrecond_final_cg.png")
plt.show()
print("\n Error plots can increase! The error is not the objective function!")

In [None]:
plt.figure(figsize = (20,7))  
plt.subplot(2,1,1),
plt.title("Alpha")
for i in range(len(results_DampedNewton)):
  plt.plot( np.asarray(results_DampedNewton[i]["linesearch_steps"]),label = 'Damped Newton for $\epsilon=$'+ str(epsilons[i]), linewidth = 2)
plt.xlabel("Number of iterations")
plt.ylabel("Alpha in log-scale")
plt.legend()
plt.show()

In [None]:
text = [
        "Preconditioning 1: Form E data",
        "Preconditioning 2: Form P data",
        "Form preconditioning functions",
        "Invert the linear system for p_k",
        "Unwinding",
        "Complete code block"
        ]

plt.figure( figsize = (20,10) )  
for j in range( len(results_DampedNewton[0]['timings'][0]) ):
  values = []
  for i in range( len(results_DampedNewton) ):
    mean = 0
    for k in range( len(results_DampedNewton[i]['timings']) ):
      mean += results_DampedNewton[i]['timings'][k][j]
    mean = mean/len(results_DampedNewton[i]['timings']) 
    values.append( mean )
  if len(precond_epsilons) == len(values):
    plt.plot( precond_epsilons, np.asarray(values), label=text[j],linewidth = 2 )
    plt.legend( loc='upper left' )
plt.xlabel( "Epsilons" )
plt.ylabel( "Time in ms" )
plt.show()