<a href="https://colab.research.google.com/github/shaabhishek/cvxpylayer_qp/blob/main/Differentiable_QP_Layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
!pip install cvxpylayers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [20]:
import numpy as np

import scipy
import cvxpy as cvx
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from cvxpylayers.jax import CvxpyLayer

In [204]:
D = jnp.diag
ccat = lambda x: jnp.concatenate(x, axis=1)
rcat = lambda x: jnp.concatenate(x, axis=0)
mv = lambda A,b: jnp.einsum("ij,j->i", A, b)
outer = jnp.outer

def make_problem_cvxpy(n_vars, n_eq, n_ineq):
    # Create a CVXPY problem.
    z = cvx.Variable(n_vars)

    L = cvx.Parameter((n_vars, n_vars))
    c = cvx.Parameter((n_vars))
    A = cvx.Parameter((n_eq, n_vars))
    b = cvx.Parameter((n_eq))
    G = cvx.Parameter((n_ineq, n_vars))
    h = cvx.Parameter((n_ineq))

    objective = cvx.Minimize( 0.5 * cvx.sum_squares(L @ z) + c @ z)
    # objective = cvx.Minimize( c @ z)
    constraints = [A @ z == b, G @ z <= h]
    prob = cvx.Problem(objective, constraints)
    assert prob.is_dpp()
    return prob, z, constraints, [L, c, A, b, G, h]

def make_problem_cvxpy_batch(n_vars, n_eq, n_ineq, n_batch):
    # Create a CVXPY problem.
    z = cvx.Variable((n_batch, n_vars))

    L = cvx.Parameter((n_batch, n_vars, n_vars))
    c = cvx.Parameter((n_batch, n_vars))
    A = cvx.Parameter((n_batch, n_eq, n_vars))
    b = cvx.Parameter((n_batch, n_eq))
    G = cvx.Parameter((n_batch, n_ineq, n_vars))
    h = cvx.Parameter((n_batch, n_ineq))

    # objective = cvx.Minimize( 0.5 * cvx.sum_squares(L @ z) + c @ z)
    cost = 0
    for i in range(n_batch):
        cost += 0.5 * cvx.sum_squares(L[i] @ z[i]) + c[i] @ z[i]
    objective = cvx.Minimize(cost)

    constraints = []
    for i in range(n_batch):
        constraints.append(A[i] @ z[i] == b[i])

    for i in range(n_batch):
        constraints.append(G[i] @ z[i] <= h[i])

    prob = cvx.Problem(objective, constraints)
    assert prob.is_dpp()
    return prob, z, constraints, [L, c, A, b, G, h]

def solve_problem_cvxpy(prob, x, constraints, params, data):
    for p, d in zip(params, data):
        p.value = np.array(d)

    prob.solve(solver=cvx.SCS, verbose=False, warm_start=False, eps=1e-6)
    return x.value, constraints[0].dual_value, constraints[1].dual_value

def make_problem_cvxpylayer(prob, x, params):
    cvxpylayer = CvxpyLayer(prob, parameters=params, variables=[x])
    return cvxpylayer

def solve_problem_cvxpylayer(cvxpylayer, data):
    return cvxpylayer(*data, solver_args={"eps": 1e-6})

@jax.jit
def compute_grads_manual(z_star, nu_star, lambda_star, data, grad_l):
    (L, c, A, b, G, h) = data
    Q = L @ L.T
    
    M = rcat(
        [
            ccat([Q, G.T, A.T]),
            ccat([D(lambda_star)@G, D(mv(G, z_star) - h), jnp.zeros((n_ineq,n_eq))]),
            ccat([A, jnp.zeros((n_eq,n_ineq)), jnp.zeros((n_eq,n_eq))])
            ])
    
    assert M.shape == (n_vars+n_eq+n_ineq, n_vars+n_eq+n_ineq)

    eq_37 = -jnp.linalg.solve( M.T, grad_l )
    d_z, d_lambda, d_nu = eq_37[:n_vars], eq_37[n_vars:n_vars+n_ineq], eq_37[n_vars+n_ineq:]
    
    grad_L = L @ (outer(d_z, z_star) + outer(z_star, d_z))
    grad_c = d_z
    grad_A = outer(d_nu, z_star) + outer(nu_star, d_z)
    grad_b = -d_nu
    grad_G = outer(lambda_star, d_z) + D(lambda_star) @ outer(d_lambda, z_star)
    grad_h = -mv(D(lambda_star), d_lambda)
    return grad_L, grad_c, grad_A, grad_b, grad_G, grad_h


