<a href="https://colab.research.google.com/github/shaabhishek/cvxpylayer_qp/blob/main/Differentiable_QP_Layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install cvxpylayers
!pip install jaxopt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from functools import partial, lru_cache
import numpy as np

import scipy
import cvxpy as cvx
import jax
jax.config.update("jax_enable_x64", True)
import jaxopt
import jax.numpy as jnp
from cvxpylayers.jax import CvxpyLayer

In [3]:
D = jnp.diag
ccat = lambda x: jnp.concatenate(x, axis=1)
rcat = lambda x: jnp.concatenate(x, axis=0)
mv = lambda A,b: jnp.einsum("ij,j->i", A, b)
outer = jnp.outer

@lru_cache
def make_problem_cvxpy(n_vars, n_eq, n_ineq):
    # Create a CVXPY problem.
    z = cvx.Variable(n_vars)

    L = cvx.Parameter((n_vars, n_vars))
    c = cvx.Parameter((n_vars))
    A = cvx.Parameter((n_eq, n_vars))
    b = cvx.Parameter((n_eq))
    G = cvx.Parameter((n_ineq, n_vars))
    h = cvx.Parameter((n_ineq))

    objective = cvx.Minimize( 0.5 * cvx.sum_squares(L @ z) + c @ z)
    # objective = cvx.Minimize( c @ z)
    constraints = [A @ z == b, G @ z <= h]
    prob = cvx.Problem(objective, constraints)
    assert prob.is_dpp()
    return prob, z, constraints, [L, c, A, b, G, h]


def solve_problem_cvxpy(prob, x, constraints, params, data):
    for p, d in zip(params, data):
        p.value = np.array(d)

    prob.solve(solver=cvx.SCS, verbose=False, warm_start=True, eps=1e-6)
    return x.value, constraints[0].dual_value, constraints[1].dual_value

def make_problem_cvxpylayer(prob, x, params):
    cvxpylayer = CvxpyLayer(prob, parameters=params, variables=[x])
    return cvxpylayer

def solve_problem_cvxpylayer(cvxpylayer, data):
    return cvxpylayer(*data, solver_args={"eps": 1e-6})

def solve_problem_jaxopt(osqpsolver, data):
    sol = osqpsolver(params_obj=(np.linalg.matrix_power(data[0],2), data[1]), params_eq=(data[2], data[3]), params_ineq=(data[4], data[5])).params
    return sol
    
@jax.jit
def compute_grads_manual(z_star, nu_star, lambda_star, data, grad_l):
    (L, c, A, b, G, h) = data
    Q = L @ L.T
    
    M = rcat(
        [
            ccat([Q, G.T, A.T]),
            ccat([D(lambda_star)@G, D(mv(G, z_star) - h), jnp.zeros((n_ineq,n_eq))]),
            ccat([A, jnp.zeros((n_eq,n_ineq)), jnp.zeros((n_eq,n_eq))])
            ])
    
    assert M.shape == (n_vars+n_eq+n_ineq, n_vars+n_eq+n_ineq)

    eq_37 = -jnp.linalg.solve( M.T, grad_l )
    d_z, d_lambda, d_nu = eq_37[:n_vars], eq_37[n_vars:n_vars+n_ineq], eq_37[n_vars+n_ineq:]
    
    grad_L = L @ (outer(d_z, z_star) + outer(z_star, d_z))
    grad_c = d_z
    grad_A = outer(d_nu, z_star) + outer(nu_star, d_z)
    grad_b = -d_nu
    grad_G = outer(lambda_star, d_z) + D(lambda_star) @ outer(d_lambda, z_star)
    grad_h = -mv(D(lambda_star), d_lambda)
    return grad_L, grad_c, grad_A, grad_b, grad_G, grad_h


In [4]:
@lru_cache
def make_problem_cvxpy_batch(n_vars, n_eq, n_ineq, n_batch):
    # Create a CVXPY problem.
    z = [cvx.Variable(n_vars) for _ in range(n_batch)]

    L = [cvx.Parameter((n_vars, n_vars)) for _ in range(n_batch)]
    c = [cvx.Parameter((n_vars)) for _ in range(n_batch)]
    A = [cvx.Parameter((n_eq, n_vars)) for _ in range(n_batch)]
    b = [cvx.Parameter((n_eq)) for _ in range(n_batch)]
    G = [cvx.Parameter((n_ineq, n_vars)) for _ in range(n_batch)]
    h = [cvx.Parameter((n_ineq)) for _ in range(n_batch)]

    # objective = cvx.Minimize( 0.5 * cvx.sum_squares(L @ z) + c @ z)
    cost = 0
    for i in range(n_batch):
        cost += 0.5 * cvx.sum_squares(L[i] @ z[i]) + c[i] @ z[i]
    objective = cvx.Minimize(cost)

    eq_constraints = []
    for i in range(n_batch):
        eq_constraints.append(A[i] @ z[i] == b[i])

    ineq_constraints = []
    for i in range(n_batch):
        ineq_constraints.append(G[i] @ z[i] <= h[i])

    prob = cvx.Problem(objective, eq_constraints+ineq_constraints)
    assert prob.is_dpp()
    return prob, z, (eq_constraints, ineq_constraints), [L, c, A, b, G, h]


