In [1]:
import numpy as np

Fast dual proximal gradient algorithm:

https://web.iem.technion.ac.il/images/user-files/becka/papers/40.pdf

http://www.seas.ucla.edu/~vandenbe/236C/lectures/dualproxgrad.pdf

Nesterov’s optimal gradient methods:

https://www2.cs.uic.edu/~zhangx/teaching/agm.pdf

In [115]:
from numpy.linalg import norm

def FDPG_proj_on_intersection(x0, proj_list, max_iter=10000):
    """Projection onto the intersection of convex sets"""
    N = len(proj_list)
    shape = x0.shape
    w = np.broadcast_to(x0, [N, *shape]).copy()
    y = np.broadcast_to(x0, [N, *shape]).copy()
    y_old = y.copy()
    u_old = x0.copy()
    L = N + 0.1
    t = 1.
    for k in range(max_iter):
        u = x0 + np.sum(w, axis=0)
        for i in range(N):
            y[i,:] = w[i,:] - 1/L*(u - proj_list[i](u - L*w[i,:]))
        t_old = t
        t = 0.5*(1 + np.sqrt(1 + 4*t**2)) # reduces to the usual proximal gradient when t=1
        w = y + (t_old - 1)/t*(y - y_old)
        y_old = y
        if norm(u_old - u)/norm(u) < 1e-12:
            print("Converged in %d iterations!" % (k))
            break
        u_old = u
    return u

Projection onto the [Birkhoff polytope](https://en.wikipedia.org/wiki/Birkhoff_polytope) (set of the [doubly stochastic matrices](https://en.wikipedia.org/wiki/Doubly_stochastic_matrix)):

\begin{aligned}
& \underset{X}{\text{min}} & & \frac{1}{2}\|X - Y\|^2_F \\
& \text{s.t.} & & X1 = 1, X^T1 = 1, X_{ij} \geq 0
\end{aligned} 

In [116]:
def projection_RC1(Y):
    n = Y.shape[0]
    I = np.eye(n)
    M = 1/n*I + np.sum(Y)/n/n*I - 1/n*Y
    return Y + np.sum(M, axis=1, keepdims=True) - 1/n*np.sum(Y, axis=0, keepdims=True)

In [117]:
n = 4

np.random.seed(0)

Y = np.random.randn(n,n)

In [118]:
proj_list = [lambda X: projection_RC1(X), lambda X: np.maximum(0, X)]

X_proj = FDPG_proj_on_intersection(Y, proj_list)

Converged in 507 iterations!


In [119]:
print(X_proj)

[[ 2.75337594e-01 -1.61065605e-12  1.63342708e-01  5.61319698e-01]
 [ 6.22075049e-01 -2.86815016e-12  3.77924952e-01 -2.86809465e-12]
 [ 2.86846935e-12  5.61319698e-01  2.86820567e-12  4.38680302e-01]
 [ 1.02587358e-01  4.38680302e-01  4.58732341e-01 -1.61071156e-12]]


In [120]:
print(np.sum(X_proj, axis=0), np.sum(X_proj, axis=1))

[1. 1. 1. 1.] [1. 1. 1. 1.]
