In [52]:
import numpy as np
import pandas as pd
from IPython.display import display, HTML

from uot.analysis import get_agg_table, get_comparison_table, get_mean_comparison_table, get_std_comparison_table

# Comparison of OT algorithms

1. Own implemented Sinkhorn with jax
2. OTT-JAX Sinkhorn
3. LP

In [53]:
sinkhorn_df = pd.read_csv("./results/sinkhorn.csv")
ott_jax_sinkhorn_df = pd.read_csv("./results/ott_jax_sinkhorn.csv")
pot_lp_df = pd.read_csv("./results/pot_lp.csv")
pot_sinkhorn_df = pd.read_csv("./results/pot_sinkhorn.csv")

ott_jax_sinkhorn_df.drop(index=0, inplace=True)
sinkhorn_df.drop(index=0, inplace=True)

In [54]:
proper_order = ['32 1D Gaussians', '64 1D Gaussians', '128 1D Gaussians', '512 1D Gaussians',
                '4x4 2D Gaussians', '8x8 2D Gaussians', '16x16 2D Gaussians', '32x32 2D Gaussians'] 

def make_comlexity_order(df):
    df['name'] = pd.Categorical(df['name'], categories=proper_order, ordered=True)
    return df.sort_values('name')

In [55]:
for test_name in proper_order:
    ott_index = ott_jax_sinkhorn_df[ott_jax_sinkhorn_df.name == test_name].iloc[:1].index
    sink_index = ott_jax_sinkhorn_df[ott_jax_sinkhorn_df.name == test_name].iloc[:1].index
    ott_jax_sinkhorn_df.drop(index=ott_index, inplace=True)
    sinkhorn_df.drop(index=sink_index, inplace=True)


In [56]:
sinkhorn_agg = make_comlexity_order(get_agg_table(sinkhorn_df))
ott_jax_sinkhorn_agg = make_comlexity_order(get_agg_table(ott_jax_sinkhorn_df))
pot_lp_agg = make_comlexity_order(get_agg_table(pot_lp_df))
pot_sinkhorn_agg = make_comlexity_order(get_agg_table(pot_sinkhorn_df))

agg_dfs = {
    "jax_sinkhorn": sinkhorn_agg,
    "ottjax_sinkhorn": ott_jax_sinkhorn_agg,
    "pot_sinkhorn": pot_sinkhorn_agg,
    "pot_lp": pot_lp_agg
}

In [57]:
def display_mean_and_std(agg_dfs, column: str):
    mean_comparison = get_mean_comparison_table(agg_dfs, column)
    std_comparison = get_std_comparison_table(agg_dfs, column)

    mean_html = mean_comparison.style.highlight_min(axis=1, subset=mean_comparison.columns[1:],
                                    props='color:white; font-weight:bold; background-color:darkblue;').to_html()

    std_html = std_comparison.style.highlight_min(axis=1, subset=std_comparison.columns[1:],
                                    props='color:white; font-weight:bold; background-color:darkblue;').to_html()

    combined_html = f"""
    <h3 style="text-align:center;">{column} mean and std</h3>
    <div style="display: flex; justify-content: space-around;">
        <div>{mean_html}</div>
        <div>{std_html}</div>
    </div>
    """
    return display(HTML(combined_html))

In [58]:
display_mean_and_std(agg_dfs, "time")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,7.474264,5.670013,71.734608,0.547151
6,64 1D Gaussians,30.515417,17.738918,64.301571,0.91625
0,128 1D Gaussians,233.703858,238.91224,101.929921,14.804785
5,512 1D Gaussians,404.410245,618.300572,203.851415,95.564234
4,4x4 2D Gaussians,148.439084,3.321776,61.460383,0.418318
7,8x8 2D Gaussians,18.439351,21.700288,88.682318,0.789312
1,16x16 2D Gaussians,104.265853,142.514011,130.266816,8.841312
3,32x32 2D Gaussians,833.05922,1310.870771,568.57673,331.411305

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,8.293957,3.090734,129.87495,0.111575
6,64 1D Gaussians,29.457112,12.913335,101.017257,0.195617
0,128 1D Gaussians,202.128167,192.935387,125.099379,4.759199
5,512 1D Gaussians,266.922372,444.602782,172.848899,38.790335
4,4x4 2D Gaussians,341.625026,4.605224,96.716357,0.084878
7,8x8 2D Gaussians,28.635419,29.779425,112.120314,0.126694
1,16x16 2D Gaussians,57.22518,54.858979,47.628101,2.780088
3,32x32 2D Gaussians,455.066133,573.860129,291.567407,174.601392


