In [54]:
import numpy as np
import pandas as pd

from uot.analysis import get_agg_table, get_comparison_table

# Comparison of OT algorithms

1. Own implemented Sinkhorn with jax
2. OTT-JAX Sinkhorn
3. LP

In [55]:
sinkhorn_df = pd.read_csv("./results/sinkhorn.csv")
ott_jax_sinkhorn_df = pd.read_csv("./results/ott_jax_sinkhorn.csv")
pot_lp_df = pd.read_csv("./results/pot_lp.csv")
pot_sinkhorn_df = pd.read_csv("./results/pot_sinkhorn.csv")

ott_jax_sinkhorn_df.drop(index=0, inplace=True)
sinkhorn_df.drop(index=0, inplace=True)

In [56]:
proper_order = ['32 1D Gaussians', '64 1D Gaussians', '128 1D Gaussians', '512 1D Gaussians',
                '4x4 2D Gaussians', '8x8 2D Gaussians', '16x16 2D Gaussians', '32x32 2D Gaussians'] 

def make_comlexity_order(df):
    df['name'] = pd.Categorical(df['name'], categories=proper_order, ordered=True)
    return df.sort_values('name')

In [57]:
for test_name in proper_order:
    ott_index = ott_jax_sinkhorn_df[ott_jax_sinkhorn_df.name == test_name].iloc[:1].index
    sink_index = ott_jax_sinkhorn_df[ott_jax_sinkhorn_df.name == test_name].iloc[:1].index
    ott_jax_sinkhorn_df.drop(index=ott_index, inplace=True)
    sinkhorn_df.drop(index=sink_index, inplace=True)


In [58]:
sinkhorn_agg = make_comlexity_order(get_agg_table(sinkhorn_df))
ott_jax_sinkhorn_agg = make_comlexity_order(get_agg_table(ott_jax_sinkhorn_df))
pot_lp_agg = make_comlexity_order(get_agg_table(pot_lp_df))
pot_sinkhorn_agg = make_comlexity_order(get_agg_table(pot_sinkhorn_df))

agg_dfs = {
    "jax_sinkhorn": sinkhorn_agg,
    "ottjax_sinkhorn": ott_jax_sinkhorn_agg,
    "pot_sinkhorn": pot_sinkhorn_agg,
    "pot_lp": pot_lp_agg
}

In [59]:
time_comparison = get_comparison_table(agg_dfs, 'time')
time_comparison

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,3.02932±3.0878,25.76916±19.41588,36.39365±67.44168,1.22962±0.2337
6,64 1D Gaussians,4.66368±3.81083,25.11426±19.41371,25.90488±21.92788,1.29875±0.31477
0,128 1D Gaussians,4.64243±3.83998,25.33901±19.59219,25.67188±21.61521,1.08624±0.18192
5,512 1D Gaussians,4.59097±3.86661,24.9423±19.23132,25.38233±21.0348,1.00838±0.13122
4,4x4 2D Gaussians,0.21158±0.85647,1.23179±1.73305,27.96815±48.67833,0.24847±0.02817
7,8x8 2D Gaussians,1.65965±3.20218,11.3801±16.84656,34.03817±50.18987,0.49208±0.06686
1,16x16 2D Gaussians,13.3321±8.07717,44.70947±19.21542,28.91662±44.81672,5.07545±1.65817
3,32x32 2D Gaussians,92.62621±55.67914,722.7589±286.48345,231.69955±86.90583,265.76763±166.83012


In [60]:
cost_rerr_comparison = get_comparison_table(agg_dfs, "cost_rerr")
cost_rerr_comparison

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_sinkhorn,pot_lp
2,32 1D Gaussians,0.01388±0.02322,0.01432±0.02395,0.09174±0.19166,0.0±0.0
6,64 1D Gaussians,0.01357±0.02304,0.014±0.02376,0.09174±0.19166,0.0±0.0
0,128 1D Gaussians,0.01357±0.02304,0.014±0.02376,0.09174±0.19166,0.0±0.0
5,512 1D Gaussians,0.01357±0.02304,0.014±0.02376,0.09174±0.19166,0.0±0.0
4,4x4 2D Gaussians,6e-05±0.0002,3e-05±0.0001,0.03728±0.10495,0.0±0.0
7,8x8 2D Gaussians,8e-05±0.00011,3e-05±3e-05,0.00566±0.03117,0.0±0.0
1,16x16 2D Gaussians,0.00897±0.01836,0.00916±0.01841,0.00918±0.0182,0.0±0.0
3,32x32 2D Gaussians,0.02985±0.08188,0.03038±0.08193,0.03192±0.08198,0.0±0.0
