In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ot.greenkhorn.ot import OT as greenkhorn
from ot.apdamd.ot import OT as apdamd
from ot.tests.sample_problem import *
from ot.ot import cost
import timeit

In [8]:
n_iters = [10, 100, 1000, 10000, 100000]
my_problem = sample_problem(6)
eps = 4 * np.log(2)
timings_greenkhorn = []
timings_apdamd = []
std_timings_greenkhorn = []
std_timings_apdamd = []
transport_plans_greenkhorn = []
transport_plans_apdamd = []
costs_greenkhorn = []
costs_apdamd = []
for n_iter in n_iters:
    transport_plans_greenkhorn.append(greenkhorn(None, *my_problem[1:], eps=eps, iter_max=n_iter)[0])
    transport_plans_apdamd.append(apdamd(*my_problem, eps=eps, iter_max=n_iter)[0])
    costs_apdamd.append(cost(my_problem[1], transport_plans_apdamd[-1]))
    costs_greenkhorn.append(cost(my_problem[1], transport_plans_greenkhorn[-1]))
    timing_greenkhorn = %timeit -o -r 10 -n 2 greenkhorn( * my_problem, eps, iter_max=n_iter)
    timing_apdamd = %timeit -o -r 10 -n 2 apdamd( * my_problem, eps, iter_max=n_iter)
    timings_greenkhorn.append(np.mean(timing_greenkhorn.timings))
    std_timings_greenkhorn.append(np.std(timing_greenkhorn.timings))
    timings_apdamd.append(np.mean(timing_apdamd.timings))
    std_timings_apdamd.append(np.std(timing_apdamd.timings))

268 ms ± 13.4 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
213 ms ± 7.58 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
259 ms ± 3.61 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
223 ms ± 16.8 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
270 ms ± 10.9 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
218 ms ± 10.5 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
269 ms ± 3.44 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
227 ms ± 6.31 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
315 ms ± 3.94 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)
335 ms ± 6.92 ms per loop (mean ± std. dev. of 10 runs, 2 loops each)


In [9]:
costs_greenkhorn

[Array(0.89160837, dtype=float64),
 Array(0.89160848, dtype=float64),
 Array(0.89160848, dtype=float64),
 Array(0.89160848, dtype=float64),
 Array(0.89160848, dtype=float64)]

In [10]:
costs_apdamd

[Array(0.82340954, dtype=float64),
 Array(0.86017184, dtype=float64),
 Array(0.8698017, dtype=float64),
 Array(0.875011, dtype=float64),
 Array(0.87724696, dtype=float64)]

In [11]:
transport_plans_apdamd

[Array([[0.02542768, 0.03630895, 0.03190571, 0.03934938, 0.05336869,
         0.03317839],
        [0.03358976, 0.03031642, 0.04896419, 0.07180592, 0.04853549,
         0.0350032 ],
        [0.02884797, 0.02915187, 0.02624332, 0.03035604, 0.02843338,
         0.05156766],
        [0.04574484, 0.03417127, 0.05227405, 0.03296366, 0.03038488,
         0.0292658 ],
        [0.03591396, 0.03648228, 0.04477104, 0.08084175, 0.06524415,
         0.04753422],
        [0.06369135, 0.03586088, 0.1175258 , 0.04540515, 0.05302365,
         0.04681772]], dtype=float64),
 Array([[0.02757925, 0.0398853 , 0.03387303, 0.04221744, 0.05676026,
         0.03573348],
        [0.03496522, 0.03271577, 0.05004529, 0.07737039, 0.0519254 ,
         0.03569979],
        [0.03076347, 0.0323261 , 0.02833925, 0.0328643 , 0.03047974,
         0.0571462 ],
        [0.04914265, 0.03771263, 0.05573006, 0.03500948, 0.03125043,
         0.03096579],
        [0.0363882 , 0.03927063, 0.04375792, 0.08648605, 0.07004158,
    

In [12]:
transport_plans_greenkhorn

[Array([[0.02975439, 0.0450884 , 0.0369984 , 0.04563586, 0.05298275,
         0.04522874],
        [0.03594858, 0.03763171, 0.05578792, 0.08809108, 0.05153788,
         0.04524232],
        [0.02986651, 0.03257395, 0.02948416, 0.03223271, 0.02977038,
         0.05815053],
        [0.04975573, 0.04487792, 0.06443589, 0.03950931, 0.03270855,
         0.03869534],
        [0.03489276, 0.04221313, 0.04361416, 0.08494527, 0.06010857,
         0.06252665],
        [0.04831295, 0.03832026, 0.11008379, 0.0409789 , 0.04640855,
         0.04483978]], dtype=float64),
 Array([[0.0297544 , 0.0450884 , 0.0369984 , 0.04563587, 0.05298276,
         0.04522875],
        [0.03594859, 0.03763171, 0.0557879 , 0.08809109, 0.05153789,
         0.04524233],
        [0.02986651, 0.03257395, 0.02948416, 0.03223272, 0.02977039,
         0.05815053],
        [0.04975574, 0.04487793, 0.06443587, 0.03950932, 0.03270856,
         0.03869535],
        [0.03489276, 0.04221314, 0.04361415, 0.08494529, 0.06010858,
    