In [None]:
import numpy as np
import pandas as pd
from proximal_operators import (
    proj_simplex, proj_simplex_quick,
    proj_simplex_active_set, proj_simplex_condat)
from utils import gen_sym_psd, time_one_func, gen_expes_dims

In [None]:
projs = [
    proj_simplex,
    proj_simplex_quick,
    proj_simplex_active_set,
    proj_simplex_condat
]

# Correctness

In [None]:
metric = lambda x: np.linalg.norm(x, float('inf'))
DIM = 10000
RAD = 2
imp = np.random.randn(DIM)
ref = proj_simplex(imp, RAD)

[
    [(ref < 0).sum(), ref.sum()],
    [metric(ref - proj(imp, RAD)) for proj in projs[1:]]
]

# Performances

In [None]:
%%timeit
proj_simplex(np.random.randn(DIM), RAD)

In [None]:
%%timeit
proj_simplex_quick(np.random.randn(DIM), RAD)

In [None]:
%%timeit
proj_simplex_active_set(np.random.randn(DIM), RAD)

In [None]:
%%timeit
proj_simplex_condat(np.random.randn(DIM), RAD)

In [None]:
exps = gen_expes_dims(0, 6, 15, 10)
timings = {
    proj.__name__: time_one_func(proj, exps.copy())
    for proj in projs}

median_times = pd.DataFrame({
    name: times.loc['50%', :] for name, times in timings.items()})
median_times.plot(figsize=(20, 15), grid=True)