### Install uot package

In [None]:
!pip install --upgrade uot

# OT experiments

Configure jax

In [None]:
import jax
jax.config.update("jax_enable_x64", True)

Download dataset

In [None]:
from uot.core.dataset import download_dataset
download_dataset()

all necessary imports

In [None]:
from uot.algorithms.sinkhorn import jax_sinkhorn
from uot.algorithms.gradient_ascent import gradient_ascent
from uot.algorithms.lbfgs import lbfgs_ot
from uot.algorithms.lp import pot_lp
from uot.core.experiment import run_experiment
from uot.core.suites import time_precision_experiment

Define solvers and their params

In [None]:
epsilon_kwargs = [
    {'epsilon': 100},
    {'epsilon': 10},
    {'epsilon': 1},
    {'epsilon': 1e-1},
    {'epsilon': 1e-3},
    {'epsilon': 1e-6},
    {'epsilon': 1e-9},
]

solvers = {
    'pot-lp': (pot_lp, []),
    'lbfgs': (lbfgs_ot, epsilon_kwargs),
    'jax-sinkhorn': (jax_sinkhorn, epsilon_kwargs),
    'grad-ascent': (gradient_ascent, [ # grad ascent works really bad for big regularizations
                                    {'epsilon': 1},
                                    {'epsilon': 1e-1},
                                    {'epsilon': 1e-3},
                                    {'epsilon': 1e-6},
                                    # {'epsilon': 1e-9},
                                    ]),
}

# algorithms that use jax jit 
jit_algorithms = [
    'jax-sinkhorn', 'optax-grad-ascent', 'lbfgs'
]


Define problemset:

In [None]:
problemset_names = [
    (1, "gamma", 32),
    (1, "gamma", 64),
    (1, "gamma", 256),
    (1, "gamma", 512),
    (1, "gamma", 1024),
    (1, "gamma", 2048),

    (1, "gaussian", 32),
    (1, "gaussian", 64),
    (1, "gaussian", 256),
    (1, "gaussian", 512),
    (1, "gaussian", 1024),
    (1, "gaussian", 2048),

    (1, "beta", 32),
    (1, "beta", 64),
    (1, "beta", 256),
    (1, "beta", 512),
    (1, "beta", 1024),
    (1, "beta", 2048),

    (1, "gaussian|gamma|beta|cauchy", 32),
    (1, "gaussian|gamma|beta|cauchy", 64),
    (1, "gaussian|gamma|beta|cauchy", 128),
    (1, "gaussian|gamma|beta|cauchy", 256),
    (1, "gaussian|gamma|beta|cauchy", 512),
    (1, "gaussian|gamma|beta|cauchy", 1024),
    (1, "gaussian|gamma|beta|cauchy", 2048),

    (2, "WhiteNoise", 32),
    (2, "CauchyDensity", 32),
    (2, "GRFmoderate", 32),
    (2, "GRFrough", 32),
    (2, "GRFsmooth", 32),
    (2, "LogGRF", 32),
    (2, "LogitGRF", 32),
    (2, "MicroscopyImages", 32),
    (2, "Shapes", 32),
    (2, "ClassicImages", 64),

    (3, "3dmesh", 1024),
    (3, "3dmesh", 2048),
]

Run experiment:

In [None]:
df = run_experiment(experiment=time_precision_experiment, 
                    problemsets_names=problemset_names,
                    solvers=solvers,
                    jit_algorithms=jit_algorithms,
                    folds=1)

Save data:

In [None]:
df.to_csv("ot_experiments.csv")