In [1]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from pathlib import Path
import ot
import scipy as sp
import matplotlib.pyplot as plt
from utils import LoadCloudPoint, DistanceProfile

In [2]:
import random

random.seed(10)

lcp = LoadCloudPoint(filepath="datasets/csv_files/0005_Jogging001.csv")
source_pc, target_pc = lcp.get_two_random_point_cloud()

dp = DistanceProfile(source_pc, target_pc)
distance_matrix = dp.compute_L2_matrix()

In [3]:
distance_matrix[0].shape

(26, 26)

In [4]:
distance_matrix[1].shape

(26, 26)

In [5]:
from utils import compute_W_matrix_distance_matrix_input

W, map_matrix= compute_W_matrix_distance_matrix_input(distance_matrix[0], distance_matrix[1])

In [6]:
from utils import plot_3d_points_and_connections

In [7]:
# %load_ext autoreload

# Dist prof with W1 loss OT

In [8]:
plot_3d_points_and_connections(source_pc, target_pc, map_matrix)

In [9]:
from accuracy import accuracy
accuracy(map_matrix)

0.46153846153846156

# Dist prof W1 with dist L1 norm

In [10]:
distance_matrix = dp.compute_L1_matrix()
W, map_matrix= compute_W_matrix_distance_matrix_input(distance_matrix[0], distance_matrix[1])
plot_3d_points_and_connections(source_pc, target_pc, map_matrix)

In [11]:
accuracy(map_matrix)

0.34615384615384615

# Vanilla OT with position Coords

In [12]:
M = ot.dist(source_pc, target_pc)

N = source_pc.shape[0]
a = np.ones(N) / N
b = np.ones(N) / N
G = ot.solve(M, a, b).plan

In [13]:
plot_3d_points_and_connections(source_pc, target_pc, G)

In [14]:
accuracy(G)

0.3076923076923077

# Using GW on Dist loss L2

In [15]:
distance_matrix = dp.compute_L2_matrix()
T, logs = ot.gromov_wasserstein(distance_matrix[0], distance_matrix[1], a, b, 'square_loss', log=True)
plot_3d_points_and_connections(source_pc, target_pc, T)

In [16]:
accuracy(T)

0.38461538461538464

# GW with L1 dist

In [17]:
distance_matrix = dp.compute_L1_matrix()
T, logs = ot.gromov_wasserstein(distance_matrix[0], distance_matrix[1], a, b, 'square_loss', log=True)
fig = plot_3d_points_and_connections(source_pc, target_pc, T)
fig.show()
accuracy(T)

0.34615384615384615