# Optimal Transport Algorithms Analysis

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML
from uot.analysis import get_agg_table, get_mean_comparison_table, get_std_comparison_table, display_mean_and_std, display_all_metrics

plt.style.use('ggplot')
sns.set_palette("colorblind")

In [4]:
results_df = pd.read_csv('result_2025-05-01_02-52-55.csv')

print(f"Shape of results: {results_df.shape}")
print(f"Algorithms: {results_df['name'].unique()}")
print(f"Datasets: {results_df['dataset'].unique()}")

results_df.head()

Shape of results: (5740, 34)
Algorithms: ['pot-lp' 'ott-jax-sinkhorn' 'jax-sinkhorn' 'optax-grad-ascent']
Datasets: ['32 1D gamma' '64 1D gamma' '256 1D gamma' '512 1D gamma' '1024 1D gamma'
 '2048 1D gamma' '32 1D gaussian' '64 1D gaussian' '256 1D gaussian'
 '512 1D gaussian' '1024 1D gaussian' '2048 1D gaussian' '32 1D beta'
 '64 1D beta' '256 1D beta' '512 1D beta' '1024 1D beta' '2048 1D beta'
 '32 1D beta_cauchy_gamma_gaussian' '64 1D beta_cauchy_gamma_gaussian'
 '128 1D beta_cauchy_gamma_gaussian' '256 1D beta_cauchy_gamma_gaussian'
 '512 1D beta_cauchy_gamma_gaussian' '1024 1D beta_cauchy_gamma_gaussian'
 '2048 1D beta_cauchy_gamma_gaussian' 'WhiteNoise 32x32'
 'CauchyDensity 32x32' 'GRFmoderate 32x32' 'GRFrough 32x32'
 'GRFsmooth 32x32' 'LogGRF 32x32' 'LogitGRF 32x32'
 'MicroscopyImages 32x32' 'Shapes 32x32' 'ClassicImages 32x32'
 '3D_Colored_Mesh_red_1024x1024pts' '3D_Colored_Mesh_red_2048x2048pts']


Unnamed: 0.1,Unnamed: 0,problem,source_measure_name,target_measure_name,source_gamma,target_gamma,dataset,time,cost_rerr,coupling_avg_err,...,source_color_name,source_num_points,target_mesh_name,target_color_channel,target_color_name,target_num_points,source_points,target_points,color_mode,name
0,3,Simple transport,32 1D gamma,32 1D gamma,"(np.float64(2.0), np.float64(6.0))","(np.float64(8.0), np.float64(2.72))",32 1D gamma,0.688336,1.557822e-16,1.516189e-19,...,,,,,,,,,,pot-lp
1,4,Simple transport,32 1D gamma,32 1D gamma,"(np.float64(2.0), np.float64(6.0))","(np.float64(2.0), np.float64(5.34))",32 1D gamma,0.574769,0.0,6.098637e-20,...,,,,,,,,,,pot-lp
2,5,Simple transport,32 1D gamma,32 1D gamma,"(np.float64(2.0), np.float64(6.0))","(np.float64(3.0), np.float64(1.41))",32 1D gamma,0.640088,2.884502e-16,3.684593e-19,...,,,,,,,,,,pot-lp
3,6,Simple transport,32 1D gamma,32 1D gamma,"(np.float64(2.0), np.float64(6.0))","(np.float64(5.0), np.float64(3.38))",32 1D gamma,0.579087,0.0,2.1879919999999996e-19,...,,,,,,,,,,pot-lp
4,7,Simple transport,32 1D gamma,32 1D gamma,"(np.float64(2.0), np.float64(6.0))","(np.float64(10.0), np.float64(4.03))",32 1D gamma,0.582283,0.0,1.3658409999999998e-19,...,,,,,,,,,,pot-lp


In [21]:
metrics_to_analyze = ['time', 'cost_rerr', 'coupling_avg_err']

for metric in metrics_to_analyze:
    print(f"\n## Analysis of {metric} ##")
    display_mean_and_std(results_df, metric)


