In [3]:
from admm_lasso import admm_lasso, plot
import casadi as ca
import numpy as np
from scipy import sparse

ref: https://www.researchgate.net/publication/299465495_An_Augmented_Lagrangian_Based_Algorithm_for_Distributed_NonConvex_Optimization

In [6]:
def create_prox_func(Nx,f_func):
    '''
    Not used here.
    '''
    x = ca.SX.sym("x",Nx)
    v = ca.SX.sym("v",Nx)
    lambda_ = ca.SX.sym("lambda",1)
    
    f = f_func(x)
    f_prox = f + 1/2 * lambda_ * ca.norm_2(x - v) ** 2
    f_prox_func = ca.Function("f_prox_func", [lambda_, x, v], [f_prox])

    p = ca.vertcat(v,lambda_)
    
    # Define proximal solver
    solver_opt = {}
    solver_opt['print_time'] = False
    solver_opt['ipopt'] = {
        'max_iter': 500,
        'print_level': 1,
        'acceptable_tol': 1e-6,
        'acceptable_obj_change_tol': 1e-6
    }

    nlp = {'x':x, 'f':f_prox, 'p': p}
    solver = ca.nlpsol('solver', 'ipopt', nlp, solver_opt)
    return solver

$\min _{y_{i}} f_{i}\left(y_{i}\right)+\lambda_{i}^{\top} A_{i} y_{i}+\frac{p}{2}\left\|A_{i}\left(y_{i}-x_{i}\right)\right\|_{2}^{2} \quad$ s.t. $\quad h_{i}\left(y_{i}\right) \leq 0$

In [59]:
def create_subproblem(fi_func, Ai, rho):
    N_lambda_i, N_yi = np.shape(Ai)
    yi = ca.SX.sym("yi",N_yi)
    xi = ca.SX.sym("xi",N_yi)
    lambda_i = ca.SX.sym("lambda_i",N_lambda_i)
    
    fi = fi_func(yi) + lambda_i.T @ Ai @ yi + rho/2 * ca.norm_2(Ai @ (yi - xi))**2 
    p = ca.vertcat(lambda_i, xi)
    g = yi
    # Define proximal solver
    solver_opt = {}
    solver_opt['print_time'] = False
    solver_opt['ipopt'] = {
        'max_iter': 500,
        'print_level': 1,
        'acceptable_tol': 1e-6,
        'acceptable_obj_change_tol': 1e-6
    }

    nlp = {'x':yi, 'g':g, 'f':fi, 'p': p}
    solver = ca.nlpsol('solver', 'ipopt', nlp, solver_opt)
    return solver

In [60]:
def create_QP_problem(A_list,rho):
    N = len(A_list)
    x_plus_list = []
    y_i_list = []
    lambda_plus_list = []
    obj = 0
    g = 0
    for i in range(N):
        Ai = A_list[i]
        
        N_lambda_i, N_yi = np.shape(Ai)
        yi = ca.SX.sym("yi",N_yi)
        x_plus = ca.SX.sym("xi",N_yi)
        lambda_plus = ca.SX.sym("lambda_i",N_lambda_i)
        
        x_plus_list += [x_plus]
        y_i_list += [yi] 
        lambda_plus_list += [lambda_plus]
    
        obj += rho/2 * ca.norm_2(Ai @ (yi - x_plus))**2  - lambda_plus.T @ Ai @ yi
        g += Ai @ x_plus
    x = ca.vertcat(*x_plus_list)
    p = ca.vertcat(*(lambda_plus_list + y_i_list))
    # Define proximal solver
    solver_opt = {}
    solver_opt['print_time'] = False
    solver_opt['ipopt'] = {
        'max_iter': 500,
        'print_level': 1,
        'acceptable_tol': 1e-6,
        'acceptable_obj_change_tol': 1e-6
    }

    nlp = {'x':x, 'g':g, 'f':obj, 'p': p}
    solver = ca.nlpsol('solver', 'ipopt', nlp, solver_opt)
    return solver    

### Example 1
$\min _{x} x_{1} \cdot x_{2} \quad$ s.t. $\quad x_{1}-x_{2}=0$

Numerical problem with divergent $\lambda^{+} = - 2 \lambda$

In [76]:
eps = 1e-5
rho = 0.75
N_itermax = 100
A_list = []
fi_func_list = []

A = ca.DM([[1,-1]])
A_list += [A]
N = len(A_list)
b = ca.DM([0])

Nx = 2
x = ca.SX.sym("x",Nx)
fi_func = ca.Function("fi_func", [x], [x[0] * x[1]])
fi_func_list += [fi_func]

subsolver_list = []
# Define subproblem solvers
for i in range(N):
    Ai = A_list[i]
    fi_func = fi_func_list[i]
    subsolver_list += [create_subproblem(fi_func, Ai, rho)]
