# Point Cloud Matching

This notebook demonstrates the basic functionality of the quantized Gromov-Wasserstein (qGW) algorithm for matching low-dimensional point clouds.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import networkx as nx
import ot
import random

import pickle

from sklearn.metrics.pairwise import euclidean_distances
from sklearn.datasets import make_blobs

from scipy.sparse import coo_matrix

import time

from random import sample, uniform

import pywavefront

from quantizedGW import *

## Point Cloud Data

We'll use point clouds from the [CAPOD database](https://sites.google.com/site/pgpapadakis/home/CAPOD). There are 15 classes of shapes and 12 samples per class. Loading requires `pywavefront`. We'll immediately compute the pairwise distance matrix for the point cloud.

In [None]:
shape_class = 8
shape_sample = 10

In [None]:
path='./data/CAPOD/class'+str(shape_class)+'/m'+str((shape_class-1)*12+shape_sample)+'.obj'
scene = pywavefront.Wavefront(path)

X1 = np.array(scene.vertices)
Dist1 = euclidean_distances(X1)

print('Number of points:', len(X1))

In [None]:
fig = plt.figure(figsize=(10,10))

ax = fig.add_subplot(111, projection='3d')
ax.scatter(X1[:,0],-X1[:,1],X1[:,2], marker='o', s=20, c='goldenrod', alpha=0.2)
ax.view_init(elev=10., azim=10)
# ax.set_axis_off()
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

plt.show()