In [233]:
def get_data(n_vars, n_eq, n_ineq):
    Q = np.random.randn(n_vars, n_vars)
    # Q = np.zeros((n_vars, n_vars))
    Q = Q@Q.T
    assert np.all(np.linalg.eigvals(Q) >= 0)

    # L = np.linalg.cholesky(Q)
    L = scipy.linalg.sqrtm(Q)
    assert np.allclose(Q, L@L.T)
    c = np.random.randn(n_vars)
    A = np.random.randn(n_eq, n_vars)
    b = np.random.randn(n_eq)
    # G = np.random.randn(n_ineq, n_vars)
    # h = np.random.randn(n_ineq)

    G = np.eye(n_vars)
    h = np.zeros(n_vars)
    return (L, c, A, b, G, h)

def get_data_batch(n_vars, n_eq, n_ineq, n_batch):
    Q = np.random.randn(n_batch, n_vars, n_vars)
    Q = np.einsum("bij,bkj->bik", Q, Q)
    assert np.all(np.linalg.eigvals(Q[0]) >= 0)
    L = np.stack([scipy.linalg.sqrtm(_Q) for _Q in Q], axis=0)

    c = np.random.randn(n_batch, n_vars)
    A = np.random.randn(n_batch, n_eq, n_vars)
    b = np.random.randn(n_batch, n_eq)
    # G = np.random.randn(n_batch, n_ineq, n_vars)
    # h = np.random.randn(n_batch, n_ineq)

    G = np.stack([np.eye(n_vars) for _ in range(n_batch)], axis=0)
    h = np.zeros((n_batch, n_vars))

    return (L, c, A, b, G, h)


def l_cvxpylayer(data, prob, x, params):
    cvxpylayer = make_problem_cvxpylayer(prob, x, params)
    z_star_layer, = solve_problem_cvxpylayer(cvxpylayer, data)
    return jnp.sum(z_star_layer), z_star_layer

def l_manual(data, prob, x, constraints, params):
    z_star, nu_star, lambda_star = solve_problem_cvxpy(prob, x, constraints, params, data)
    # print(z_star, nu_star, lambda_star)
    # print(lambda_star)
    grad_l = jnp.concatenate([jnp.ones_like(z_star), jnp.zeros_like(lambda_star), jnp.zeros_like(nu_star)])
    grads_manual = compute_grads_manual(z_star, nu_star, lambda_star, data, grad_l)
    l = z_star.sum()
    return (l, z_star), grads_manual


n_vars = 50
n_eq = 1
n_ineq = n_vars
data = get_data(n_vars, n_eq, n_ineq)
data_batch = get_data_batch(n_vars, n_eq, n_ineq, 20)
prob, x, constraints, params = make_problem_cvxpy(n_vars, n_eq, n_ineq)

l_cvxpylayer_value_and_grad = jax.value_and_grad(l_cvxpylayer, has_aux=True)
(_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data, prob, x, params)
(_, z_star), grads_manual = l_manual(data, prob, x, constraints, params)
# l_cvxpylayer_value_and_grad(data_batch, prob, x, params)
assert np.allclose(z_star, z_star_layer, atol=1e-3), np.sum(np.abs(z_star - z_star_layer))

In [236]:
%time (_, z_star), grads_manual = l_manual(data, prob, x, constraints, params)
%time (_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data, prob, x, params)
%timeit -n 5 [l_manual(_data, prob, x, constraints, params) for _data in zip(*data_batch)]
%timeit -n 5 (_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data_batch, prob, x, params)

CPU times: user 18.8 ms, sys: 1.15 ms, total: 20 ms
Wall time: 23.5 ms
CPU times: user 64.2 ms, sys: 52.9 ms, total: 117 ms
Wall time: 58.6 ms
549 ms ± 70.7 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
619 ms ± 174 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [237]:
from functools import lru_cache
# Test gradients
eps = 1e-2