def solve_problem_cvxpy_batch(prob, x, constraints, params, data):
    for p, d in zip(params, data):
        for p_b, d_b in zip(p, d):
            p_b.value = np.array(d_b)

    # prob.solve(solver=cvx.SCS, verbose=False, warm_start=True, eps=1e-6)
    prob.solve(solver=cvx.ECOS, verbose=False, warm_start=True)
    z_star = np.stack([_x.value for _x in x])
    eq_constraints, ineq_constraints = constraints
    nu_star = np.stack([_c.dual_value for _c in eq_constraints])
    lambda_star = np.stack([_c.dual_value for _c in ineq_constraints])
    return z_star, nu_star, lambda_star


In [52]:
# As a jax custom_vjp
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def solve_QP_cvxpy(data, problem_params):
    prob, x, constraints, params = problem_params
    z_star, nu_star, lambda_star = solve_problem_cvxpy(prob, x, constraints, params, data)
    return z_star

def f_fwd_cvxpy(data, problem_params):
    prob, x, constraints, params = problem_params
    (z_star, nu_star, lambda_star) = solve_problem_cvxpy(prob, x, constraints, params, data)
    return z_star, (z_star, nu_star, lambda_star, data)

@partial(jax.custom_vjp, nondiff_argnums=(1,))
def solve_QP_jaxopt(data, problem_params):
    # prob, x, constraints, params = problem_params
    # z_star, nu_star, lambda_star = solve_problem_cvxpy(prob, x, constraints, params, data)
    sol = solve_problem_jaxopt(problem_params, data)
    z_star, nu_star, lambda_star = sol.primal, sol.dual_eq, sol.dual_ineq
    return z_star

def f_fwd_jaxopt(data, problem_params):
    # prob, x, constraints, params = problem_params
    # (z_star, nu_star, lambda_star) = solve_problem_cvxpy(prob, x, constraints, params, data)
    sol = solve_problem_jaxopt(problem_params, data)
    z_star, nu_star, lambda_star = sol.primal, sol.dual_eq, sol.dual_ineq
    return z_star, (z_star, nu_star, lambda_star, data)

@partial(jax.jit, static_argnames=["shapes"])
def _compute_grads(opt_solution, data, shapes, g):
    z_star, nu_star, lambda_star = opt_solution
    n_vars, n_eq, n_ineq = shapes
    # For backward pass
    (L, c, A, b, G, h) = data
    Q = L @ L.T
    M = rcat(
        [
            ccat([Q, G.T, A.T]),
            ccat([D(lambda_star)@G, D(mv(G, z_star) - h), jnp.zeros((n_ineq,n_eq))]),
            ccat([A, jnp.zeros((n_eq,n_ineq)), jnp.zeros((n_eq,n_eq))])
            ])
    
    assert M.shape == (n_vars+n_eq+n_ineq, n_vars+n_eq+n_ineq)
    g = jnp.concatenate([g, jnp.zeros_like(lambda_star), jnp.zeros_like(nu_star) ], axis=-1)
    eq_37 = -jnp.linalg.solve( M.T, g )
    d_z, d_lambda, d_nu = eq_37[:n_vars], eq_37[n_vars:n_vars+n_ineq], eq_37[n_vars+n_ineq:]
    # grads_manual = compute_grads_manual(z_star, nu_star, lambda_star, data, g)
    grad_L = L @ (outer(d_z, z_star) + outer(z_star, d_z))
    grad_c = d_z
    grad_A = outer(d_nu, z_star) + outer(nu_star, d_z)
    grad_b = -d_nu
    grad_G = outer(lambda_star, d_z) + D(lambda_star) @ outer(d_lambda, z_star)
    grad_h = -mv(D(lambda_star), d_lambda)
    grads = (grad_L, grad_c, grad_A, grad_b, grad_G, grad_h)
    return grads

