In [4]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"

import jax
jax.config.update("jax_enable_x64", True)

from uot.suites import time_precision_suite, memory_suite
from algorithms.lp import pot_lp
from uot.experiment import run_experiment, get_problemset
from algorithms.sinkhorn import jax_sinkhorn, ott_jax_sinkhorn, pot_sinkhorn
from algorithms.gradient_ascent import gradient_ascent

from uot.analysis import display_all_metrics

# Time and precision exeriments

### Solvers

In [5]:
solvers = {
    'pot-lp': pot_lp,
    'ott-jax-sinkhorn': ott_jax_sinkhorn,
    'jax-sinkhorn': jax_sinkhorn,
    'optax-grad-ascent': gradient_ascent
}

### Gaussian distributions 

In [6]:
problemset_names = [
    "64 1D gaussian",
    "128 1D gaussian",
    "256 1D gaussian",
    "512 1D gaussian",
    "1024 1D gaussian"
]

gaussian_results = run_experiment(time_precision_suite, problemset_names, solvers)
display_all_metrics(gaussian_results, ['time', 'cost_rerr', 'coupling_avg_err'])

Data loaded from ./datasets/1D/64_gaussian.pkl
Data loaded from ./datasets/1D/128_gaussian.pkl
Data loaded from ./datasets/1D/256_gaussian.pkl
Data loaded from ./datasets/1D/512_gaussian.pkl
Data loaded from ./datasets/1D/1024_gaussian.pkl


Running experiments: 100%|██████████| 900/900 [03:11<00:00,  4.70it/s] 


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gaussian,357.991049,1502.146643,933.980013,24.343907
1,128 1D gaussian,1.620076,52.591901,40.05068,1.034888
2,256 1D gaussian,8.270964,120.229836,91.348399,1.442905
3,512 1D gaussian,46.087175,299.434903,224.591235,3.961085
4,64 1D gaussian,0.743639,23.219479,18.177861,0.80144

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gaussian,234.839027,930.618534,548.296434,0.8066
1,128 1D gaussian,0.300572,20.098345,15.627784,0.08023
2,256 1D gaussian,2.945729,43.493692,31.529589,0.13066
3,512 1D gaussian,20.676042,112.95173,90.478066,0.303049
4,64 1D gaussian,0.198314,6.122241,5.176676,0.090456


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gaussian,0.0,0.049985,0.048357,1.0
1,128 1D gaussian,0.0,0.244717,0.24358,1.0
2,256 1D gaussian,0.0,0.044073,0.042924,1.0
3,512 1D gaussian,0.0,0.462196,0.458394,1.0
4,64 1D gaussian,0.0,0.185188,0.184168,0.99994

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gaussian,0.0,0.171225,0.1709,0.0
1,128 1D gaussian,0.0,1.125787,1.123868,0.0
2,256 1D gaussian,0.0,0.121506,0.120837,0.0
3,512 1D gaussian,0.0,1.290672,1.284691,0.0
4,64 1D gaussian,0.0,0.792894,0.79189,0.000259


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gaussian,0.0,2e-06,2e-06,1e-06
1,128 1D gaussian,0.0,8.9e-05,8.9e-05,6.1e-05
2,256 1D gaussian,0.0,2.6e-05,2.6e-05,1.5e-05
3,512 1D gaussian,0.0,7e-06,7e-06,4e-06
4,64 1D gaussian,0.0,0.000245,0.000245,0.000244

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gaussian,0.0,0.0,0.0,0.0
1,128 1D gaussian,0.0,1e-06,1e-06,0.0
2,256 1D gaussian,0.0,0.0,0.0,0.0
3,512 1D gaussian,0.0,0.0,0.0,0.0
4,64 1D gaussian,0.0,1.4e-05,1.4e-05,0.0


### Gamma distribution test

In [7]:
problemset_names = [
    "32 1D gamma",
    "64 1D gamma",
    "128 1D gamma",
    "256 1D gamma",
    "512 1D gamma",
    "1024 1D gamma"
]

gamma_results = run_experiment(time_precision_suite, problemset_names, solvers)
display_all_metrics(gamma_results, ['time', 'cost_rerr', 'coupling_avg_err'])

Data saved to ./datasets/1D/32_gamma.pkl
Data saved to ./datasets/1D/64_gamma.pkl
Data saved to ./datasets/1D/128_gamma.pkl
Data saved to ./datasets/1D/256_gamma.pkl
Data saved to ./datasets/1D/512_gamma.pkl
Data saved to ./datasets/1D/1024_gamma.pkl


Running experiments:   0%|          | 0/1080 [00:00<?, ?it/s]

