<a href="https://colab.research.google.com/github/shaabhishek/cvxpylayer_qp/blob/main/Differentiable_QP_Layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install cvxpylayers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import numpy as np

import scipy
import cvxpy as cvx
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from cvxpylayers.jax import CvxpyLayer

In [3]:
D = jnp.diag
ccat = lambda x: jnp.concatenate(x, axis=1)
rcat = lambda x: jnp.concatenate(x, axis=0)
mv = lambda A,b: jnp.einsum("ij,j->i", A, b)
outer = jnp.outer

def make_problem_cvxpy(n_vars, n_eq, n_ineq):
    # Create a CVXPY problem.
    z = cvx.Variable(n_vars)

    L = cvx.Parameter((n_vars, n_vars))
    c = cvx.Parameter((n_vars))
    A = cvx.Parameter((n_eq, n_vars))
    b = cvx.Parameter((n_eq))
    G = cvx.Parameter((n_ineq, n_vars))
    h = cvx.Parameter((n_ineq))

    objective = cvx.Minimize( 0.5 * cvx.sum_squares(L @ z) + c @ z)
    # objective = cvx.Minimize( c @ z)
    constraints = [A @ z == b, G @ z <= h]
    prob = cvx.Problem(objective, constraints)
    assert prob.is_dpp()
    return prob, z, constraints, [L, c, A, b, G, h]

def make_problem_cvxpy_batch(n_vars, n_eq, n_ineq, n_batch):
    # Create a CVXPY problem.
    z = cvx.Variable((n_batch, n_vars))

    L = cvx.Parameter((n_batch, n_vars, n_vars))
    c = cvx.Parameter((n_batch, n_vars))
    A = cvx.Parameter((n_batch, n_eq, n_vars))
    b = cvx.Parameter((n_batch, n_eq))
    G = cvx.Parameter((n_batch, n_ineq, n_vars))
    h = cvx.Parameter((n_batch, n_ineq))

    # objective = cvx.Minimize( 0.5 * cvx.sum_squares(L @ z) + c @ z)
    cost = 0
    for i in range(n_batch):
        cost += 0.5 * cvx.sum_squares(L[i] @ z[i]) + c[i] @ z[i]
    objective = cvx.Minimize(cost)

    constraints = []
    for i in range(n_batch):
        constraints.append(A[i] @ z[i] == b[i])

    for i in range(n_batch):
        constraints.append(G[i] @ z[i] <= h[i])

    prob = cvx.Problem(objective, constraints)
    assert prob.is_dpp()
    return prob, z, constraints, [L, c, A, b, G, h]

def solve_problem_cvxpy(prob, x, constraints, params, data):
    for p, d in zip(params, data):
        p.value = np.array(d)

    prob.solve(solver=cvx.SCS, verbose=False, warm_start=True, eps=1e-6)
    return x.value, constraints[0].dual_value, constraints[1].dual_value

def make_problem_cvxpylayer(prob, x, params):
    cvxpylayer = CvxpyLayer(prob, parameters=params, variables=[x])
    return cvxpylayer

def solve_problem_cvxpylayer(cvxpylayer, data):
    return cvxpylayer(*data, solver_args={"eps": 1e-6})

@jax.jit
def compute_grads_manual(z_star, nu_star, lambda_star, data, grad_l):
    (L, c, A, b, G, h) = data
    Q = L @ L.T
    
    M = rcat(
        [
            ccat([Q, G.T, A.T]),
            ccat([D(lambda_star)@G, D(mv(G, z_star) - h), jnp.zeros((n_ineq,n_eq))]),
            ccat([A, jnp.zeros((n_eq,n_ineq)), jnp.zeros((n_eq,n_eq))])
            ])
    
    assert M.shape == (n_vars+n_eq+n_ineq, n_vars+n_eq+n_ineq)

    eq_37 = -jnp.linalg.solve( M.T, grad_l )
    d_z, d_lambda, d_nu = eq_37[:n_vars], eq_37[n_vars:n_vars+n_ineq], eq_37[n_vars+n_ineq:]
    
    grad_L = L @ (outer(d_z, z_star) + outer(z_star, d_z))
    grad_c = d_z
    grad_A = outer(d_nu, z_star) + outer(nu_star, d_z)
    grad_b = -d_nu
    grad_G = outer(lambda_star, d_z) + D(lambda_star) @ outer(d_lambda, z_star)
    grad_h = -mv(D(lambda_star), d_lambda)
    return grad_L, grad_c, grad_A, grad_b, grad_G, grad_h


Create data (i.e. parameters Q, c, A, b, G, h)

In [4]:
def get_data(n_vars, n_eq, n_ineq):
    Q = np.random.randn(n_vars, n_vars)
    # Q = np.zeros((n_vars, n_vars))
    Q = Q@Q.T
    assert np.all(np.linalg.eigvals(Q) >= 0)

    # L = np.linalg.cholesky(Q)
    L = scipy.linalg.sqrtm(Q)
    assert np.allclose(Q, L@L.T)
    c = np.random.randn(n_vars)
    A = np.random.randn(n_eq, n_vars)
    b = np.random.randn(n_eq)
    # G = np.random.randn(n_ineq, n_vars)
    # h = np.random.randn(n_ineq)

    G = np.eye(n_vars)
    h = np.zeros(n_vars)
    return (L, c, A, b, G, h)

