# Gromov-Wasserstein Baseline Comparison

In this notebook, we give a simple baseline comparison of the performance of the quantized Gromov-Wasserstein (qGW) algorithm to standard Gromov-Wasserstein (GW).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import ot
import random

import pickle

from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.datasets import make_blobs

from scipy.sparse import coo_matrix

import time

from random import sample, uniform

from quantizedGW import *

# Experiment

We will construct toy datasets using the `sklearn` function `make_blobs`. These will consist of 2D point clouds with a varying number of points. Each point cloud is considered as a metric measure space (mm-space) with Euclidean distance and uniform measure.

Given two point clouds, we match using the GW algorithm and the qGW algorithm. The algorithms output couplings $\mu_{GW}$ and $\mu_{qGW}$, respectively. We also construct a product coupling $\mu_{prod}$. We use $\mu_{prod}$ as the putative maximizer of GW loss and $\mu_{GW}$ as the putative minimizer. We construct a relative error of $\mu_{qGW}$ as
$$
\mathrm{rel. error} = \frac{\mu_{qGW}-\mu_{GW}}{\mu_{prod} - \mu_{GW}}
$$

For each dataset size, we run this trial numerous times and report the average relative error, as well as the average compute time. We do this for a variety of sampling rates in the qGW algorithm.

Set parameters

In [None]:
means = [200,400,600,800,1000,1200,1400,1600,1800,2000] # Data sizes
variance = 0.2 # Randomize the sizes a bit
num_trials = 5 # How many trials to run for each size
sample_rates = [0.1,0.2,0.3,0.4,0.5] # Sample rates for qGW

Run experiment

In [None]:
results_gw = {}
results_qgw = {}


for mean_points in means:
    print('Starting',mean_points,'points...')
    
    times_gw = []
    losses_gw = []
    
    times_qgw = np.zeros([len(sample_rates),num_trials])
    losses_qgw = np.zeros([len(sample_rates),num_trials])
    relative_losses = np.zeros([len(sample_rates),num_trials])
    
    for j in range(num_trials):

        # Create Datasets
        num_points1 = int(random.uniform((1-variance)*mean_points,(1+variance)*mean_points))
        num_points2 = int(random.uniform((1-variance)*mean_points,(1+variance)*mean_points))
        n_features1 = 2
        n_features2 = 2

        X1, y = make_blobs(n_samples=num_points1, n_features = n_features1)
        Dist1 = euclidean_distances(X1)

        X2, y = make_blobs(n_samples=num_points2, n_features = n_features2)
        Dist2 = euclidean_distances(X2)

        p1 = ot.unif(num_points1)
        p2 = ot.unif(num_points2)
        
        product_loss = gwloss_init(Dist1,Dist2,p1,p2,p1[:,None]*p2[None,:])

        ## GW Coupling
        start = time.time()
        coup_gw, log = ot.gromov.gromov_wasserstein(
            Dist1, Dist2, p1, p2, 'square_loss', verbose=False, log=True)
        times_gw.append(time.time() - start)

        gw_loss = gwloss_init(Dist1,Dist2,p1,p2,coup_gw)
        losses_gw.append(gw_loss)
        

        ## quantized GW with random subset selection
        for (i,rate) in enumerate(sample_rates):
            samples = int(rate*min([num_points1,num_points2]))
            node_subset1 = list(set(sample(list(range(num_points1)),samples)))
            node_subset2 = list(set(sample(list(range(num_points2)),samples)))

            start = time.time()
            coup_comp = compressed_gw_point_cloud(Dist1,Dist2,p1,p2,
                                                  node_subset1,node_subset2,
                                                  verbose = False,return_dense = True)
            times_qgw[i,j] = time.time() - start
            
            quantized_loss = gwloss_init(Dist1,Dist2,p1,p2,coup_comp)
            losses_qgw[i,j] = quantized_loss
            
            relative_losses[i,j] = (quantized_loss - gw_loss)/(product_loss - gw_loss)
            
        print('Trial',j,'done')
    
    mean_time_gw = np.mean(times_gw)
    mean_loss_gw = np.mean(losses_gw)
    results_gw[mean_points] = {'time':mean_time_gw,'loss':mean_loss_gw} 

    mean_times_comp = np.mean(times_qgw, axis = 1)
    mean_losses_comp = np.mean(losses_qgw,axis = 1)
    mean_relative_losses = np.mean(relative_losses, axis = 1)
    results_qgw[mean_points] = {'time':mean_times_comp,'loss':mean_losses_comp,'relative loss':mean_relative_losses}

# Plotting the Results

In [None]:
gw_times = [results_gw[mean_points]['time'] for mean_points in means]
qgw_times = [[results_qgw[mean_points]['time'][j] for mean_points in means] for j in range(len(sample_rates))]
rel_losses = [[results_qgw[mean_points]['relative loss'][j] for mean_points in means] for j in range(len(sample_rates))]

In [None]:
x = means[:-1]
labels = means[:-1]

plt.figure(figsize = (10,6))
for j in range(len(sample_rates)):
    plt.plot(x,qgw_times[j][:-1],label = str(sample_rates[j]))