def f_bwd(problem_params, residualdata, g):
    (z_star, nu_star, lambda_star, data) = residualdata
    n_vars, n_eq, n_ineq = x.shape[0], constraints[0].shape[0], constraints[1].shape[0]
    grads = _compute_grads((z_star, nu_star, lambda_star), data, (n_vars, n_eq, n_ineq), g)
    return (grads,)

solve_QP_cvxpy.defvjp(f_fwd_cvxpy, f_bwd)
solve_QP_jaxopt.defvjp(f_fwd_jaxopt, f_bwd)

Create data (i.e. parameters Q, c, A, b, G, h)

In [53]:
def get_data(n_vars, n_eq, n_ineq):
    Q = np.random.randn(n_vars, n_vars)
    # Q = np.zeros((n_vars, n_vars))
    Q = Q@Q.T
    assert np.all(np.linalg.eigvals(Q) >= 0)

    # L = np.linalg.cholesky(Q)
    L = scipy.linalg.sqrtm(Q)
    assert np.allclose(Q, L@L.T)
    c = np.random.randn(n_vars)

    A = np.random.randn(n_eq, n_vars)
    b = np.random.randn(n_eq)
    G = np.random.randn(n_ineq, n_vars)
    h = np.random.randn(n_ineq)

    # G = np.eye(n_vars)
    # h = np.zeros(n_vars)
    return (L, c, A, b, G, h)

def get_data_batch(n_vars, n_eq, n_ineq, n_batch):
    Q = np.random.randn(n_batch, n_vars, n_vars)
    Q = np.einsum("bij,bkj->bik", Q, Q)
    assert np.all(np.linalg.eigvals(Q[0]) >= 0)
    L = np.stack([scipy.linalg.sqrtm(_Q) for _Q in Q], axis=0)

    c = np.random.randn(n_batch, n_vars)
    A = np.random.randn(n_batch, n_eq, n_vars)
    b = np.random.randn(n_batch, n_eq)
    G = np.random.randn(n_batch, n_ineq, n_vars)
    h = np.random.randn(n_batch, n_ineq)

    # G = np.stack([np.eye(n_vars) for _ in range(n_batch)], axis=0)
    # h = np.zeros((n_batch, n_vars))

    return (L, c, A, b, G, h)

In [54]:
def l_cvxpylayer(data, prob, x, params):
    cvxpylayer = make_problem_cvxpylayer(prob, x, params)
    z_star_layer, = solve_problem_cvxpylayer(cvxpylayer, data)
    return jnp.sum(z_star_layer), z_star_layer
l_cvxpylayer_value_and_grad = jax.value_and_grad(l_cvxpylayer, has_aux=True)

def l_manual(data, prob, x, constraints, params):
    z_star, nu_star, lambda_star = solve_problem_cvxpy(prob, x, constraints, params, data)
    grad_l = jnp.concatenate([jnp.ones_like(z_star), jnp.zeros_like(lambda_star), jnp.zeros_like(nu_star)])
    grads_manual = compute_grads_manual(z_star, nu_star, lambda_star, data, grad_l)
    l = z_star.sum()
    return (l, z_star), grads_manual

def l_jax(data, prob, x, constraints, params):
    z_star = solve_QP_cvxpy(data, (prob, x, constraints, params))
    l = jnp.sum(z_star)
    return l, z_star
l_jax_value_and_grad = jax.value_and_grad(l_jax, has_aux=True)


def l_jaxopt(data, qpsolver):
    # qpsolver = qp.run
    z_star = solve_QP_jaxopt(data, qpsolver)
    l = jnp.sum(z_star)
    return l, z_star
l_jaxopt_value_and_grad = jax.value_and_grad(l_jaxopt, has_aux=True)
qp = jaxopt.OSQP()
qpsolver = jax.jit(qp.run)

In [62]:

n_vars = 60
n_eq = 2
n_ineq = n_vars
n_batch = 20
data = get_data(n_vars, n_eq, n_ineq)
data_batch = get_data_batch(n_vars, n_eq, n_ineq, n_batch)
prob, x, constraints, params = make_problem_cvxpy(n_vars, n_eq, n_ineq)
prob_batch, x_batch, constraints_batch, params_batch = make_problem_cvxpy_batch(n_vars, n_eq, n_ineq, n_batch)