def get_data_batch(n_vars, n_eq, n_ineq, n_batch):
    Q = np.random.randn(n_batch, n_vars, n_vars)
    Q = np.einsum("bij,bkj->bik", Q, Q)
    assert np.all(np.linalg.eigvals(Q[0]) >= 0)
    L = np.stack([scipy.linalg.sqrtm(_Q) for _Q in Q], axis=0)

    c = np.random.randn(n_batch, n_vars)
    A = np.random.randn(n_batch, n_eq, n_vars)
    b = np.random.randn(n_batch, n_eq)
    # G = np.random.randn(n_batch, n_ineq, n_vars)
    # h = np.random.randn(n_batch, n_ineq)

    G = np.stack([np.eye(n_vars) for _ in range(n_batch)], axis=0)
    h = np.zeros((n_batch, n_vars))

    return (L, c, A, b, G, h)

In [None]:
def l_cvxpylayer(data, prob, x, params):
    cvxpylayer = make_problem_cvxpylayer(prob, x, params)
    z_star_layer, = solve_problem_cvxpylayer(cvxpylayer, data)
    return jnp.sum(z_star_layer), z_star_layer

def l_manual(data, prob, x, constraints, params):
    z_star, nu_star, lambda_star = solve_problem_cvxpy(prob, x, constraints, params, data)
    # print(z_star, nu_star, lambda_star)
    # print(lambda_star)
    grad_l = jnp.concatenate([jnp.ones_like(z_star), jnp.zeros_like(lambda_star), jnp.zeros_like(nu_star)])
    grads_manual = compute_grads_manual(z_star, nu_star, lambda_star, data, grad_l)
    l = z_star.sum()
    return (l, z_star), grads_manual


n_vars = 60
n_eq = 10
n_ineq = n_vars
data = get_data(n_vars, n_eq, n_ineq)
data_batch = get_data_batch(n_vars, n_eq, n_ineq, 20)
prob, x, constraints, params = make_problem_cvxpy(n_vars, n_eq, n_ineq)

l_cvxpylayer_value_and_grad = jax.value_and_grad(l_cvxpylayer, has_aux=True)
(_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data, prob, x, params)
(_, z_star), grads_manual = l_manual(data, prob, x, constraints, params)
l_cvxpylayer_value_and_grad(data_batch, prob, x, params)


Compare solutions

In [6]:
print(z_star)
print(z_star_layer)

assert np.allclose(z_star, z_star_layer, atol=1e-3), np.sum(np.abs(z_star - z_star_layer))

[-4.92095910e-07 -9.03358617e-09 -3.02948103e-02 -5.77438299e-02
 -3.78558184e-06 -1.40480123e-06  6.35821865e-07  2.40269849e-06
 -3.37088648e-06 -2.87031143e-01  1.44705499e-06 -3.64549815e-06
 -1.34774959e-01 -4.66068713e-08 -2.05358857e-01 -3.52979741e-07
  3.14795604e-07 -4.74028675e-06  2.05294454e-06 -2.23238890e-06
 -1.66936156e-06  8.31118822e-07 -6.62853057e-02 -6.08125066e-02
 -2.94993391e-07 -2.65497135e-06  1.80870588e-06 -1.69735501e-07
 -1.56534947e-02  5.52173548e-07 -1.47002980e-01 -1.90311851e-01
 -1.02680867e-06 -6.70239021e-02 -1.95052416e-06 -2.50558895e-01
 -2.02113804e-06  1.13746806e-07 -1.73954006e-06 -1.36995157e-01
 -8.80179502e-02 -8.54248521e-04  1.45600071e-06 -4.50688332e-02
 -1.99705499e-01 -1.27820008e-01 -8.45903139e-02 -1.06381958e-06
 -1.11386409e-06 -1.47207936e-06 -2.86520682e-06 -5.60403558e-02
  8.77500607e-07 -1.72418224e-02 -1.64619940e-01 -1.90040470e-01
  2.44866859e-06 -2.27321116e-06 -9.77023309e-02 -1.72196622e-01]