plt.plot(x,gw_times[:-1],'--',label = 'GW')
# plt.plot(x,gw_ent_1_times,'-.',label = 'GW Ent 10')
# plt.plot(x,gw_ent_2_times,':',label = 'GW Ent 100')

fontsize = 18

plt.xticks(x, labels)
plt.xlabel('Avg. Points per Dataset',fontsize = fontsize)
plt.ylabel('Avg. Time (s)',fontsize = fontsize)
plt.legend(loc="upper left",fontsize = fontsize)
plt.title('Matching Datasets: Compute Time',fontsize = fontsize)

plt.show()

In [None]:
x = means[:-1]
labels = means[:-1]


plt.figure(figsize = (10,6))

for j in range(len(sample_rates)):
    plt.plot(x,[100*rel_losses[j][k] for k in range(len(rel_losses[j]))][:-1],label = str(sample_rates[j]))
plt.plot(x,[0 for j in range(len(gw_times))][:-1],'--',label = 'GW')

fontsize = 18

plt.xticks(x, labels)
plt.xlabel('Avg. Points per Dataset',fontsize = fontsize)
plt.ylabel('Avg. Relative Error (%)',fontsize = fontsize)
plt.legend(loc="upper right",fontsize = fontsize)
plt.title('Matching Datasets: Relative Error Against GW',fontsize = fontsize)

plt.show()

# Matching Figure

To verify that the good performance of qGW w.r.t. relative error reflects high quality matchings, we can plot the matchings obtained by GW and qGW.

To visualize a matching, we color the source point cloud (say, by distance to a given point) and transfer the color to the target point cloud using the coupling matrix for each method. The transferred color is the weighted average of the colors matched to a given vertex, with weights coming from the coupling matrix values.

In [None]:
# Create new datasets for matching
num_points1 = 2000
num_points2 = 2000
n_features1 = 2
n_features2 = 2

X1, y = make_blobs(n_samples=num_points1, n_features = n_features1)
Dist1 = euclidean_distances(X1)

X2, y = make_blobs(n_samples=num_points2, n_features = n_features2)
Dist2 = euclidean_distances(X2)

p1 = ot.unif(num_points1)
p2 = ot.unif(num_points2)

In [None]:
## GW Coupling
start = time.time()
coup_gw, log = ot.gromov.gromov_wasserstein(
    Dist1, Dist2, p1, p2, 'square_loss', verbose=False, log=True)
time_gw = time.time() - start
print('GW Compute Time:',time_gw)

In [None]:
## quantized GW with random subset selection
sample_rate = .2
samples = int(sample_rate*num_points1)
node_subset1 = list(set(sample(list(range(num_points1)),samples)))
node_subset2 = list(set(sample(list(range(num_points2)),samples)))

start = time.time()
coup_qgw= compressed_gw_point_cloud(Dist1,Dist2,p1,p2,node_subset1,node_subset2,verbose = True,return_dense = True)
time_qgw = time.time() - start
print('Compressed GW Compute Time:', time_qgw)

In [None]:
quantized_loss = gwloss_init(Dist1,Dist2,p1,p2,coup_qgw)
gw_loss = gwloss_init(Dist1,Dist2,p1,p2,coup_gw)
product_loss = gwloss_init(Dist1,Dist2,p1,p2,p1[:,None]*p2[None,:])

print('Loss with Compression:', quantized_loss)
print('Loss without Compression:', gw_loss)
print('Product Coupling Loss:', product_loss)

rel_error = (quantized_loss - gw_loss)/(product_loss - gw_loss)*100

print('Relative Error w.r.t. product and optimal (percent):', rel_error)

In [None]:
fig = plt.figure(figsize = (5,5))
ax1 = fig.add_subplot(111)

xs = X1[:,0]
ys = X1[:,1]

# Color by distance to the given point
point = 1
c1 = Dist1[point,:]

fontsize = 14
ax1.scatter(xs, ys, c = c1)
plt.axis('equal')
plt.title('Source Data', fontsize = fontsize)

plt.show()

In [None]:
fig = plt.figure(figsize = (5,5))
ax1 = fig.add_subplot(111)

xs = X2[:,0]
ys = X2[:,1]

c2 = [np.dot(coup_gw[:,j],c1)/np.sum(coup_gw[:,j]) for j in range(Dist2.shape[0])]

ax1.scatter(xs, ys, c = c2)
plt.axis('equal')
plt.title('Target Data, GW Matching \n Compute Time '+str(np.round(time_gw,2))+'s', fontsize = fontsize)

plt.show()

In [None]:
fig = plt.figure(figsize = (5,5))
ax1 = fig.add_subplot(111)

xs = X2[:,0]
ys = X2[:,1]

c2 = [np.dot(coup_qgw[:,j],c1)/np.sum(coup_qgw[:,j]) for j in range(Dist2.shape[0])]

ax1.scatter(xs, ys, c = c2)
plt.axis('equal')
plt.title('Target Data, Compressed GW Matching \n Compute Time '+str(np.round(time_qgw,2))+'s, '+ str(np.round(rel_error,1))+'% Rel. Error', fontsize = fontsize)

plt.savefig('Matching_Blobs_Target_Comp',dpi = 100)
plt.show()