In [13]:

from uot.suites import time_precision_suite, memory_suite
from algorithms.lp import pot_lp
from uot.experiment import run_experiment
from algorithms.sinkhorn import sinkhorn, ott_jax_sinkhorn, pot_sinkhorn
from algorithms.gradient_ascent import gradient_ascent

from uot.analysis import display_all_metrics

In [14]:

import jax
jax.config.update("jax_enable_x64", True)

# Time and precision exeriments

### Solvers

In [15]:
solvers = {
    'pot-lp': pot_lp,
    'pot-sinkhorn': pot_sinkhorn,
    'ott-jax-sinkhorn': ott_jax_sinkhorn,
    'jax-sinkhorn': sinkhorn,
    'optax-grad-ascent': gradient_ascent
}

### Gamma distribution test

In [16]:
problemset_names = [
    "32 1D gamma",
    "64 1D gamma",
    "128 1D gamma",
    "256 1D gamma",
    "512 1D gamma",
]

results = run_experiment(time_precision_suite, problemset_names, solvers)
display_all_metrics(results, ['time', 'cost_rerr', 'coupling_avg_err'])

Data loaded from ./datasets/1D/32_gamma.pkl
Data loaded from ./datasets/1D/64_gamma.pkl
Data loaded from ./datasets/1D/128_gamma.pkl
Data loaded from ./datasets/1D/256_gamma.pkl
Data loaded from ./datasets/1D/512_gamma.pkl


Running experiments: 100%|██████████| 2250/2250 [04:02<00:00,  9.27it/s] 


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,128 1D gamma,1.208947,38.041252,47.091609,29.905285,175.730683
1,256 1D gamma,5.113982,37.615001,85.779525,57.049949,686.761814
2,32 1D gamma,0.357941,33.498458,2.707509,1.776017,148.195499
3,512 1D gamma,23.09402,86.797441,278.310267,148.526062,508.467891
4,64 1D gamma,0.511143,29.987944,10.355947,7.417487,162.913892

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,128 1D gamma,0.139159,8.984125,10.114285,11.4789,115.873047
1,256 1D gamma,1.177394,11.348571,30.081807,22.272609,206.272277
2,32 1D gamma,0.06727,8.558252,0.903775,1.053039,169.306466
3,512 1D gamma,6.168626,17.419408,68.465538,59.249211,1015.854114
4,64 1D gamma,0.058017,8.863987,3.433301,3.039375,194.123112


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,128 1D gamma,0.0,1.696158,1.698597,1.695987,2.384849
1,256 1D gamma,0.0,0.236497,0.238349,0.236556,0.987403
2,32 1D gamma,0.0,0.278758,0.279324,0.278811,0.901695
3,512 1D gamma,0.0,1.572517,1.579075,1.572461,5.112568
4,64 1D gamma,0.0,0.303198,0.304107,0.303212,1.013034

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,128 1D gamma,0.0,5.091688,5.094097,5.091177,4.955431
1,256 1D gamma,0.0,0.77973,0.780961,0.780382,1.037086
2,32 1D gamma,0.0,0.702786,0.703075,0.70284,0.49557
3,512 1D gamma,0.0,6.504446,6.510021,6.504141,19.572943
4,64 1D gamma,0.0,0.912131,0.913117,0.912007,0.572315


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,128 1D gamma,0.0,8.8e-05,8.8e-05,8.8e-05,0.000115
1,256 1D gamma,0.0,2.6e-05,2.6e-05,2.6e-05,2.9e-05
2,32 1D gamma,0.0,0.000377,0.000377,0.000377,0.001573
3,512 1D gamma,0.0,7e-06,7e-06,7e-06,8e-06
4,64 1D gamma,0.0,0.000231,0.000231,0.000231,0.000441

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,128 1D gamma,0.0,2e-06,2e-06,2e-06,1e-05
1,256 1D gamma,0.0,1e-06,1e-06,1e-06,1e-06
2,32 1D gamma,0.0,8.5e-05,8.5e-05,8.5e-05,0.000503
3,512 1D gamma,0.0,0.0,0.0,0.0,0.0
4,64 1D gamma,0.0,2.6e-05,2.6e-05,2.6e-05,6.3e-05


### Beta distribution test

In [None]:
problemset_names = [
    "32 1D beta",
    "64 1D beta",
    "128 1D beta",
    "256 1D beta",
    "512 1D beta",
    "1024 1D beta",
]

results = run_experiment(time_precision_suite, problemset_names, solvers)

display_all_metrics(results, ['time', 'cost_rerr', 'coupling_avg_err'])

Data loaded from ./datasets/1D/32_beta.pkl
Data loaded from ./datasets/1D/64_beta.pkl
Data loaded from ./datasets/1D/128_beta.pkl
Data loaded from ./datasets/1D/256_beta.pkl
Data loaded from ./datasets/1D/512_beta.pkl


