Skip to content

Commit

Permalink
Merge pull request google#455 from Algue-Rythme:convenienceapi
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544446827
  • Loading branch information
JAXopt authors committed Jun 29, 2023
2 parents 7238f76 + 47c2d39 commit 5ff06fe
Show file tree
Hide file tree
Showing 5 changed files with 588 additions and 77 deletions.
234 changes: 227 additions & 7 deletions docs/quadratic_programming.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,83 @@ The solver specificities are summarized in the table below.
The best choice will depend on the usage.

.. list-table:: Quadratic programming solvers
:widths: 45, 15, 20, 15, 15, 15, 22, 15
:widths: 45, 15, 20, 20, 15, 15, 15, 22, 15, 15
:header-rows: 1

* - Name
- jit
- pytree
- matvec
- fun
- quad. fun
- precision
- stability
- speed
- derivative
- input format
* - :class:`jaxopt.EqualityConstrainedQP`
- yes
- yes
- no
- yes
- yes
- ++
- \+
- +++
- implicit
- (Q, c), (A, b)
* - :class:`jaxopt.CvxpyQP`
- no
- no
- no
- no
- +++
- +++
- \+
- implicit
- (Q, c), (A, b), (G, h)
* - :class:`jaxopt.OSQP`
- yes
- yes
- yes
- yes
- \+
- ++
- ++
- implicit
- (Q, c), (A, b), (G, h)
* - :class:`jaxopt.BoxOSQP`
- yes
- yes
- yes
- yes
- \+
- ++
- ++
- both
- (Q, c), A, (l, u)
* - :class:`jaxopt.BoxCDQP`
- yes
- no
- no
- no
- ++
- +++
- ++
- both
- (Q, c), (l, u)

- *jit*: the algorithm can be used with jit or vmap, on GPU/TPU.
- *matvec*: the input can be given as matvec instead of dense matrices.
- *fun*: the algorithm can be used with quadratic polynomial fun.
- *pytree*: the algorithm can be used with pytrees of matrices (see below).
- *matvec*: the QP parameters can be given as matvec instead of dense matrices (see below).
- *quad. fun*: the algorithm can be used with a quadratic function argument (see below).
- *precision*: accuracy expected when the solver succeeds to converge.
- *stability*: capacity to handle badly scaled problems and matrices with poor conditioning.
- *speed*: typical speed on big instances to reach its maximum accuracy.
- *derivative*: whether differentiation is supported only via implicit differentiation, or by both implicit differentiation and unrolling.
- *input format*: see subsections below.


This table is given as rule of thumb only; on some particular instances
some solvers may behave unexpectedly better than others.
some solvers may behave unexpectedly better (or worse!) than others.
In case of difficulties, we suggest to test different combinations of
algorithms, ``maxiter`` and ``tol`` values.

Expand All @@ -87,7 +109,8 @@ The problem takes the form:

jaxopt.EqualityConstrainedQP

This class is optimized for QPs with equality constraints only: it supports jit, pytrees and matvec.
This class is optimized for QPs with equality constraints only: it supports jit, pytrees and matvec.
It is based on the KKT conditions of the problem.

Example::

Expand Down Expand Up @@ -179,6 +202,9 @@ OSQP
This solver is a pure JAX re-implementation of the OSQP algorithm.
It is jittable, supports pytrees and matvecs, but the precision is usually
lower than :class:`CvxpyQP` when run in float32 precision.
It is meant as a drop-in replacement for :class:`CvxpyQP`, but it
is a wrapper over :class:`BoxOSQP`.
Hence we recommend to use :class:`BoxOSQP` to avoid a costly problem transformation.

.. autosummary::
:toctree: _autosummary
Expand Down Expand Up @@ -226,6 +252,7 @@ The problem takes the form:
:class:`jaxopt.BoxOSQP` uses the same underlying solver as :class:`jaxopt.OSQP`
but accepts problems in the above box-constrained format instead. The bounds
``u`` (resp. ``l``) can be set to ``inf`` (resp. ``-inf``) if required.
Equality can be enforced with ``l = u``.