## Analysis of time ##


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,32 1D gamma,0.594633,7.870374,5.226564,2.089505
1,64 1D gamma,0.790995,14.012934,10.026857,2.871272
2,256 1D gamma,6.355724,26.249865,14.483131,30.109811
3,512 1D gamma,22.816457,172.120468,82.834946,6.432182
4,1024 1D gamma,157.69658,957.39834,442.800824,21.84527
5,2048 1D gamma,534.489982,2441.057092,1029.648571,82.963372
6,32 1D gaussian,0.619189,20.695721,18.424764,3.535471
7,64 1D gaussian,0.895952,31.765082,24.44181,3.220959
8,256 1D gaussian,9.143174,76.938644,40.325279,23.01116
9,512 1D gaussian,51.896794,549.214355,246.651783,336.474754

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,32 1D gamma,0.032123,2.357363,2.060558,0.694811
1,64 1D gamma,0.076263,5.748085,4.565121,1.94458
2,256 1D gamma,2.354984,12.870873,7.056626,54.645216
3,512 1D gamma,6.477157,57.381559,28.308996,0.170846
4,1024 1D gamma,61.653819,414.129539,184.943365,0.210033
5,2048 1D gamma,68.973041,827.719543,383.483836,0.805262
6,32 1D gaussian,0.128549,13.746587,12.764321,3.988006
7,64 1D gaussian,0.095658,25.404665,19.753893,1.950078
8,256 1D gaussian,3.755596,81.255309,41.229854,33.513025
9,512 1D gaussian,26.198619,406.214635,175.552062,639.193166



## Analysis of cost_rerr ##


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,32 1D gamma,0.0,0.808631,0.80813,0.940784
1,64 1D gamma,0.0,0.349908,0.348742,0.997926
2,256 1D gamma,0.0,2.551291,2.548138,0.998016
3,512 1D gamma,0.0,1.672507,1.667075,1.0
4,1024 1D gamma,0.0,0.673603,0.667033,1.0
5,2048 1D gamma,0.0,22.585105,22.565076,1.0
6,32 1D gaussian,0.0,0.00882,0.008264,0.994718
7,64 1D gaussian,0.0,0.012751,0.011964,0.998803
8,256 1D gaussian,0.0,0.017849,0.01567,0.999338
9,512 1D gaussian,0.0,0.037756,0.035161,0.999882

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,32 1D gamma,0.0,1.870216,1.870111,0.392142
1,64 1D gamma,0.0,0.941185,0.939342,0.002912
2,256 1D gamma,0.0,12.53509,12.52774,0.006438
3,512 1D gamma,0.0,5.368442,5.359606,0.0
4,1024 1D gamma,0.0,2.432297,2.414477,0.0
5,2048 1D gamma,0.0,123.496417,123.475707,0.0
6,32 1D gaussian,0.0,0.026579,0.025772,0.008764
7,64 1D gaussian,0.0,0.02982,0.029121,0.002014
8,256 1D gaussian,0.0,0.046394,0.043185,0.002563
9,512 1D gaussian,0.0,0.113058,0.108517,0.000375



## Analysis of coupling_avg_err ##


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,32 1D gamma,0.0,0.000336,0.000336,0.001096
1,64 1D gamma,0.0,0.000216,0.000216,0.000264
2,256 1D gamma,0.0,2.4e-05,2.4e-05,1.6e-05
3,512 1D gamma,0.0,7e-06,7e-06,4e-06
4,1024 1D gamma,0.0,2e-06,2e-06,1e-06
5,2048 1D gamma,0.0,0.0,0.0,0.0
6,32 1D gaussian,0.0,0.000266,0.000266,0.001223
7,64 1D gaussian,0.0,0.000191,0.000191,0.000269
8,256 1D gaussian,0.0,2e-05,2e-05,1.7e-05
9,512 1D gaussian,0.0,7e-06,7e-06,4e-06

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,32 1D gamma,0.0,0.00011,0.00011,0.000218
1,64 1D gamma,0.0,3.7e-05,3.7e-05,3e-05
2,256 1D gamma,0.0,2e-06,2e-06,2e-06
3,512 1D gamma,0.0,0.0,0.0,0.0
4,1024 1D gamma,0.0,0.0,0.0,0.0
5,2048 1D gamma,0.0,0.0,0.0,0.0
6,32 1D gaussian,0.0,0.000111,0.000111,0.000204
7,64 1D gaussian,0.0,5.2e-05,5.2e-05,3.3e-05
8,256 1D gaussian,0.0,9e-06,9e-06,3e-06
9,512 1D gaussian,0.0,1e-06,1e-06,0.0