(_, z_star), grads_manual = l_manual(data, prob, x, constraints, params)
(_, z_star_jax), grads_jax = l_jax_value_and_grad(data, prob, x, constraints, params)
(_, z_star_jaxopt), grads_jaxopt = l_jaxopt_value_and_grad(data, qpsolver)
(_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data, prob, x, params)
l_cvxpylayer_value_and_grad(data_batch, prob, x, params);




Batch Variants

In [63]:
# %time z_star_b_l, = solve_problem_cvxpylayer(make_problem_cvxpylayer(prob, x, params), get_data_batch(n_vars, n_eq, n_ineq, n_batch))
# %time z_star_b_m, nu_star_b_m, lambda_star_b_m = solve_problem_cvxpy_batch(prob_batch, x_batch, constraints_batch, params_batch, get_data_batch(n_vars, n_eq, n_ineq, n_batch))

In [64]:
# solve_QP(data, (prob, x, constraints, params))
# f_fwd(data, (prob, x, constraints, params))
# jax.jacobian(solve_QP)(data, (prob, x, constraints, params))

Compare solutions

In [65]:
print(z_star.round(3))
print(z_star_layer.round(3))
print(z_star_jax.round(3))
print(z_star_jaxopt.round(3))

assert np.allclose(z_star, z_star_layer, atol=1e-3), np.sum(np.abs(z_star - z_star_layer))
assert np.allclose(z_star, z_star_jax, atol=1e-3), np.sum(np.abs(z_star - z_star_jax))
assert np.allclose(z_star, z_star_jaxopt, atol=1e-3), np.sum(np.abs(z_star - z_star_jaxopt))

[ 0.381  0.509 -0.158  0.093 -0.155 -0.075 -0.448  0.073  0.144  0.008
  0.12   0.238  0.395  0.404  0.128  0.09  -0.033 -0.44  -0.277  0.091
 -0.153  0.382 -0.208 -0.06  -0.129  0.005  0.198 -0.024  0.008 -0.075
 -0.045  0.187  0.082 -0.126 -0.162  0.058 -0.085 -0.087 -0.071 -0.407
 -0.124 -0.005  0.121 -0.255  0.046  0.051  0.386 -0.195 -0.25  -0.391
  0.061 -0.475 -0.061 -0.096 -0.064  0.245  0.057 -0.138  0.054  0.055]
[ 0.381  0.509 -0.158  0.093 -0.155 -0.075 -0.448  0.073  0.144  0.008
  0.12   0.238  0.395  0.404  0.128  0.09  -0.033 -0.44  -0.277  0.091
 -0.153  0.382 -0.208 -0.06  -0.129  0.005  0.198 -0.024  0.008 -0.075
 -0.045  0.187  0.082 -0.126 -0.162  0.058 -0.085 -0.087 -0.071 -0.407
 -0.124 -0.005  0.121 -0.255  0.046  0.051  0.386 -0.195 -0.25  -0.391
  0.061 -0.475 -0.061 -0.096 -0.064  0.245  0.057 -0.138  0.054  0.055]