Next we pick another shape and compute its distance matrix. We also create probability vectors for each shape (we'll use uniform measure).

In [None]:
shape_class = 8
shape_sample = 1

path='./data/CAPOD/class'+str(shape_class)+'/m'+str((shape_class-1)*12+shape_sample)+'.obj'
scene = pywavefront.Wavefront(path)

X2 = np.array(scene.vertices)
Dist2 = euclidean_distances(X2)

print('Number of points:', len(X2))

p1 = ot.unif(len(X1))
p2 = ot.unif(len(X2))

In [None]:
fig = plt.figure(figsize=(10,10))

ax = fig.add_subplot(111, projection='3d')
ax.scatter(X2[:,0],-X2[:,1],X2[:,2], marker='o', s=20, c='goldenrod', alpha=0.2)
ax.view_init(elev=10., azim=10)
# ax.set_axis_off()
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

plt.show()

## Matching Datasets

We now compute probabilistic matchings between the datasets. First, this is done with the standard Gromov-Wasserstein algorithm. We are using the function from the Python Optimal Transport `pot` package.

**Warning:** The computation becomes quite long if the shapes you are matching are larger than ~2k points. 

In [None]:
start = time.time()
coup, log = ot.gromov.gromov_wasserstein(Dist1, Dist2, p1, p2, 
                                        'square_loss', verbose=False, log=True)
time_gw = time.time() - start
print('GW Compute Time:',time_gw)

We can visualize a matching via *color transfer*: we color the source point cloud (by, say, distance to a fixed point), then transfer this coloring to the target point cloud. The color of a point in the target point cloud is the weighted average of the colors of the points which match to it under the matching, with weights coming from the coupling.

The figure below shows that GW did a good job of matching points.

In [None]:
point = 1000
c1 = Dist1[point,:]
c2 = [np.dot(coup[:,j],c1)/np.sum(coup[:,j]) for j in range(Dist2.shape[0])]

fig = plt.figure(figsize=(15,10))

ax = fig.add_subplot(121, projection='3d')
ax.scatter(X1[:,0],-X1[:,1],X1[:,2], marker='o', s=20, c=c1, alpha=0.15)
ax.view_init(elev=10., azim=10)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
plt.title('Source Point Cloud')

ax = fig.add_subplot(122, projection='3d')
ax.scatter(X2[:,0],-X2[:,1],X2[:,2], marker='o', s=20, c=c2, alpha=0.2)
ax.view_init(elev=10., azim=10)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
plt.title('Targe Point Cloud with Color Transferred by GW \n Compute Time: {}s'.format(np.round(time_gw,2)))
plt.show()

Next we compute a matching using the qGW algorithm. The function takes subsets of the source and target point clouds as input. We sample randomly at a user-defined rate. 

In [None]:
sample_size = .1

samples = int(sample_size*len(X1))

node_subset1 = list(set(sample(list(range(X1.shape[0])),samples)))
node_subset2 = list(set(sample(list(range(X2.shape[0])),samples)))

start = time.time()
coup_qgw = compressed_gw_point_cloud(Dist1,Dist2,p1,p2,
                                      node_subset1,node_subset2,
                                      verbose = True,return_dense = True)
time_qgw = time.time()-start

print('qGW Compute Time:',time_qgw)

Plotting the color transfer for `qGW` shows that the matching also picks up the structure of the point cloud, with a much faster compute time. The improvement in computation speed increases with the size of the point clouds.

In [None]:
point = 1000
c1 = Dist1[point,:]
c2 = [np.dot(coup_qgw[:,j],c1)/np.sum(coup_qgw[:,j]) for j in range(Dist2.shape[0])]

fig = plt.figure(figsize=(15,10))

ax = fig.add_subplot(121, projection='3d')
ax.scatter(X1[:,0],-X1[:,1],X1[:,2], marker='o', s=20, c=c1, alpha=0.15)
ax.view_init(elev=10., azim=10)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
plt.title('Source Point Cloud')

ax = fig.add_subplot(122, projection='3d')
ax.scatter(X2[:,0],-X2[:,1],X2[:,2], marker='o', s=20, c=c2, alpha=0.2)
ax.view_init(elev=10., azim=10)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
plt.title('Targe Point Cloud with Color Transferred by qGW \n Compute Time: {}s'.format(np.round(time_qgw,2)))
plt.show()

For comparison, entropy regularized GW (erGW) with a large regularization coefficient can handle larger datasets than GW with faster compute time, but the color transfer quality is diminished.

In [None]:
epsilon = 100

start = time.time()
coup_er, log = ot.gromov.entropic_gromov_wasserstein(Dist1, Dist2, p1, p2,
                                                  'square_loss', epsilon=epsilon, 
                                                   log=True, verbose=False)
time_er = time.time() - start
print('erGW Compute Time:',time_er)

point = 1000
c1 = Dist1[point,:]
c2 = [np.dot(coup_er[:,j],c1)/np.sum(coup_er[:,j]) for j in range(Dist2.shape[0])]

fig = plt.figure(figsize=(15,10))

ax = fig.add_subplot(121, projection='3d')
ax.scatter(X1[:,0],-X1[:,1],X1[:,2], marker='o', s=20, c=c1, alpha=0.15)
ax.view_init(elev=10., azim=10)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
plt.title('Source Point Cloud')

ax = fig.add_subplot(122, projection='3d')
ax.scatter(X2[:,0],-X2[:,1],X2[:,2], marker='o', s=20, c=c2, alpha=0.2)
ax.view_init(elev=10., azim=10)
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
plt.title('Targe Point Cloud with Color Transferred by erGW \n Compute Time: {}s'.format(np.round(time_er,2)))
plt.show()

## Quantifying Matching Quality

We can quantify the ability of qGW to find good matchings as follows. Give a point cloud $X$, we permute the order of its points and perturb the points with noise to get a new point cloud $\widetilde{X}$. There is a ground truth optimal matching of $X$ and $\widetilde{X}$. Let $\mu$ be a coupling of $X$ and $\widetilde{X}$. Given $x \in X$, there is a ground truth match $\widetilde{x} \in \widetilde{X}$ and a matched point returned from $\mu$ as 
$$
\widetilde{y} : = \mathrm{argmax} \mu(x,\cdot).
$$
We compute the distortion score of $\mu$ as
$$
\frac{1}{|X|} \sum_{x \in X} \|\widetilde{x} - x\|^2.
$$


In [None]:
def perturbPointCloud(X,noise = 0.1):
    
    perm = np.random.permutation(np.eye(len(X)))
    X_pert = np.matmul(perm,X) + noise*(np.random.rand(X.shape[0],X.shape[1]) - np.random.rand(X.shape[0],X.shape[1]))
    
    return X_pert, perm

def matching_distortion(X1,X2,matching,perm):
    
    dis = 0
    
    for j in range(len(X1)):
        dis += np.linalg.norm(X2[matching[j],:] - X2[np.argmax(perm[:,j]),:])**2
        
    return dis/X1.shape[0]

For `X1` above, let's perturb it and compute the distortion scores of GW, qGW and erGW.

In [None]:
X1_pert, perm = perturbPointCloud(X1,noise = .01*np.max(Dist1))
Dist1_pert = euclidean_distances(X1_pert,X1_pert)

start = time.time()
coup, log = ot.gromov.gromov_wasserstein(Dist1, Dist1_pert, p1, p1, 
                                        'square_loss', verbose=False, log=True)
matching_gw = [np.argmax(coup[j,:]) for j in range(len(X1))]

print('GW Done in {} seconds'.format(time.time()-start))

start = time.time()

sample_size = .1 # Increase to increase quality and compute time
samples = int(sample_size*len(X1))
node_subset1 = list(set(sample(list(range(X1.shape[0])),samples)))
node_subset2 = list(set(sample(list(range(X2.shape[0])),samples)))

coup_qgw = compressed_gw_point_cloud(Dist1,Dist1_pert,p1,p1,
                                      node_subset1,node_subset2,
                                      verbose = False,return_dense = True)
matching_qgw = [np.argmax(coup_qgw[j,:]) for j in range(len(X1))]

print('qGW Done in {} seconds'.format(time.time()-start))

start = time.time()

epsilon = 100 # Decrease to increase quality, but increase compute time
coup_er, log = ot.gromov.entropic_gromov_wasserstein(Dist1, Dist1_pert, p1, p1,
                                                  'square_loss', epsilon=epsilon, 
                                                   log=True, verbose=False)
matching_er = [np.argmax(coup_er[j,:]) for j in range(len(X1))]

print('erGW Done in {} seconds'.format(time.time()-start))


print('Distortion Scores:')
print('GW:{}'.format(matching_distortion(X1,X1_pert,matching_gw,perm)))
print('qGW:{}'.format(matching_distortion(X1,X1_pert,matching_qgw,perm)))
print('erGW:{}'.format(matching_distortion(X1,X1_pert,matching_er,perm)))

We see that qGW is generally close to GW and is sometimes better than GW (for larger datasets)!