In [17]:
import numpy as np
import pandas as pd
from IPython.display import display, HTML

from uot.analysis import get_agg_table, get_comparison_table, get_mean_comparison_table, get_std_comparison_table

# Comparison of OT algorithms

1. Own implemented Sinkhorn with jax
2. OTT-JAX Sinkhorn
3. LP

In [2]:
sinkhorn_df = pd.read_csv("./results/sinkhorn.csv")
ott_jax_sinkhorn_df = pd.read_csv("./results/ott_jax_sinkhorn.csv")
pot_lp_df = pd.read_csv("./results/pot_lp.csv")
pot_sinkhorn_df = pd.read_csv("./results/pot_sinkhorn.csv")

ott_jax_sinkhorn_df.drop(index=0, inplace=True)
sinkhorn_df.drop(index=0, inplace=True)

In [3]:
proper_order = ['32 1D Gaussians', '64 1D Gaussians', '128 1D Gaussians', '512 1D Gaussians',
                '4x4 2D Gaussians', '8x8 2D Gaussians', '16x16 2D Gaussians', '32x32 2D Gaussians'] 

def make_comlexity_order(df):
    df['name'] = pd.Categorical(df['name'], categories=proper_order, ordered=True)
    return df.sort_values('name')

In [4]:
for test_name in proper_order:
    ott_index = ott_jax_sinkhorn_df[ott_jax_sinkhorn_df.name == test_name].iloc[:1].index
    sink_index = ott_jax_sinkhorn_df[ott_jax_sinkhorn_df.name == test_name].iloc[:1].index
    ott_jax_sinkhorn_df.drop(index=ott_index, inplace=True)
    sinkhorn_df.drop(index=sink_index, inplace=True)


In [5]:
sinkhorn_agg = make_comlexity_order(get_agg_table(sinkhorn_df))
ott_jax_sinkhorn_agg = make_comlexity_order(get_agg_table(ott_jax_sinkhorn_df))
pot_lp_agg = make_comlexity_order(get_agg_table(pot_lp_df))
pot_sinkhorn_agg = make_comlexity_order(get_agg_table(pot_sinkhorn_df))

agg_dfs = {
    "jax_sinkhorn": sinkhorn_agg,
    "ottjax_sinkhorn": ott_jax_sinkhorn_agg,
    "pot_sinkhorn": pot_sinkhorn_agg,
    "pot_lp": pot_lp_agg
}

In [27]:
def display_mean_and_std(agg_dfs, column: str):
    mean_comparison = get_mean_comparison_table(agg_dfs, column)
    std_comparison = get_std_comparison_table(agg_dfs, column)

    mean_html = mean_comparison.style.highlight_min(axis=1, subset=mean_comparison.columns[1:],
                                    props='color:white; font-weight:bold; background-color:darkblue;').to_html()

    std_html = std_comparison.style.highlight_min(axis=1, subset=std_comparison.columns[1:],
                                    props='color:white; font-weight:bold; background-color:darkblue;').to_html()

    combined_html = f"""
    <h3 style="text-align:center;">{column} mean and std</h3>
    <div style="display: flex; justify-content: space-around;">
        <div>{mean_html}</div>
        <div>{std_html}</div>
    </div>
    """
    return display(HTML(combined_html))

In [28]:
display_mean_and_std(agg_dfs, "time")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,3.029323,25.769162,36.393653,1.229622
6,64 1D Gaussians,4.663676,25.114262,25.90488,1.298753
0,128 1D Gaussians,4.642429,25.339008,25.67188,1.086245
5,512 1D Gaussians,4.590975,24.942304,25.382326,1.008376
4,4x4 2D Gaussians,0.211581,1.231791,27.968145,0.248466
7,8x8 2D Gaussians,1.659655,11.380104,34.038168,0.492081
1,16x16 2D Gaussians,13.332098,44.709472,28.916623,5.075447
3,32x32 2D Gaussians,92.626207,722.758903,231.699549,265.767628

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,3.087803,19.415881,67.441683,0.233704
6,64 1D Gaussians,3.810828,19.413713,21.927879,0.31477
0,128 1D Gaussians,3.839975,19.592192,21.615207,0.181916
5,512 1D Gaussians,3.866606,19.23132,21.034795,0.131215
4,4x4 2D Gaussians,0.856471,1.733046,48.678334,0.02817
7,8x8 2D Gaussians,3.20218,16.846556,50.189865,0.066862
1,16x16 2D Gaussians,8.077172,19.215422,44.816723,1.658169
3,32x32 2D Gaussians,55.679145,286.483446,86.905831,166.830121


