In [99]:
import pandas as pd
import numpy as np

# Comparison of OT algorithms

1. Own implemented Sinkhorn with jax
2. OTT-JAX Sinkhorn
3. LP

In [100]:
sinkhorn_df = pd.read_csv("./results/sinkhorn.csv")
ott_jax_sinkhorn_df = pd.read_csv("./results/ott_jax_sinkhorn.csv")
pot_lp_df = pd.read_csv("./results/pot_lp_dfs.csv")

# sinkhorn_df = sinkhorn_df.rename(columns={'precision': "cost_rerr", "coupling_precision": "coupling_rerr"})
# ott_jax_sinkhorn_df = ott_jax_sinkhorn_df.rename(columns={'precision': "cost_rerr", "coupling_precision": "coupling_rerr"})
# pot_lp_df = pot_lp_df.rename(columns={'precision': "cost_rerr", "coupling_precision": "coupling_rerr"})

In [101]:
proper_order = ['32 1D Gaussians', '64 1D Gaussians', '128 1D Gaussians', '512 1D Gaussians',
                '4x4 2D Gaussians', '8x8 2D Gaussians', '16x16 2D Gaussians', '32x32 2D Gaussians'] 

In [102]:
def get_agg_table(df):
        df['time_ms'] = (df['time'] * 1000)
        df = df.groupby('name').agg(
                time_ms_mean = ('time_ms', 'mean'),
                time_ms_std = ('time_ms', lambda x: np.std(x, ddof=1)),
                cost_rerr_mean = ('cost_rerr', "mean"),
                cost_rerr_std = ('cost_rerr', lambda x: np.std(x, ddof=1)),
                coupling_rerr_mean = ('coupling_rerr', 'mean'),
                coupling_rerr_std = ('coupling_rerr', lambda x: np.std(x, ddof=1)),
                memory_mean = ('memory', 'mean'),
                memory_std = ('memory', lambda x: np.std(x, ddof=1))
        ).reset_index()
        df["name"] = pd.Categorical(df["name"], categories=proper_order, ordered=True)
        # df['time_ms_mean'] = df['time_ms_mean'].round(2)
        # df['time_ms_std'] = df['time_ms_std'].round(2)
        df = df.round(5)
        return df.sort_values('name')


def get_comparison_table(dfs: dict[str, pd.DataFrame], field: str):
        comparison_df = pd.DataFrame()
        comparison_df['dataset'] = proper_order

        for name, df in dfs.items():
                mean = df[f"{field}_mean"].astype(str)
                std = df[f"{field}_std"].astype(str)
                comparison_df[name] = mean + "±" + std
        return comparison_df


In [103]:
sinkhorn_agg = get_agg_table(sinkhorn_df)
ott_jax_sinkhorn_agg = get_agg_table(ott_jax_sinkhorn_df)
pot_lp_agg = get_agg_table(pot_lp_df)

agg_dfs = {
    "jax_sinkhorn": sinkhorn_agg,
    "ottjax_sinkhorn": ott_jax_sinkhorn_agg,
    "pot_lp": pot_lp_agg
}

In [104]:
time_comparison = get_comparison_table(agg_dfs, 'time_ms')
time_comparison

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_lp
0,32 1D Gaussians,144.15444±5.18416,348.14359±22.92946,0.83213±0.41966
1,64 1D Gaussians,218.52484±36.8524,401.79056±97.21802,4.87963±1.67143
2,128 1D Gaussians,153.79393±53.00704,367.27295±181.28288,0.86724±0.39631
3,512 1D Gaussians,1412.06348±51.23347,680.7835±103.21747,213.50249±113.98311
4,4x4 2D Gaussians,110.15127±31.64142,297.51464±92.24778,0.08012±0.02675
5,8x8 2D Gaussians,154.81129±66.37926,353.93474±24.20873,0.84594±0.42331
6,16x16 2D Gaussians,149.11±11.04561,343.35891±45.43005,0.84566±0.42347
7,32x32 2D Gaussians,126.3285±33.19661,330.11664±85.76198,0.34494±0.06595


In [105]:
cost_rerr_comparison = get_comparison_table(agg_dfs, "cost_rerr")
cost_rerr_comparison

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_lp
0,32 1D Gaussians,0.01379±0.02369,0.05906±0.11158,0.0±0.0
1,64 1D Gaussians,0.00943±0.01827,0.07847±0.14367,0.0±0.0
2,128 1D Gaussians,0.01379±0.02369,0.05906±0.11158,0.0±0.0
3,512 1D Gaussians,0.03236±0.08204,0.11533±0.26852,0.0±0.0
4,4x4 2D Gaussians,0.002±0.01327,0.00471±0.02336,0.0±0.0
5,8x8 2D Gaussians,0.01379±0.02369,0.05906±0.11158,0.0±0.0
6,16x16 2D Gaussians,0.01379±0.02369,0.05906±0.11158,0.0±0.0
7,32x32 2D Gaussians,1e-05±7e-05,0.02482±0.03204,0.0±0.0


In [106]:
coupling_rerr_comparison = get_comparison_table(agg_dfs, "coupling_rerr")
coupling_rerr_comparison

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_lp
0,32 1D Gaussians,-0.0±0.0,0.0±0.0,0.0±0.0
1,64 1D Gaussians,-0.0±0.0,-0.0±0.0,0.0±0.0
2,128 1D Gaussians,-0.0±0.0,0.0±0.0,0.0±0.0
3,512 1D Gaussians,-0.0±0.0,-0.0±0.0,0.0±0.0
4,4x4 2D Gaussians,-0.0±0.0,0.0±0.0,0.0±0.0
5,8x8 2D Gaussians,-0.0±0.0,0.0±0.0,0.0±0.0
6,16x16 2D Gaussians,-0.0±0.0,0.0±0.0,0.0±0.0
7,32x32 2D Gaussians,-0.0±0.0,0.0±0.0,0.0±0.0


In [107]:
memory_comparison = get_comparison_table(agg_dfs, "memory")
memory_comparison

Unnamed: 0,dataset,jax_sinkhorn,ottjax_sinkhorn,pot_lp
0,32 1D Gaussians,1.84253±0.06972,3.86944±0.122,0.0±0.0
1,64 1D Gaussians,0.91944±0.12824,9.56363±31.0773,0.00278±0.01863
2,128 1D Gaussians,1.83524±0.07563,3.90833±0.14202,0.0±0.0
3,512 1D Gaussians,1.23056±2.7327,3.11944±0.17055,21.28958±3.99694
4,4x4 2D Gaussians,1.83194±0.091,3.79167±0.14102,0.0±0.0
5,8x8 2D Gaussians,2.03082±1.31191,3.86389±0.1026,0.0±0.0
6,16x16 2D Gaussians,1.81389±0.06858,3.86111±0.13647,0.0±0.0
7,32x32 2D Gaussians,1.88316±0.08422,4.57222±5.00238,0.0±0.0