Example::

Expand Down Expand Up @@ -316,3 +343,196 @@ algorithms is *conjugate gradient*. In JAXopt, this can be done as follows::
sol = solve_cg(matvec, b=-c)

print(sol)

Pytree of matrices API
----------------------

Solvers :class:`EqualityConstrainedQP`, :class:`OSQP` and :class:`BoxOSQP` support
the pytree of matrices API. It means that the matrices `Q`, `A`, `G` can be provided
as block diagonal operator whose blocks are leaves of pytrees.
This corresponds to separable problems that can be solved in parallel (one for each leaf).

It offers several advantages:
* This model of parallelism succeeds even if all the problems have different shapes,
contrary to the `jax.vmap` API.
* This formulation is more efficient than a single big matrix, especially when
there are a lot of blocks, and when the blocks themselves are small.
* The tolerance is globally defined and shared by all the problems,
and the number of iterations is the same for all the problems.

We illustrate below the parallel solving of two problems with different shapes::

Q1 = jnp.array([[1.0, -0.5],
[-0.5, 1.0]])
Q2 = jnp.array([[2.0]])
Q = {'problem1': Q1, 'problem2': Q2}

c1 = jnp.array([-0.4, 0.3])
c2 = jnp.array([0.1])
c = {'problem1': c1, 'problem2': c2}

a1 = jnp.array([[-0.5, 1.5]])
a2 = jnp.array([[10.0]])
A = {'problem1': a1, 'problem2': a2}

b1 = jnp.array([0.3])
b2 = jnp.array([5.0])
b = {'problem1': b1, 'problem2': b2}

qp = EqualityConstrainedQP(tol=1e-3)
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
# Solve the two problems in parallel with a single call.
sol = qp.run(**hyperparams).params
print(sol.primal['problem1'], sol.primal['problem2'])

Matvec API
----------

Solvers :class:`EqualityConstrainedQP`, :class:`OSQP` and :class:`BoxOSQP` support the matvec API.
It means that the user can provide a function ``matvec`` that computes the matrix-vector product,
either in the objective `x -> Qx` or in the constraints `x -> Ax`, `x -> Gx`.

It offers several advantages:
* the code is easier to read and closer to the mathematical formulation of the problem.
* sparse matrix-vector products are available, which can be much faster than a dense one.
* the derivatives w.r.t (params_obj, params_eq, params_ineq) may be easier to compute
than materializing the full matrix.
* it is faster than the quadratic function API.

This is the recommended API to use when the matrices are not block diagonal operators,
especially when there are other sparsity patterns involved, or in conjunction with
implicit differentiation::

# Objective:
# min ||data @ x - targets||_2^2 + 2 * n * lam ||x||_1
#
# With BoxOSQP formulation:
#
# min_{x, y, t} y^Ty + 2*n*lam 1^T t
# under targets = data @ x - y
# 0 <= x + t <= infinity
# -infinity <= x - t <= 0
data, targets = datasets.make_regression(n_samples=10, n_features=3, random_state=0)
lam = 10.0

def matvec_Q(params_Q, xyt):
del params_Q # unused
x, y, t = xyt
return jnp.zeros_like(x), 2 * y, jnp.zeros_like(t)

c = jnp.zeros(data.shape[1]), jnp.zeros(data.shape[0]), 2*n*lam * jnp.ones(data.shape[1])

def matvec_A(params_A, xyt):
x, y, t = xyt
residuals = params_A @ x - y
return residuals, x + t, x - t

l = targets, jnp.zeros_like(c[0]), jnp.full(data.shape[1], -jnp.inf)
u = targets, jnp.full(data.shape[1], jnp.inf), jnp.zeros_like(c[0])

hyper_params = dict(params_obj=(None, c), params_eq=data, params_ineq=(l, u))
osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=1e-2)
sol, state = osqp.run(None, **hyper_params)

Quadratic function API
----------------------

