In [2]:
import torch
import numpy as np

In [3]:
ref = torch.load('./reference.pth')[0]
gen = torch.load('./generated.pth')[0]

In [4]:
gen

tensor([[ 0.1232,  0.0895,  0.3940],
        [-0.1349,  0.0364, -0.2384],
        [-0.1496, -0.0106,  0.1617],
        ...,
        [ 0.1418, -0.0292,  0.3018],
        [ 0.0322,  0.0862,  0.3742],
        [ 0.1125, -0.0010,  0.0494]])

In [5]:
import plotly.graph_objects as go

pc = ref

def visualize_point_cloud(pc):
	fig = go.Figure(data=[go.Scatter3d(
	    x=pc[:, 0],
	    y=pc[:, 1],
	    z=pc[:, 2],
	    mode='markers',
	    marker=dict(
	        size=2,
	        #color=samples[i][:, 2],  # color points by z-axis value
	        #colorscale='Viridis',
	        opacity=0.8
	    )
	)])
	
	fig.update_layout(scene=dict(
	    xaxis_title='X',
	    yaxis_title='Y',
	    zaxis_title='Z'
	))
	
	fig.show()

In [6]:
visualize_point_cloud(ref)
visualize_point_cloud(gen)

In [7]:
import ot 

def calculate_emd(tensor1, tensor2):
    arr1 = tensor1.cpu().numpy()
    arr2 = tensor2.cpu().numpy()

    n = arr1.shape[0]
    m = arr2.shape[0]
    a = np.ones((n,)) / n  
    b = np.ones((m,)) / m  

    cost_matrix = ot.dist(arr1, arr2, metric='euclidean')
    emd_matrix = ot.emd(a, b, cost_matrix)
    total_emd = np.sum(emd_matrix * cost_matrix)

    return total_emd

In [8]:
from scipy.stats import wasserstein_distance

def calculate_emd_approx(tensor1, tensor2):
    arr1 = tensor1.cpu().numpy()
    arr2 = tensor2.cpu().numpy()
    
    emd_total = 0
    
    for dim in range(arr1.shape[1]):
        emd_total += wasserstein_distance(arr1[:, dim], arr2[:, dim])
    
    emd_average = emd_total / arr1.shape[1]

    return emd_average

In [9]:
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment

def emd(tensor1, tensor2):
    d = cdist(tensor1, tensor2)
    assignment = linear_sum_assignment(d)
    return d[assignment].sum() / min(len(tensor1), len(tensor2))

In [10]:
print("EMD 1:", calculate_emd(gen, ref)) #1

emd_result = calculate_emd_approx(gen, ref)
print("EMD 2:", np.sum(emd_result)) #2 

print("EMD 3:", emd(gen, ref)) #3


EMD 1: 0.05047609394864594
EMD 2: 0.022353968031931497
EMD 3: 0.05047608999655243