[ 0.381  0.509 -0.158  0.093 -0.155 -0.075 -0.448  0.073  0.144  0.008
  0.12   0.238  0.395  0.404  0.128  0.09  -0.033 -0.44  -0.277  0.091
 -0.

Test Gradients

In [66]:
# Test gradients
eps = 1e-1

def f_c(_c):
    (L, c, A, b, G, h) = data
    data_new = (L, _c, A, b, G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_b(_b):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, _b, G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_h(_h):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, b, G, _h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

def f_G(_G):
    (L, c, A, b, G, h) = data
    data_new = (L, c, A, b, _G, h)
    return l_manual(data_new, prob, x, constraints, params)[0][0]

# c
print(grads_manual[1].round(3))
print(grads_jax[1].round(3))
print(grads_jaxopt[1].round(3))
print(grads_cvxpylayers[1].round(3))
print(np.array([((f_c(data[1]+h_c) - f_c(data[1]-h_c)) / (2*h_c + 1e-10))[i] for i, h_c in enumerate(np.eye(n_vars)*eps)]).round(3))
print()
print("======================================================")

# # b
# print(grads_manual[3].round(3))
# print(grads_jax[3].round(3))
# print(grads_jaxopt[3].round(3))
# print(grads_cvxpylayers[3].round(3))
# print(np.array([((f_b(data[3]+h_b) - f_b(data[3]-h_b)) / (2*h_b + 1e-10))[i] for i, h_b in enumerate(np.eye(n_eq)*eps)]).round(3))
# print()
# print("======================================================")

# h
# print(grads_manual[5].round(3))
# print(grads_jax[5].round(3))
# print(grads_jaxopt[5].round(3))
# print(grads_cvxpylayers[5].round(3))
# print(np.array([((f_h(data[5]+h_h) - f_h(data[5]-h_h)) / (2*h_h + 1e-10))[i] for i, h_h in enumerate(np.eye(n_ineq)*eps)]).round(3))

# # G
# i = 1
# print(grads_manual[4][i].round(3))
# print(grads_jax[4][i].round(3))
# print(grads_jaxopt[4][i].round(3))
# print(grads_cvxpylayers[4][i].round(3))

# @lru_cache(maxsize=n_vars*n_ineq)
# def make_h_G(i,j):
#     h_G = np.zeros_like(data[4])
#     h_G[i,j]+=eps
#     return h_G

# row = [((f_G(data[4]+make_h_G(i,j)) - f_G(data[4]-make_h_G(i,j))) / (2*make_h_G(i,j) + 1e-10))[i,j] for j in range(data[4].shape[1])]
# print(np.array(row).round(3))

[-0.051 -0.02  -0.11  -0.098 -0.054  0.07   0.008 -0.041 -0.133 -0.022
 -0.013 -0.063 -0.015 -0.076 -0.065  0.037 -0.03  -0.032 -0.001 -0.07
 -0.072  0.02  -0.027 -0.159 -0.099  0.069  0.017 -0.064 -0.028  0.006
 -0.096 -0.083 -0.148 -0.102  0.065 -0.046 -0.031  0.013 -0.005 -0.028
  0.022 -0.056 -0.12   0.031 -0.12  -0.037 -0.07   0.053 -0.044 -0.016
 -0.015 -0.024 -0.061 -0.002 -0.018 -0.144 -0.077 -0.     0.012  0.011]
[-0.051 -0.02  -0.11  -0.098 -0.054  0.07   0.008 -0.041 -0.133 -0.022
 -0.013 -0.063 -0.015 -0.076 -0.065  0.037 -0.03  -0.032 -0.001 -0.07
 -0.072  0.02  -0.027 -0.159 -0.099  0.069  0.017 -0.064 -0.028  0.006
 -0.096 -0.083 -0.148 -0.102  0.065 -0.046 -0.031  0.013 -0.005 -0.028
  0.022 -0.056 -0.12   0.031 -0.12  -0.037 -0.07   0.053 -0.044 -0.016
 -0.015 -0.024 -0.061 -0.002 -0.018 -0.144 -0.077 -0.     0.012  0.011]
[-0.051 -0.02  -0.11  -0.098 -0.054  0.07   0.008 -0.041 -0.133 -0.022
 -0.013 -0.063 -0.015 -0.076 -0.065  0.037 -0.03  -0.032 -0.001 -0.07
 -0.072

Profiling Code

In [61]:
# Unbatched data
%timeit -n 20 (_, z_star), grads_manual = l_manual(data, prob, x, constraints, params)
%timeit -n 20 (_, z_star_jax), grads_jax = l_jax_value_and_grad(data, prob, x, constraints, params)
%timeit -n 20 (_, z_star_jaxopt), grads_jaxopt = l_jaxopt_value_and_grad(data, qpsolver)
%timeit -n 20 (_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data, prob, x, params)

# Batched data
%timeit -n 5 [l_manual(_data, prob, x, constraints, params) for _data in zip(*data_batch)]
%timeit -n 5 [l_jax_value_and_grad(_data, prob, x, constraints, params) for _data in zip(*data_batch)]
%timeit -n 5 [l_jaxopt_value_and_grad(_data, qpsolver) for _data in zip(*data_batch)]
%timeit -n 5 (_, z_star_layer), grads_cvxpylayers = l_cvxpylayer_value_and_grad(data_batch, prob, x, params)

19.4 ms ± 662 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)
27.3 ms ± 729 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)
31.2 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 20 loops each)
84.8 ms ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 20 loops each)
1.12 s ± 136 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
1.23 s ± 127 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
424 ms ± 120 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
1.15 s ± 128 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [None]:
# print([g.shape for g in grads_manual])
# print([g.shape for g in grads_cvxpylayers])
print("======================================================")
print(*[g.round(4) for g in grads_manual], sep='\n')
print("======================================================")
print(*[g.round(4) for g in grads_jax], sep='\n')
print("======================================================")
print(*[g.round(4) for g in grads_cvxpylayers], sep='\n')
print("======================================================")
print(*[np.sum(np.abs(g1-g2)) for g1,g2 in zip(grads_manual, grads_cvxpylayers)], sep='\n')
print("======================================================")