In [None]:
from heapq import heapify, heappop
from collections import deque
from timeit import repeat
import numpy as np
import pandas as pd

In [None]:
def proj_simplex(vec, rad=1):
    ''' Held, M., Wolfe, P. and Crowder, H.P., 1974.
    Validation of subgradient optimization.
    Mathematical programming, 6(1), pp.62-88. '''
    assert all((isinstance(vec, np.ndarray), vec.ndim == 1, len(vec) >= 1)), (
        'imput must be a 1-d non empty numpy array')
    sorted_vec = np.sort(vec)[::-1]
    cummeans = 1 / np.arange(1, len(vec) + 1) * (np.cumsum(sorted_vec) - rad)
    rho = max(np.where(sorted_vec > cummeans)[0])
    return np.maximum(vec - cummeans[rho], 0)

def proj_simplex_max_heap(vec, rad=1):
    ''' Van Den Berg, E. and Friedlander, M.P., 2009.
    Probing the Pareto frontier for basis pursuit solutions.
    SIAM Journal on Scientific Computing, 31(2), pp.890-912. '''
    assert all((isinstance(vec, np.ndarray), vec.ndim == 1, len(vec) >= 1)), (
        'imput must be a 1-d non empty numpy array')
    if len(vec) == 1:
        proj = np.array([1])
    else:
        largest = deque([])
        opp_vec = [-ele for ele in vec]  # heapq supports MIN-heap...
        heapify(opp_vec)
        crit = -1
        while crit < 0:
            root = heappop(opp_vec)
            largest.append(-root)
            # TODO: running computation of the mean
            crit = (sum(largest) - rad) / len(largest) + root
        largest.pop()
        mean = (sum(largest) - rad) / len(largest)
        proj = np.maximum(vec - mean, 0)
    return proj

def proj_simplex_quick(vec, rad=1):
    ''' Kiwiel, K.C., 2008.
    Breakpoint searching algorithms for the continuous quadratic knapsack problem.
    Mathematical Programming, 112(2), pp.473-491.'''
    assert all((isinstance(vec, np.ndarray), vec.ndim == 1, len(vec) >= 1)), (
        'imput must be a 1-d non empty numpy array')
    vals = vec.copy()
    length = 0
    som = -rad
    while vals.size > 0:
        val = np.random.choice(vals)  # choose the median...
        high = vals[vals > val]
        low = vals[vals < val]
        n_eq = len(vals) - len(high) - len(low)
        som_ = som + n_eq * val + high.sum()
        length_ = length + n_eq + len(high)
        if som_ / length_ < val:
            vals = low.copy()
            length = length_
            som = som_
        else:
            vals = high.copy()
    return np.maximum(vec - som / length, 0)

def proj_simplex_active_set(vec, rad=1):
    ''' Michelot, C., 1986.
    A finite algorithm for finding the projection of a point onto the canonical simplex of R^n.
    Journal of Optimization Theory and Applications, 50(1), pp.195-200. '''
    assert all((isinstance(vec, np.ndarray), vec.ndim == 1, len(vec) >= 1)), (
        'imput must be a 1-d non empty numpy array')
    vals = vec.copy()
    length_old = len(vals)
    mean = (sum(vals) - rad) / length_old
    lengths_diff = 1
    while lengths_diff >= 1:
        vals = vals[vals > mean]
        length = len(vals)
        mean = (sum(vals) - rad) / length
        lengths_diff = abs(length - length_old)
        length_old = length
    return np.maximum(vec - mean, 0)

def proj_simplex_condat(vec, rad=1):
    ''' Condat, L., 2016.
    Fast projection onto the Simplex and the l1 Ball.
    Mathematical Programming, 158(1), pp.575-585. '''
    assert all((isinstance(vec, np.ndarray), vec.ndim == 1, len(vec) >= 1)), (
        'imput must be a 1-d non empty numpy array')
    largest = [vec[0]]
    reservoir = []
    mean = largest[0] - rad
    for val in vec[1:]:
        if val > mean:
            mean = mean + (val - mean) / (len(largest) + 1)
            if mean > val - rad:
                largest.append(val)
            else:
                reservoir.extend(largest)
                largest = [val]
                mean = val - rad
    if reservoir:
        for val in reservoir:
            if val > mean:
                largest.append(val)
                mean = mean + (val - mean) / len(largest)
    length_old = len(largest)
    lengths_diff = 1
    while lengths_diff >= 1:
        for ind, val in enumerate(largest):
            if val <= mean:
                del largest[ind]
                mean = mean + (mean - val) / len(largest)
        length = len(largest)
        lengths_diff = abs(length - length_old)
        length_old = length
    return np.maximum(vec - mean, 0)

projs = [
    proj_simplex,
#     proj_simplex_max_heap,
    proj_simplex_quick,
    proj_simplex_active_set,
    proj_simplex_condat]

# Correctness

In [None]:
metric = lambda x: np.linalg.norm(x, float('inf'))
DIM = 10000
RAD = 2
imp = np.random.randn(DIM)
ref = proj_simplex(imp, RAD)

[
    [(ref < 0).sum(), ref.sum()],
    [metric(ref - proj(imp, RAD)) for proj in projs[1:]]
]

# Performances

In [None]:
%%timeit
proj_simplex(np.random.randn(DIM), RAD)

In [None]:
%%timeit
proj_simplex_max_heap(np.random.randn(DIM), RAD)

In [None]:
%%timeit
proj_simplex_quick(np.random.randn(DIM), RAD)

In [None]:
%%timeit
proj_simplex_active_set(np.random.randn(DIM), RAD)

In [None]:
%%timeit
proj_simplex_condat(np.random.randn(DIM), RAD)

In [None]:
def gen_expes_dims(exp_start, exp_stop, n_dims, n_runs):
    dims = np.logspace(exp_start, exp_stop, n_dims).astype(int)
    return {
        dim: [np.random.randn(dim) for _ in range(n_runs)]
        for dim in dims}

def partial_stack(frame, cols_to_stack, names):
    assert set(cols_to_stack) <= set(frame)
    assert isinstance(names, list) and len(names) == 2
    cols = frame.columns
    index = list(cols[~cols.isin(cols_to_stack)])
    frame = frame.set_index(index).stack().reset_index()
    frame.columns = index + names
    return frame

def time_one_func(func, expes, n_repeats=5, n_execs=10):
    times = {
        dim: [
            repeat(lambda: func(imput), number=n_execs, repeat=n_repeats)
            for imput in imputs]
        for dim, imputs in expes.items()}
    times_frame = pd.concat([
        pd.DataFrame(times_).stack().reset_index(drop=True).rename(dim)
        for dim, times_ in times.items()], axis=1)
    return times_frame.describe()

In [None]:
exps = gen_expes_dims(0, 6, 15, 10)
timings = {
    proj.__name__: time_one_func(proj, exps.copy())
    for proj in projs}

median_times = pd.DataFrame({
    name: times.loc['50%', :] for name, times in timings.items()})
median_times.plot(figsize=(20, 15), grid=True)