Running experiments: 100%|██████████| 1800/1800 [00:12<00:00, 144.56it/s]


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta,1.34248,10.859378,2.445142,0.940138
1,256 1D beta,4.893817,11.221097,3.464174,1.786536
2,32 1D beta,0.832425,10.312576,2.018409,0.53777
3,512 1D beta,27.659463,12.160065,6.46003,4.713537
4,64 1D beta,0.710416,10.369939,2.167606,0.736638

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta,0.274452,1.571982,0.340967,0.006687
1,256 1D beta,2.292534,1.289194,0.345698,0.012444
2,32 1D beta,0.286196,0.788802,0.231789,0.00973
3,512 1D beta,16.043054,1.122639,1.07714,0.040271
4,64 1D beta,0.081393,0.692938,0.260509,0.016112


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta,0.0,3.340557,3.340557,3.340547
1,256 1D beta,0.0,1.869123,1.869127,1.869129
2,32 1D beta,0.0,0.395177,0.395191,0.395175
3,512 1D beta,0.0,12.724874,12.724875,12.724873
4,64 1D beta,0.0,1.115248,1.115266,1.115262

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta,0.0,5.568081,5.56808,5.568082
1,256 1D beta,0.0,3.963574,3.963574,3.963572
2,32 1D beta,0.0,1.002651,1.002651,1.002638
3,512 1D beta,0.0,60.293957,60.293955,60.293957
4,64 1D beta,0.0,3.259394,3.259395,3.259393


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta,0.0,7.8e-05,7.8e-05,7.8e-05
1,256 1D beta,0.0,2.4e-05,2.4e-05,2.4e-05
2,32 1D beta,0.0,0.000158,0.000158,0.000158
3,512 1D beta,0.0,7e-06,7e-06,7e-06
4,64 1D beta,0.0,0.00015,0.00015,0.00015

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta,0.0,7e-06,7e-06,7e-06
1,256 1D beta,0.0,1e-06,1e-06,1e-06
2,32 1D beta,0.0,0.000101,0.000101,0.000101
3,512 1D beta,0.0,0.0,0.0,0.0
4,64 1D beta,0.0,4.3e-05,4.3e-05,4.3e-05


### Different distributions

In [None]:
problemset_names = [
    "32 1D gaussian|uniform|gamma|beta|cauchy",
    "64 1D gaussian|uniform|gamma|beta|cauchy",
    "128 1D gaussian|uniform|gamma|beta|cauchy",
    "256 1D gaussian|uniform|gamma|beta|cauchy",
    "512 1D gaussian|uniform|gamma|beta|cauchy"
    "1024 1D gaussian|uniform|gamma|beta|cauchy"
]

results = run_experiment(time_precision_suite, problemset_names, solvers)
display_all_metrics(results, ['time', 'cost_rerr', 'coupling_avg_err'])

Data loaded from ./datasets/1D/32_beta_cauchy_gamma_gaussian_uniform.pkl
Data loaded from ./datasets/1D/64_beta_cauchy_gamma_gaussian_uniform.pkl
Data loaded from ./datasets/1D/128_beta_cauchy_gamma_gaussian_uniform.pkl
Data loaded from ./datasets/1D/256_beta_cauchy_gamma_gaussian_uniform.pkl
Data loaded from ./datasets/1D/512_beta_cauchy_gamma_gaussian_uniform.pkl


Running experiments: 100%|██████████| 1800/1800 [01:48<00:00, 16.52it/s] 


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta_cauchy_gamma_gaussian_uniform,2.191435,157.066637,19.878175,15.004543
1,256 1D beta_cauchy_gamma_gaussian_uniform,8.666804,193.444227,51.329413,35.537702
2,32 1D beta_cauchy_gamma_gaussian_uniform,1.06597,137.093233,8.728282,6.973803
3,512 1D beta_cauchy_gamma_gaussian_uniform,47.731591,144.134729,97.037767,61.73884
4,64 1D beta_cauchy_gamma_gaussian_uniform,0.878734,144.693439,13.960321,10.471092

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta_cauchy_gamma_gaussian_uniform,0.381626,131.401705,15.579803,12.166646
1,256 1D beta_cauchy_gamma_gaussian_uniform,2.326342,143.560744,37.857492,24.973716
2,32 1D beta_cauchy_gamma_gaussian_uniform,0.355428,127.179413,6.596298,6.172215
3,512 1D beta_cauchy_gamma_gaussian_uniform,21.170248,119.901652,80.452492,51.115977
4,64 1D beta_cauchy_gamma_gaussian_uniform,0.08127,107.834581,9.358084,7.511935


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.625102,0.622414,0.62187
1,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.039023,0.040196,0.039171
2,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.03146,0.031631,0.03147
3,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.036452,0.038032,0.036497
4,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.062431,0.06278,0.062334

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,3.543266,3.543949,3.543806
1,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.156295,0.156301,0.156261
2,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.100101,0.100336,0.100095
3,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.06485,0.06624,0.064342
4,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.263661,0.263694,0.263624


Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,8.1e-05,8e-05,8e-05
1,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,2.3e-05,2.3e-05,2.3e-05
2,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.000237,0.000237,0.000237
3,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,7e-06,7e-06,7e-06
4,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.000207,0.000207,0.000207

Unnamed: 0,dataset,pot-lp,pot-sinkhorn,ott-jax-sinkhorn,jax-sinkhorn
0,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,1e-05,8e-06,8e-06
1,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,4e-06,4e-06,4e-06
2,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.000141,0.000141,0.000141
3,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.0,0.0,0.0
4,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,4.1e-05,4.1e-05,4.1e-05