In [59]:
display_mean_and_std(agg_dfs, "cost_rerr")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.008582,0.008787,0.091269,0.0
6,64 1D Gaussians,0.012594,0.012912,0.091592,0.0
0,128 1D Gaussians,0.014005,0.014752,0.089634,0.0
5,512 1D Gaussians,0.013879,0.014989,0.086192,0.0
4,4x4 2D Gaussians,5.6e-05,3.2e-05,0.037282,0.0
7,8x8 2D Gaussians,8.4e-05,2.8e-05,0.00566,0.0
1,16x16 2D Gaussians,0.008965,0.00916,0.009178,0.0
3,32x32 2D Gaussians,0.029846,0.030382,0.031915,0.0

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.014961,0.015326,0.202708,0.0
6,64 1D Gaussians,0.021305,0.02184,0.193792,0.0
0,128 1D Gaussians,0.023721,0.024859,0.18539,0.0
5,512 1D Gaussians,0.023353,0.025014,0.179671,0.0
4,4x4 2D Gaussians,0.000203,9.9e-05,0.104954,0.0
7,8x8 2D Gaussians,0.000108,3.2e-05,0.031167,0.0
1,16x16 2D Gaussians,0.018365,0.018406,0.018197,0.0
3,32x32 2D Gaussians,0.081884,0.081934,0.081981,0.0


In [60]:
display_mean_and_std(agg_dfs, "coupling_avg_err")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.000272,0.000272,0.000465,0.0
6,64 1D Gaussians,0.000184,0.000184,0.00022,0.0
0,128 1D Gaussians,2.1e-05,2.1e-05,2.2e-05,0.0
5,512 1D Gaussians,6e-06,6e-06,6e-06,0.0
4,4x4 2D Gaussians,0.000231,0.000231,0.000713,0.0
7,8x8 2D Gaussians,0.000105,0.000105,0.000118,1e-06
1,16x16 2D Gaussians,9e-06,9e-06,9e-06,0.0
3,32x32 2D Gaussians,1e-06,1e-06,1e-06,0.0

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.000163,0.000163,0.00046,0.0
6,64 1D Gaussians,9.7e-05,9.7e-05,0.000116,0.0
0,128 1D Gaussians,1e-05,1e-05,1e-05,0.0
5,512 1D Gaussians,2e-06,2e-06,2e-06,0.0
4,4x4 2D Gaussians,0.000406,0.000406,0.001666,0.0
7,8x8 2D Gaussians,4.1e-05,4.1e-05,7e-05,4e-06
1,16x16 2D Gaussians,2e-06,2e-06,2e-06,1e-06
3,32x32 2D Gaussians,0.0,0.0,0.0,0.0


In [61]:
display_mean_and_std(agg_dfs, "memory")

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.0,0.0,0.002778,0.0
6,64 1D Gaussians,0.0,0.0,0.0,0.0
0,128 1D Gaussians,0.019886,0.0,0.030556,0.0
5,512 1D Gaussians,0.008523,0.059659,0.144444,0.097222
4,4x4 2D Gaussians,0.0,0.0,0.0,0.0
7,8x8 2D Gaussians,0.0,0.0,0.0,0.0
1,16x16 2D Gaussians,0.0,0.0,0.0,0.0
3,32x32 2D Gaussians,5.170898,0.363636,4.661111,38.180556

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.0,0.0,0.018634,0.0
6,64 1D Gaussians,0.0,0.0,0.0,0.0
0,128 1D Gaussians,0.093165,0.0,0.133664,0.0
5,512 1D Gaussians,0.041744,0.313792,0.695948,0.652186
4,4x4 2D Gaussians,0.0,0.0,0.0,0.0
7,8x8 2D Gaussians,0.0,0.0,0.0,0.0
1,16x16 2D Gaussians,0.0,0.0,0.0,0.0
3,32x32 2D Gaussians,5.27736,1.685656,9.156221,6.333564