Running experiments: 100%|██████████| 1080/1080 [02:17<00:00,  7.87it/s] 


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gamma,112.881127,1289.255104,768.730255,22.543157
1,128 1D gamma,1.270168,42.794181,32.76772,1.007666
2,256 1D gamma,4.606151,59.731953,45.735795,1.414802
3,32 1D gamma,1.201355,10.976747,4.683642,3.669318
4,512 1D gamma,19.458231,269.909056,200.943544,3.920479
5,64 1D gamma,0.47442,10.659167,8.395256,2.032102

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gamma,54.457399,436.883952,265.28851,0.792315
1,128 1D gamma,0.173519,9.044845,6.859305,0.117371
2,256 1D gamma,1.002349,31.825137,21.43351,0.133944
3,32 1D gamma,6.052601,46.006462,13.523791,18.648357
4,512 1D gamma,3.677243,78.185477,54.018782,0.228601
5,64 1D gamma,0.047817,4.124028,3.69883,1.164566


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gamma,0.0,0.622494,0.615545,1.0
1,128 1D gamma,0.0,0.639553,0.637421,1.0
2,256 1D gamma,0.0,0.575926,0.573645,1.0
3,32 1D gamma,0.0,0.214929,0.214262,0.955678
4,512 1D gamma,0.0,0.188696,0.184802,1.0
5,64 1D gamma,0.0,2.022703,2.021884,0.982248

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gamma,0.0,1.739608,1.734005,0.0
1,128 1D gamma,0.0,2.278202,2.275524,0.0
2,256 1D gamma,0.0,1.55129,1.550281,0.0
3,32 1D gamma,0.0,0.744632,0.744359,0.13918
4,512 1D gamma,0.0,0.365253,0.362345,0.0
5,64 1D gamma,0.0,11.67091,11.670243,0.062798


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gamma,0.0,2e-06,2e-06,1e-06
1,128 1D gamma,0.0,8.8e-05,8.8e-05,6.1e-05
2,256 1D gamma,0.0,2.6e-05,2.6e-05,1.5e-05
3,32 1D gamma,0.0,0.000338,0.000338,0.00117
4,512 1D gamma,0.0,7e-06,7e-06,4e-06
5,64 1D gamma,0.0,0.000233,0.000233,0.000264

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D gamma,0.0,0.0,0.0,0.0
1,128 1D gamma,0.0,2e-06,2e-06,0.0
2,256 1D gamma,0.0,0.0,0.0,0.0
3,32 1D gamma,0.0,0.000124,0.000124,0.000195
4,512 1D gamma,0.0,0.0,0.0,0.0
5,64 1D gamma,0.0,2.3e-05,2.3e-05,2.3e-05


### Beta distribution test

In [8]:
problemset_names = [
    "32 1D beta",
    "64 1D beta",
    "128 1D beta",
    "256 1D beta",
    "512 1D beta",
    "1024 1D beta",
]

beta_results = run_experiment(time_precision_suite, problemset_names, solvers)

display_all_metrics(beta_results, ['time', 'cost_rerr', 'coupling_avg_err'])

Data saved to ./datasets/1D/32_beta.pkl
Data saved to ./datasets/1D/64_beta.pkl
Data saved to ./datasets/1D/128_beta.pkl
Data saved to ./datasets/1D/256_beta.pkl
Data saved to ./datasets/1D/512_beta.pkl
Data saved to ./datasets/1D/1024_beta.pkl


Running experiments:   0%|          | 0/1080 [00:00<?, ?it/s]

  result_code_string = check_result(result_code)
Running experiments: 100%|██████████| 1080/1080 [06:51<00:00,  2.62it/s]


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta,150.52624,103.453865,92.878146,8120.218196
1,128 1D beta,1.161747,4.101102,3.023232,17.010203
2,256 1D beta,4.147168,7.287126,6.353246,40.955855
3,32 1D beta,0.549274,1.048439,0.284057,0.843553
4,512 1D beta,27.492978,21.539853,18.827903,423.382786
5,64 1D beta,0.444442,1.680563,0.829477,2.091086

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta,59.820486,16.050534,8.561255,8172.46462
1,128 1D beta,0.344017,1.054765,0.483739,2.303183
2,256 1D beta,1.861686,2.282814,1.267411,38.133129
3,32 1D beta,0.15886,0.083615,0.0602,0.103365
4,512 1D beta,14.630383,6.185702,2.8343,311.110778
5,64 1D beta,0.082904,0.261148,0.130873,0.271662


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta,0.062459,3.008671,3.00868,0.942094
1,128 1D beta,0.0,1.967492,1.96746,0.42832
2,256 1D beta,0.0,2.743754,2.743734,0.835049
3,32 1D beta,0.0,1.022184,1.022193,0.77826
4,512 1D beta,0.0,2.418802,2.418797,0.895738
5,64 1D beta,0.0,1.746617,1.746593,0.492457

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta,0.14676,13.074842,13.074844,0.081472
1,128 1D beta,0.0,3.594856,3.594864,0.257692
2,256 1D beta,0.0,8.847186,8.847188,0.192268
3,32 1D beta,0.0,3.959924,3.959923,1.874744
4,512 1D beta,0.0,5.694884,5.694893,0.117031
5,64 1D beta,0.0,3.667326,3.667337,0.736606


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta,0.0,2e-06,2e-06,1e-06
1,128 1D beta,0.0,6.7e-05,6.7e-05,7.6e-05
2,256 1D beta,0.0,2.4e-05,2.4e-05,1.6e-05
3,32 1D beta,0.0,0.000182,0.000182,0.001122
4,512 1D beta,0.0,6e-06,6e-06,4e-06
5,64 1D beta,0.0,0.000194,0.000194,0.000302

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta,1e-06,0.0,0.0,0.0
1,128 1D beta,0.0,1.7e-05,1.7e-05,8e-06
2,256 1D beta,0.0,2e-06,2e-06,1e-06
3,32 1D beta,0.0,0.000143,0.000143,0.000286
4,512 1D beta,0.0,0.0,0.0,0.0
5,64 1D beta,0.0,3.1e-05,3.1e-05,4.7e-05