In [29]:
display_mean_and_std(agg_dfs, "cost_rerr")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.013884,0.014323,0.091737,0.0
6,64 1D Gaussians,0.013569,0.013998,0.091737,0.0
0,128 1D Gaussians,0.013569,0.013998,0.091737,0.0
5,512 1D Gaussians,0.013569,0.013998,0.091737,0.0
4,4x4 2D Gaussians,5.6e-05,3.2e-05,0.037282,0.0
7,8x8 2D Gaussians,8.4e-05,2.8e-05,0.00566,0.0
1,16x16 2D Gaussians,0.008965,0.00916,0.009178,0.0
3,32x32 2D Gaussians,0.029846,0.030382,0.031915,0.0

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.023216,0.023945,0.191663,0.0
6,64 1D Gaussians,0.023039,0.023764,0.191663,0.0
0,128 1D Gaussians,0.023039,0.023764,0.191663,0.0
5,512 1D Gaussians,0.023039,0.023764,0.191663,0.0
4,4x4 2D Gaussians,0.000203,9.9e-05,0.104954,0.0
7,8x8 2D Gaussians,0.000108,3.2e-05,0.031167,0.0
1,16x16 2D Gaussians,0.018365,0.018406,0.018197,0.0
3,32x32 2D Gaussians,0.081884,0.081934,0.081981,0.0


In [30]:
display_mean_and_std(agg_dfs, "coupling_avg_err")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.000103,0.000103,0.000112,0.0
6,64 1D Gaussians,0.000101,0.000101,0.000112,0.0
0,128 1D Gaussians,0.000101,0.000101,0.000112,0.0
5,512 1D Gaussians,0.000101,0.000101,0.000112,0.0
4,4x4 2D Gaussians,0.000231,0.000231,0.000713,0.0
7,8x8 2D Gaussians,0.000105,0.000105,0.000118,1e-06
1,16x16 2D Gaussians,9e-06,9e-06,9e-06,0.0
3,32x32 2D Gaussians,1e-06,1e-06,1e-06,0.0

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,4.8e-05,4.8e-05,5.4e-05,0.0
6,64 1D Gaussians,5e-05,5e-05,5.4e-05,0.0
0,128 1D Gaussians,5e-05,5e-05,5.4e-05,0.0
5,512 1D Gaussians,5e-05,5e-05,5.4e-05,0.0
4,4x4 2D Gaussians,0.000406,0.000406,0.001666,0.0
7,8x8 2D Gaussians,4.1e-05,4.1e-05,7e-05,4e-06
1,16x16 2D Gaussians,2e-06,2e-06,2e-06,1e-06
3,32x32 2D Gaussians,0.0,0.0,0.0,0.0


In [31]:
display_mean_and_std(agg_dfs, "memory")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.0,0.0,0.0,0.0
6,64 1D Gaussians,0.0,0.002841,0.0,0.0
0,128 1D Gaussians,0.0,0.0,0.0,0.0
5,512 1D Gaussians,0.0,0.0,0.0,0.0
4,4x4 2D Gaussians,0.0,0.0,0.0,0.0
7,8x8 2D Gaussians,0.0,0.0,0.002778,0.0
1,16x16 2D Gaussians,0.008523,0.011364,0.011111,0.0
3,32x32 2D Gaussians,4.325107,5.471591,10.644444,20.138194

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.0,0.0,0.0,0.0
6,64 1D Gaussians,0.0,0.018844,0.0,0.0
0,128 1D Gaussians,0.0,0.0,0.0,0.0
5,512 1D Gaussians,0.0,0.0,0.0,0.0
4,4x4 2D Gaussians,0.0,0.0,0.0,0.0
7,8x8 2D Gaussians,0.0,0.0,0.018634,0.0
1,16x16 2D Gaussians,0.056533,0.075378,0.074536,0.0
3,32x32 2D Gaussians,4.518531,8.035345,9.228398,18.972729