QP_solver = create_QP_problem(A_list, rho)
# Initial guess
xi_list = []
yi_list = []
lambda_i_list = []
lbx_list = []
ubx_list = []
for i in range(N):
    Ai = A_list[i]
    N_lambda_i, N_xi = np.shape(Ai)
    
    xi = np.random.randn(N_xi,1).flatten().tolist()
    lambda_i = np.random.randn(N_lambda_i,1).flatten().tolist()
    
    xi_list += [xi]
    lambda_i_list += [lambda_i]
    lbx_list += [[-ca.inf] * N_xi]
    ubx_list += [[ca.inf] * N_xi]
    yi_list += [[0] * N_xi]

nl_sub = {}

nl_qp = {}
nl_qp['lbg'] = b
nl_qp['ubg'] = b
nl_qp['lbx'] = sum(lbx_list,[])
nl_qp['ubx'] = sum(ubx_list,[])

# Track solution
yi_sol_list = []
x_plus_sol_list = []
lambda_i_sol_list = []
yi_sol_list += yi_list
x_plus_sol_list += [sum(xi_list,[])]
lambda_i_sol_list = [sum(lambda_i_list,[])]
for i in range(N_itermax):
    sum_Ay = 0
    for j in range(N):
        Ai = A_list[j]
        N_lambda_i, N_xi = np.shape(Ai)
        
        nl_sub['x0'] = yi_list[j]
        nl_sub['lbx'] = lbx_list[j]
        nl_sub['ubx'] = ubx_list[j]
        nl_sub['p'] = lambda_i_list[j] + xi_list[j]
        
        solver_subproblem = subsolver_list[j]
        yi_sol = solver_subproblem(**nl_sub)
        yi_list[j] = yi_sol['x'].full().flatten().tolist()
        yi_sol_list += [yi_list[j]]
        
        sum_Ay += Ai @ yi_sol['x']
        lambda_i_list[j] = (ca.DM(lambda_i_list[j]) + rho * Ai @ (ca.DM(yi_list[j]) - ca.DM(xi_list[j]))).full().flatten().tolist()
        lambda_i_sol_list += [lambda_i_list[j]]
    if ca.norm_1(sum_Ay - b) <= eps:
        break
    
    nl_qp['x0'] = sum(xi_list,[])    #  2D list to 1D
    nl_qp['p'] = sum(lambda_i_list,[]) + sum(xi_list,[])
    xi_sol = solver_subproblem(**nl_qp)
    # Update x_plus
    pos = 0
    for j in range(N):
        
        xi_plus_list = xi_sol['x'].full().flatten().tolist()
        list_len = len(xi_list[j])
        xi_list[j] = xi_plus_list[pos:pos+list_len]
        pos += list_len
    
    x_plus_sol_list += [xi_plus_list]

In [77]:
yi_sol_list

[[0, 0],
 [-2.637456625432172, 2.637456625432172],
 [5.274913250864344, -5.274913250864344],
 [-10.549826501728688, 10.549826501728688],
 [21.099653003457377, -21.099653003457377],
 [-42.19930600691475, 42.19930600691475],
 [84.3986120138295, -84.3986120138295],
 [-168.79722402765896, 168.79722402765896],
 [337.59444805531786, -337.5944480553179],
 [-675.1888961106363, 675.1888961106362],
 [1350.3777922212726, -1350.3777922212726],
 [-2700.7555844425456, 2700.755584442546],
 [5401.511168885094, -5401.511168885094],
 [-10803.022337770188, 10803.022337770188],
 [21606.044675540375, -21606.044675540375],
 [-43212.08935108075, 43212.08935108075],
 [86424.1787021615, -86424.1787021615],
 [-172848.357404323, 172848.357404323],
 [345696.714808646, -345696.714808646],
 [-691393.429617292, 691393.429617292],
 [1382786.859234584, -1382786.859234584],
 [-2765573.718469168, 2765573.718469168],
 [5531147.436938336, -5531147.436938336],
 [-11062294.873876672, 11062294.873876672],
 [22124589.74775334

In [78]:
x_plus_sol_list

[[1.1052711286698296, 2.6831826926406093],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0.0],
 [0.0, 0

In [79]:
lambda_i_sol_list

[[0.13529463973800127],
 [-2.637456625432172],
 [5.274913250864344],
 [-10.549826501728688],
 [21.099653003457377],
 [-42.19930600691475],
 [84.3986120138295],
 [-168.79722402765893],
 [337.5944480553179],
 [-675.1888961106364],
 [1350.3777922212726],
 [-2700.755584442546],
 [5401.511168885095],
 [-10803.022337770188],
 [21606.044675540375],
 [-43212.08935108075],
 [86424.1787021615],
 [-172848.357404323],
 [345696.714808646],
 [-691393.429617292],
 [1382786.859234584],
 [-2765573.718469168],
 [5531147.436938336],
 [-11062294.873876672],
 [22124589.747753344],
 [-44249179.49550669],
 [88498358.99101338],
 [-176996717.9820267],
 [353993435.9640532],
 [-707986871.9281065],
 [1415973743.856213],
 [-2831947487.712426],
 [5663894975.424852],
 [-11327789950.849705],
 [22655579901.69939],
 [-45311159803.39877],
 [90622319606.7975],
 [-181244639213.5949],
 [362489278427.18964],
 [-724978556854.3784],
 [1449957113708.7554],
 [-2899914227417.5093],
 [5799828454835.012],
 [-11599656909670.012],
 