[-4.92095547e-07 -9.0334

Profiling Code

In [10]:
%timeit -n 5 (_, z_star), grads_manual = l_manual(data, prob, x, constraints, params)
# %timeit -n 5 (_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data, prob, x, params)
# %timeit -n 5 [l_manual(_data, prob, x, constraints, params) for _data in zip(*data_batch)]
# %timeit -n 5 (_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data_batch, prob, x, params)

941 ms ± 93.7 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
765 ms ± 89.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


Test Gradients

In [13]:
from functools import lru_cache
# Test gradients
eps = 1e-2

def f_c(_c):
    (L, c, A, b, G, h) = data
    data_new = (L, _c, A, b, G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_b(_b):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, _b, G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_h(_h):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, b, G, _h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_G(_G):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, b, _G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

# c
# print(grads_manual[1].round(3))
# print(grads_cvxpylayers[1].round(3))
# print(np.array([((f_c(data[1]+h_c) - f_c(data[1]-h_c)) / (2*h_c + 1e-6))[i] for i, h_c in enumerate(np.eye(n_vars)*eps)]).round(3))
# print()

# b
# print(grads_manual[3].round(3))
# print(grads_cvxpylayers[3].round(3))
# print(np.array([((f_b(data[3]+h_b) - f_b(data[3]-h_b)) / (2*h_b + 1e-10))[i] for i, h_b in enumerate(np.eye(n_eq)*eps)]).round(3))
# print()

# h
# print(grads_manual[5].round(3))
# print(grads_cvxpylayers[5].round(3))
# print(np.array([((f_h(data[5]+h_h) - f_h(data[5]-h_h)) / (2*h_h + 1e-10))[i] for i, h_h in enumerate(np.eye(n_ineq)*eps)]).round(3))

# G
i = 1
print(grads_manual[4][i].round(3))
print(grads_cvxpylayers[4][i].round(3))

@lru_cache(maxsize=n_vars*n_ineq)
def make_h_G(i,j):
    h_G = np.zeros_like(data[4])
    h_G[i,j]+=eps
    return h_G

row = [((f_G(data[4]+make_h_G(i,j)) - f_G(data[4]-make_h_G(i,j))) / (2*make_h_G(i,j) + 1e-10))[i,j] for j in range(data[4].shape[1])]
print(np.array(row).round(3))

[ 0.     0.    -0.523 -0.144 -0.    -0.    -0.    -0.    -0.     0.461
 -0.    -0.     0.432 -0.     0.187 -0.    -0.    -0.    -0.    -0.
  0.    -0.     0.394 -0.022  0.     0.     0.    -0.     0.145  0.
  0.389  0.644 -0.    -0.009 -0.     0.548 -0.     0.    -0.    -0.405
 -0.447 -0.579 -0.    -0.693  0.723  0.201 -0.444 -0.    -0.     0.
  0.    -0.6   -0.    -0.517  0.573  0.666  0.     0.    -0.434 -0.007]
[-0.141  0.038 -1.103 -0.016 -0.053 -0.394  0.307  0.045 -0.024  0.305
 -0.097  0.188 -0.024 -0.201  0.49  -0.459 -0.219  0.089 -0.301 -0.58
  0.19  -0.06   0.463 -0.178  0.143  0.133 -0.22  -0.017 -0.359 -0.095
  0.445  0.19  -0.154  0.202  0.055  0.336 -0.243 -0.292 -0.204 -0.158
 -0.389 -1.023  0.492 -1.097  1.013  0.146 -0.421 -0.163 -0.17  -0.03
  0.034 -1.238 -0.154 -1.056  0.907  0.657 -0.255  0.033 -0.734  0.132]
[-0.    -0.001 -0.52  -0.145 -0.003  0.     0.     0.002 -0.003  0.461
 -0.002 -0.001  0.434  0.003  0.187 -0.002  0.    -0.001  0.001  0.
 -0.    -0.     0.

In [9]:
print([g.shape for g in grads_manual])
print([g.shape for g in grads_cvxpylayers])
print(*[g.round(4) for g in grads_manual], sep='\n')
print()
print(*[g.round(4) for g in grads_cvxpylayers], sep='\n')

print(*[np.sum(g1-g2) for g1,g2 in zip(grads_manual, grads_cvxpylayers)], sep='\n')

[(60, 60), (60,), (10, 60), (10,), (60, 60), (60,)]
[(60, 60), (60,), (10, 60), (10,), (60, 60), (60,)]
[[ 0.      0.      0.0156 ...  0.      0.0191  0.0147]
 [-0.     -0.      0.0044 ... -0.     -0.0005 -0.0097]
 [ 0.      0.     -0.0004 ...  0.      0.0095  0.0232]
 ...
 [ 0.      0.     -0.0055 ...  0.     -0.0043  0.0006]
 [ 0.      0.      0.0199 ...  0.      0.0305  0.0333]
 [ 0.      0.      0.0263 ...  0.      0.0335  0.028 ]]
[-0.     -0.     -0.032  -0.0156 -0.     -0.      0.      0.     -0.
 -0.015   0.     -0.      0.0044 -0.     -0.0184 -0.      0.     -0.
  0.     -0.      0.      0.      0.0118 -0.0096 -0.     -0.      0.
 -0.      0.0056  0.      0.0004  0.008  -0.     -0.0097 -0.     -0.0054
 -0.      0.     -0.     -0.0404 -0.0359 -0.0309  0.     -0.043   0.0109
 -0.0069 -0.0352 -0.     -0.     -0.     -0.     -0.0396  0.     -0.0299
  0.0078  0.0092  0.     -0.     -0.0365 -0.0241]
[[ 0.     -0.     -0.2172 -0.0653 -0.     -0.     -0.     -0.     -0.
   0.1568 -0. 