def f_c(_c):
    (L, c, A, b, G, h) = data
    data_new = (L, _c, A, b, G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_b(_b):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, _b, G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_h(_h):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, b, G, _h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_G(_G):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, b, _G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

# print(grads_manual[1].round(3))
# print(grads_cvxpylayers[1].round(3))
# print(np.array([((f_c(data[1]+h_c) - f_c(data[1]-h_c)) / (2*h_c + 1e-6))[i] for i, h_c in enumerate(np.eye(n_vars)*eps)]).round(3))
# print()

# print(grads_manual[3].round(3))
# print(grads_cvxpylayers[3].round(3))
# print(np.array([((f_b(data[3]+h_b) - f_b(data[3]-h_b)) / (2*h_b + 1e-10))[i] for i, h_b in enumerate(np.eye(n_eq)*eps)]).round(3))
# print()

# print(grads_manual[5].round(3))
# print(grads_cvxpylayers[5].round(3))
# print(np.array([((f_h(data[5]+h_h) - f_h(data[5]-h_h)) / (2*h_h + 1e-10))[i] for i, h_h in enumerate(np.eye(n_ineq)*eps)]).round(3))

# G
i = 3
print(grads_manual[4][i].round(3))
print(grads_cvxpylayers[4][i].round(3))

@lru_cache(maxsize=n_vars*n_ineq)
def make_h_G(i,j):
    h_G = np.zeros_like(data[4])
    h_G[i,j]+=eps
    return h_G

row = [((f_G(data[4]+make_h_G(i,j)) - f_G(data[4]-make_h_G(i,j))) / (2*make_h_G(i,j) + 1e-10))[i,j] for j in range(data[4].shape[1])]
print(np.array(row).round(3))

[ 0.    -0.     0.    -0.    -0.015  0.     0.031 -0.222 -0.306 -0.139
 -0.    -0.     0.    -0.119 -0.345 -0.289 -0.173 -0.26  -0.142  0.
  0.    -0.    -0.    -0.     0.    -0.    -0.     0.    -0.    -0.21
  0.     0.    -0.435 -0.23   0.    -0.121  0.     0.    -0.416 -0.
 -0.054 -0.     0.    -0.279  0.    -0.151 -0.277 -0.127  0.01   0.   ]
[-0.05   0.049  0.023  0.027  0.017  0.037  0.05  -0.245 -0.394 -0.201
  0.007  0.016  0.045 -0.119 -0.402 -0.231 -0.195 -0.266 -0.092 -0.073
  0.039  0.01   0.039 -0.018  0.009 -0.1    0.06  -0.092 -0.048 -0.21
  0.037 -0.045 -0.527 -0.22   0.009 -0.026 -0.021 -0.006 -0.422  0.053
  0.046 -0.003  0.025 -0.37  -0.047 -0.218 -0.28  -0.198 -0.038 -0.015]
[-0.    -0.     0.     0.    -0.015 -0.     0.031 -0.222 -0.306 -0.139
 -0.    -0.     0.    -0.119 -0.345 -0.289 -0.173 -0.259 -0.142  0.
  0.     0.    -0.006 -0.    -0.    -0.     0.    -0.    -0.    -0.21
 -0.    -0.    -0.435 -0.23  -0.    -0.121 -0.    -0.    -0.416  0.
 -0.054 -0.    -0. 

In [218]:
print(z_star)
print(z_star_layer)

[ 1.62172195e-11  5.52643494e-12 -7.75841854e-02 -7.80358178e-02
 -9.38404829e-02]
[ 1.62171808e-11  5.52639232e-12 -7.75841854e-02 -7.80358178e-02
 -9.38404829e-02]


In [219]:
print([g.shape for g in grads_manual])
print([g.shape for g in grads_cvxpylayers])
print(*[g.round(4) for g in grads_manual], sep='\n')
print()
print(*[g.round(4) for g in grads_cvxpylayers], sep='\n')

print(*[np.sum(g1-g2) for g1,g2 in zip(grads_manual, grads_cvxpylayers)], sep='\n')

[(5, 5), (5,), (1, 5), (1,), (5, 5), (5,)]
[(5, 5), (5,), (1, 5), (1,), (5, 5), (5,)]
[[ 0.      0.     -0.0007 -0.0038 -0.0014]
 [-0.     -0.     -0.0022  0.0053 -0.0013]
 [-0.     -0.     -0.0038  0.0022 -0.0035]
 [-0.     -0.      0.004   0.0111  0.0061]
 [-0.     -0.     -0.0028  0.0008 -0.0027]]
[ 0.      0.      0.0139 -0.0304  0.0086]
[[ 0.     -0.     -0.0746 -0.0427 -0.0842]]
[-0.831]
[[-0.      0.      0.1721  0.098   0.1944]
 [-0.      0.      0.1522  0.1281  0.1795]
 [ 0.      0.      0.      0.      0.    ]
 [ 0.      0.      0.      0.      0.    ]
 [ 0.      0.      0.      0.      0.    ]]
[ 1.9157  1.8614 -0.     -0.     -0.    ]

[[ 0.      0.     -0.0007 -0.0038 -0.0014]
 [-0.     -0.     -0.0022  0.0053 -0.0013]
 [-0.     -0.     -0.0038  0.0022 -0.0035]
 [-0.     -0.      0.004   0.0111  0.0061]
 [-0.     -0.     -0.0028  0.0008 -0.0027]]
[ 0.      0.      0.0139 -0.0304  0.0086]
[[-0.     -0.     -0.0746 -0.0427 -0.0842]]
[-0.831]
[[ 0.      0.      0.1721  0.098 