Solvers :class:`EqualityConstrainedQP`, :class:`OSQP` and :class:`BoxOSQP` support the quadratic function API.
It means that the whole objective function `x -> 1/2 x^T Q x + c^T x + K` can be provided as a function
``fun`` that computes the quadratic function. The function must be differentiable w.r.t `x`.

It offers several advantages:
* the code is easier to read and closer to the mathematical formulation of the problem.
* there is no need to provide the matrix `Q` and the vector `c` separately, nor to remove the constant term `K`.
* the derivatives w.r.t (params_obj, params_eq, params_ineq) may be even easier to compute
than materializing the full matrix.

Take care that this API also have drawbacks:
* the function ``fun`` must be differentiable w.r.t `x` (with Jax's AD), even if you are not interested in the derivatives of your QP.
* to extract `x -> Qx` and `c` from the function, we need to compute the Hessian-vector product and the gradient of ``fun``, which may be expensive.
* for this API `init_params` must be provided to `run`, contrary to the other APIs.

We illustrate this API with Non Negative Least Squares (NNLS)::

# min_W \|Y-UW\|_F^2
# s.t. W>=0
n, m, rank = 20, 10, 3
onp.random.seed(654)
U = jax.nn.relu(onp.random.randn(n, rank))
W_0 = jax.nn.relu(onp.random.randn(rank, m))
Y = U @ W_0

def fun(W, params_obj):
Y, U = params_obj
# Write the objective as an implicit quadratic polynomial
return jnp.sum(jnp.square(Y - U @ W))

def matvec_G(params_G, W):
del params_G # unused
return -W

zeros = jnp.zeros_like(W_0)
hyper_params = dict(params_obj=(Y, U), params_eq=None, params_ineq=(None, zeros))

solver = OSQP(fun=fun, matvec_G=matvec_G)

init_W = jnp.zeros_like(W_0) # mandatory with `fun` API.
init_params = solver.init_params(init_W, **hyper_params)
W_sol = solver.run(init_params=init_params, **hyper_params).params.primal

This API is not recommended for large-scale problems or nested differentiations, use matvec API instead.

Implicit differentiation pitfalls
---------------------------------

When using implicit differentiation, the parameters w.r.t which we differentiate
must be passed to `params_obj`, `params_eq` or `params_ineq`. They should not be captured
from the global scope by `fun` or `matvec`. We illustrate below this common mistake::

def wrong_solver(Q): # don't do this!

def matvec_Q(params_Q, x):
del params_Q # unused
# error! Q is captured from the global scope.
# it does not fail now, but it will fail later.
return jnp.dot(Q, x)
c = jnp.zeros(Q.shape[0])

A = jnp.array([[1.0, 2.0]])
b = jnp.array([1.0])

eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q)
sol = eq_qp.run(None, params_obj=(None, c), params_eq=(A, b)).params
loss = jnp.sum(sol.primal)
return loss

Q = jnp.array([[1.0, 0.5], [0.5, 4.0]])
_ = wrong_solver(Q) # no error... but it will fail later.
_ = jax.grad(wrong_solver)(Q) # raise CustomVJPException

Also, notice that since the problems are convex, the optimum is independent of the
starting point `init_params`. Hence, derivatives w.r.t `init_params` are always
zero (mathematically).

The correct implementation is given below::

def correct_solver(Q):

def matvec_Q(params_Q, x):
return jnp.dot(params_Q, x)
c = jnp.zeros(Q.shape[0])

A = jnp.array([[1.0, 2.0]])
b = jnp.array([1.0])

eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q)
# Q is passed as a parameter, not captured from the global scope.
sol = eq_qp.run(None, params_obj=(Q, c), params_eq=(A, b)).params
loss = jnp.sum(sol.primal)
return loss

Q = jnp.array([[1.0, 0.5], [0.5, 4.0]])
_ = correct_solver(Q) # no error
_ = jax.grad(correct_solver)(Q) # no error
Loading

0 comments on commit 5ff06fe

Please sign in to comment.