### Add uot package to path

In [None]:
import sys
import os

sibling_path = os.path.abspath(os.path.join(os.getcwd(), '..', '.'))

if sibling_path not in sys.path:
    sys.path.insert(0, sibling_path)


# OT experiments

Configure jax

In [None]:
import jax
jax.config.update("jax_enable_x64", True)

all necessary imports

In [None]:
from uot.algorithms.sinkhorn import jax_sinkhorn
from uot.algorithms.gradient_ascent import gradient_ascent
from uot.algorithms.lbfgs import lbfgs_ot
from uot.algorithms.lp import pot_lp
from uot.core.experiment import run_experiment
from uot.core.suites import time_precision_experiment

Define solvers and their params

In [None]:
epsilon_kwargs = [
    {'epsilon': 100},
    {'epsilon': 10},
    {'epsilon': 1},
    {'epsilon': 1e-1},
    {'epsilon': 1e-3},
    {'epsilon': 1e-6},
    {'epsilon': 1e-9},
]

solvers = {
    'pot-lp': (pot_lp, []),
    'lbfgs': (lbfgs_ot, epsilon_kwargs),
    'jax-sinkhorn': (jax_sinkhorn, epsilon_kwargs),
    'grad-ascent': (gradient_ascent, [ # grad ascent works really bad for big regularizations
                                    {'epsilon': 1},
                                    {'epsilon': 1e-1},
                                    {'epsilon': 1e-3},
                                    {'epsilon': 1e-6},
                                    # {'epsilon': 1e-9},
                                    ]),
}

# algorithms that use jax jit 
jit_algorithms = [
    'jax-sinkhorn', 'optax-grad-ascent', 'lbfgs'
]


Define problemset:

In [None]:
problemset_names = [
    ('distribution', "gamma", 32),
    ('distribution', "gamma", 64),
    ('distribution', "gamma", 256),
    ('distribution', "gamma", 512),
    ('distribution', "gamma", 1024),
    ('distribution', "gamma", 2048),

    ('distribution', "gaussian", 32),
    ('distribution', "gaussian", 64),
    ('distribution', "gaussian", 256),
    ('distribution', "gaussian", 512),
    ('distribution', "gaussian", 1024),
    ('distribution', "gaussian", 2048),

    ('distribution', "beta", 32),
    ('distribution', "beta", 64),
    ('distribution', "beta", 256),
    ('distribution', "beta", 512),
    ('distribution', "beta", 1024),
    ('distribution', "beta", 2048),

    ('distribution', "gaussian|gamma|beta|cauchy", 32),
    ('distribution', "gaussian|gamma|beta|cauchy", 64),
    ('distribution', "gaussian|gamma|beta|cauchy", 128),
    ('distribution', "gaussian|gamma|beta|cauchy", 256),
    ('distribution', "gaussian|gamma|beta|cauchy", 512),
    ('distribution', "gaussian|gamma|beta|cauchy", 1024),
    ('distribution', "gaussian|gamma|beta|cauchy", 2048),

]

Run experiment:

In [None]:
df = run_experiment(experiment=time_precision_experiment, 
                    problemsets_names=problemset_names,
                    solvers=solvers,
                    jit_algorithms=jit_algorithms,
                    folds=1)

Save data:

In [None]:
df.to_csv("ot_experiments.csv")