In [5]:
one_dim_datasets = [ds for ds in results_df['dataset'].unique() if '1D' in ds]
two_dim_datasets = [ds for ds in results_df['dataset'].unique() if '2D' in ds or 'x32' in ds]
mesh_datasets = [ds for ds in results_df['dataset'].unique() if 'Mesh' in ds]

dataset_groups = {
    '1D Datasets': one_dim_datasets,
    '2D Datasets': two_dim_datasets,
    '3D Mesh Datasets': mesh_datasets
}

algorithms = results_df['name'].unique()

In [29]:
import os
import matplotlib.pyplot as plt
import seaborn as sns

def plot_runtime_distributions(datasets, results_df, algorithms, save_path):
    os.makedirs(save_path, exist_ok=True)
    
    for dataset in datasets:
        dataset_data = results_df[results_df['dataset'] == dataset]
        
        fig, axes = plt.subplots(1, 2, figsize=(18, 5))

        ax1 = axes[0]
        ax1.set_yscale('log')
        for algo in algorithms:
            algo_data = dataset_data[dataset_data['name'] == algo]['time']
            if len(algo_data) > 0:
                sns.kdeplot(algo_data, ax=ax1, label=algo, fill=True, alpha=0.3)
        
        ax1.set_title(f'KDE of Runtime for {dataset}')
        ax1.set_xlabel('Runtime (seconds)')
        ax1.set_ylabel('Frequency')
        ax1.legend()

        ax2 = axes[1]
        plot_data = []
        labels = []

        for algo in algorithms:
            algo_data = dataset_data[dataset_data['name'] == algo]['time']
            if len(algo_data) > 0:
                plot_data.append(algo_data)
                labels.append(algo)
            else:
                print(f"Algorithm: {algo}, Dataset: {dataset}, Count: {len(algo_data)} - No data")

        ax2.boxplot(plot_data, tick_labels=labels, showmeans=True)
        ax2.set_title(f'Box Plot of Runtime for {dataset}')
        ax2.set_ylabel('Runtime (seconds)')
        ax2.set_yscale('log')
        plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')

        plt.tight_layout()
        save_file = os.path.join(save_path, f"{dataset}_runtime_distributions.png")
        plt.savefig(save_file)
        plt.close(fig)


def plot_cost_err_distributions(datasets, results_df, algorithms, save_path):
    os.makedirs(save_path, exist_ok=True)

    for dataset in datasets:
        dataset_data = results_df[results_df['dataset'] == dataset]

        fig, axes = plt.subplots(1, 2, figsize=(18, 5))

        ax1 = axes[0]
        ax1.set_yscale('log')
        for algo in algorithms:
            algo_data = dataset_data[dataset_data['name'] == algo]['cost_rerr']
            if len(algo_data) > 1 and algo_data.nunique() > 1:
                sns.kdeplot(algo_data, ax=ax1, label=algo, fill=True, alpha=0.3)
            elif algo_data.nunique() == 1:
                print(f"Algorithm: {algo}, Dataset: {dataset}, Count: {len(algo_data)} - Constant value")
                const_val = algo_data.iloc[0]
                ax1.axvline(const_val, label=f"{algo} (constant)", linestyle='--')

        ax1.set_title(f'KDE of Cost Error for {dataset}')
        ax1.set_xlabel('Cost Error')
        ax1.set_ylabel('Frequency')
        ax1.legend()

        ax2 = axes[1]
        plot_data = []
        labels = []

        for algo in algorithms:
            algo_data = dataset_data[dataset_data['name'] == algo]['cost_rerr']
            if len(algo_data) > 0:
                plot_data.append(algo_data)
                labels.append(algo)

        ax2.boxplot(plot_data, tick_labels=labels, showmeans=True)
        ax2.set_title(f'Box Plot of Cost Error for {dataset}')
        ax2.set_ylabel('Cost Error')
        ax2.set_yscale('log')
        plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')

        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f"{dataset}_cost_err_distributions.png"))
        plt.close(fig)