### Mixed distributions

In [9]:
problemset_names = [
    "32 1D gaussian|uniform|gamma|beta|cauchy",
    "64 1D gaussian|uniform|gamma|beta|cauchy",
    "128 1D gaussian|uniform|gamma|beta|cauchy",
    "256 1D gaussian|uniform|gamma|beta|cauchy",
    "512 1D gaussian|uniform|gamma|beta|cauchy",
    "1024 1D gaussian|uniform|gamma|beta|cauchy"
]

mixed_results = run_experiment(time_precision_suite, problemset_names, solvers)
display_all_metrics(mixed_results, ['time', 'cost_rerr', 'coupling_avg_err'])

Data saved to ./datasets/1D/32_beta_cauchy_gamma_gaussian_uniform.pkl
Data saved to ./datasets/1D/64_beta_cauchy_gamma_gaussian_uniform.pkl
Data saved to ./datasets/1D/128_beta_cauchy_gamma_gaussian_uniform.pkl
Data saved to ./datasets/1D/256_beta_cauchy_gamma_gaussian_uniform.pkl
Data saved to ./datasets/1D/512_beta_cauchy_gamma_gaussian_uniform.pkl
Data saved to ./datasets/1D/1024_beta_cauchy_gamma_gaussian_uniform.pkl


Running experiments:   0%|          | 0/1080 [00:00<?, ?it/s]

  result_code_string = check_result(result_code)
Running experiments: 100%|██████████| 1080/1080 [02:17<00:00,  7.88it/s] 


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta_cauchy_gamma_gaussian_uniform,406.074032,956.974667,636.087555,24.075163
1,128 1D beta_cauchy_gamma_gaussian_uniform,1.68178,37.910816,27.283237,8.964582
2,256 1D beta_cauchy_gamma_gaussian_uniform,9.244806,113.722486,75.028752,1.532115
3,32 1D beta_cauchy_gamma_gaussian_uniform,0.458626,3.254438,2.237498,1.129831
4,512 1D beta_cauchy_gamma_gaussian_uniform,48.557019,291.766922,229.134025,4.142578
5,64 1D beta_cauchy_gamma_gaussian_uniform,0.579997,11.088487,8.691046,2.313124

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta_cauchy_gamma_gaussian_uniform,213.857592,999.039462,585.331071,0.947004
1,128 1D beta_cauchy_gamma_gaussian_uniform,0.272807,28.385865,18.785847,10.516851
2,256 1D beta_cauchy_gamma_gaussian_uniform,2.839813,95.942678,62.935578,0.238632
3,32 1D beta_cauchy_gamma_gaussian_uniform,0.173692,2.132849,1.660951,0.346897
4,512 1D beta_cauchy_gamma_gaussian_uniform,23.00632,204.157143,161.285561,0.300524
5,64 1D beta_cauchy_gamma_gaussian_uniform,0.101608,9.715593,6.915586,1.199475


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta_cauchy_gamma_gaussian_uniform,0.00108,0.052668,0.051279,1.0
1,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.166721,0.166009,0.967726
2,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.268353,0.267678,1.0
3,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.013611,0.013355,0.973434
4,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.432557,0.430175,1.0
5,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.053555,0.053187,0.947292

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta_cauchy_gamma_gaussian_uniform,0.007246,0.172551,0.171943,0.0
1,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.802229,0.80089,0.099628
2,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,1.548511,1.548593,0.0
3,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.041563,0.041364,0.060431
4,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,2.130267,2.124605,0.0
5,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.177908,0.17794,0.173488


Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta_cauchy_gamma_gaussian_uniform,0.0,2e-06,2e-06,1e-06
1,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,7.7e-05,7.7e-05,6.6e-05
2,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,2.5e-05,2.5e-05,1.5e-05
3,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.000241,0.000241,0.001357
4,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,7e-06,7e-06,4e-06
5,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.000196,0.000196,0.00028

Unnamed: 0,dataset,pot-lp,ott-jax-sinkhorn,jax-sinkhorn,optax-grad-ascent
0,1024 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.0,0.0,0.0
1,128 1D beta_cauchy_gamma_gaussian_uniform,0.0,1.2e-05,1.2e-05,6e-06
2,256 1D beta_cauchy_gamma_gaussian_uniform,0.0,1e-06,1e-06,0.0
3,32 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.000136,0.000136,0.000182
4,512 1D beta_cauchy_gamma_gaussian_uniform,0.0,0.0,0.0,0.0
5,64 1D beta_cauchy_gamma_gaussian_uniform,0.0,4.6e-05,4.6e-05,3.4e-05
