In [None]:
import sys
import os

sibling_path = os.path.abspath(os.path.join(os.getcwd(), '..', '.'))

if sibling_path not in sys.path:
    sys.path.insert(0, sibling_path)

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
jax.config.update("jax_enable_x64", True)

import ot
import numpy as np

import multiprocessing as mp
mp.set_start_method('spawn')

from sklearn.datasets import load_digits

from uot.algorithms.sinkhorn import jax_sinkhorn
from uot.algorithms.gradient_ascent import gradient_ascent
from uot.algorithms.lbfgs import lbfgs_ot
from uot.algorithms.lp import pot_lp

from uot.mnist_classification.count_pairwise_distances import compute_distances_for_all_solvers

Prepare data

In [None]:
digits = load_digits()
X, _ = digits.data, digits.target
X = X / X.sum(axis=1).reshape(X.shape[0],1) + 1e-12

row, col = np.arange(8), np.arange(8)
row, col = np.meshgrid(row, col)
points = np.vstack([coordinate.ravel() for coordinate in [row, col]]).T
C = ot.dist(points, points).astype('float64')
C /= C.max()


Prepare solvers

In [None]:
solvers = {
    'sinkhorn': jax_sinkhorn,
    'grad-ascent': gradient_ascent,
    'lbfs': lbfgs_ot,
    'lp': pot_lp
}

epsilons = [1e-1, 1e-2]

Run computation

In [None]:
compute_distances_for_all_solvers(X, C, solvers.keys(), epsilons=epsilons, num_processes=6, max_iter=10000,
                                  export_folder="../classification")

Computing all pairwise distances:   1%|          | 68884/11295942 [01:40<3:40:57, 846.84it/s]