def plot_coupling_err_distributions(datasets, results_df, algorithms, save_path):
    os.makedirs(save_path, exist_ok=True)

    for dataset in datasets:
        dataset_data = results_df[results_df['dataset'] == dataset]

        fig, axes = plt.subplots(1, 2, figsize=(18, 5))

        ax1 = axes[0]
        ax1.set_yscale('log')
        for algo in algorithms:
            algo_data = dataset_data[dataset_data['name'] == algo]['coupling_avg_err']
            if len(algo_data) > 0:
                sns.kdeplot(algo_data, ax=ax1, label=algo, fill=True, alpha=0.3)

        ax1.set_title(f'KDE of Coupling Avg. Error for {dataset}')
        ax1.set_xlabel('Coupling Avg. Error')
        ax1.set_ylabel('Frequency')
        ax1.legend()

        ax2 = axes[1]
        plot_data = []
        labels = []

        for algo in algorithms:
            algo_data = dataset_data[dataset_data['name'] == algo]['coupling_avg_err']
            if len(algo_data) > 0:
                plot_data.append(algo_data)
                labels.append(algo)

        ax2.boxplot(plot_data, tick_labels=labels, showmeans=True)
        ax2.set_title(f'Box Plot of Coupling Avg. Error for {dataset}')
        ax2.set_ylabel('Coupling Avg. Error')
        ax2.set_yscale('log')
        plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')

        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f"{dataset}_coupling_err_distributions.png"))
        plt.close(fig)




In [30]:
for group_name, datasets in dataset_groups.items():
    print(f"\n### {group_name} Runtime Analysis ###")
    plot_runtime_distributions(datasets, results_df, algorithms, save_path='runtime_distributions')
    print(f"\n### {group_name} Cost Error Analysis ###")
    plot_cost_err_distributions(datasets, results_df, algorithms, save_path='cost_err_distributions')
    print(f"\n### {group_name} Coupling Error Analysis ###")
    plot_coupling_err_distributions(datasets, results_df, algorithms, save_path='coupling_err_distributions')


### 1D Datasets Runtime Analysis ###

### 1D Datasets Cost Error Analysis ###

### 1D Datasets Cost Error Analysis ###
Algorithm: optax-grad-ascent, Dataset: 512 1D gamma, Count: 42 - Constant value
Algorithm: optax-grad-ascent, Dataset: 512 1D gamma, Count: 42 - Constant value
Algorithm: optax-grad-ascent, Dataset: 1024 1D gamma, Count: 42 - Constant value
Algorithm: optax-grad-ascent, Dataset: 1024 1D gamma, Count: 42 - Constant value
Algorithm: optax-grad-ascent, Dataset: 2048 1D gamma, Count: 42 - Constant value
Algorithm: optax-grad-ascent, Dataset: 2048 1D gamma, Count: 42 - Constant value
Algorithm: optax-grad-ascent, Dataset: 1024 1D beta_cauchy_gamma_gaussian, Count: 25 - Constant value
Algorithm: optax-grad-ascent, Dataset: 1024 1D beta_cauchy_gamma_gaussian, Count: 25 - Constant value
Algorithm: optax-grad-ascent, Dataset: 2048 1D beta_cauchy_gamma_gaussian, Count: 25 - Constant value
Algorithm: optax-grad-ascent, Dataset: 2048 1D beta_cauchy_gamma_gaussian, Count: 25 - Con