# Using preconditioners


Preconditioning can notably improve the convergence of iterative methods. Preconditioners are particularly suited for solving sparse systems that arise from PDE problems. In this example, we will show how to use a simple Jacobi preconditioner (see [here](https://en.wikipedia.org/wiki/Preconditioner#Jacobi_(or_diagonal)_preconditioner)) to solve a 2D Laplacian linear system using `lx.cg`. We will first show the performance of the solver without preconditioning and then with Jacobi preconditioning.



Let's define a [Poisson problem in 2D](https://en.wikipedia.org/wiki/Discrete_Poisson_equation)

In [51]:
import jax
import jax.numpy as jnp
import jax.random as jr
from scipy.sparse import diags, kron, eye
import jax.experimental.sparse as js
import lineax as lx

def poisson(n, m):
    """
    Create a 2D Laplacian matrix on an n-by-m grid as a JAX BCOO sparse matrix.
    """
    lap_1d_n = diags([-1, 2, -1], [-1, 0, 1], shape=(n, n), format="csr")
    lap_1d_m = diags([-1, 2, -1], [-1, 0, 1], shape=(m, m), format="csr")
    lap_2d = kron(eye(m, format="csr"), lap_1d_n) + kron(lap_1d_m, eye(n, format="csr"))
    return js.BCOO.from_scipy_sparse(lap_2d)


# Set up the problem: A x = b
n, m = 200, 200
A = poisson(n, m)
key = jr.PRNGKey(0)
b = jr.uniform(key, (A.shape[0],))

in_structure = jax.eval_shape(lambda: b)

Our Laplacian matrix `A` is a large sparse matrix of size `(n*m, n*m)`. We do not want to materialize it with a `MatrixLinearOperator`, which only supports dense matrices. Instead, we define a `SparseMatrixLinearOperator` that computes the sparse matrix-vector `A @ x` product.

In [52]:
# Define operator and solve with GMRES
class SparseMatrixLinearOperator(lx.MatrixLinearOperator):
    def mv(self, vector):
        return self.matrix @ vector
    

@lx.is_positive_semidefinite.register(SparseMatrixLinearOperator)
def _(op):
    return True
    
operator = SparseMatrixLinearOperator(A)
solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)
x = lx.linear_solve(operator, b, solver=solver, throw=False).value

Let's check the performance of this solve.

In [53]:
# Check the residual norm
error = jnp.linalg.norm(b - (A @ x))
error

Array(19.014511, dtype=float32)

Pretty bad hey. Now we use a simple Jacobi preconditioner. We need to define another `FunctionLinearOperator` that computes the sparse matrix-vector `M @ x` product, where `M` is the Jacobi preconditioner. The Jacobi preconditioner is a diagonal matrix with the diagonal elements equal to the diagonal elements of `A`. We need to write a utility function to extract the diagonal of a `BCOO` matrix.

In [54]:
@jax.jit
def get_diagonal(matrix):
    """
    Extract the diagonal from a sparse matrix.
    """
    is_diag = matrix.indices[:, 0] == matrix.indices[:, 1]
    diag_values = jnp.where(is_diag, matrix.data, 0)
    diag = jnp.zeros(matrix.shape[0], dtype=matrix.data.dtype)
    diag = diag.at[matrix.indices[:, 0]].add(diag_values)
    return diag
jacobi = get_diagonal(A)

preconditioner = lx.FunctionLinearOperator(lambda x: x / jacobi, 
                                           in_structure, 
                                           tags=[lx.positive_semidefinite_tag])

solver = lx.GMRES(atol=1e-5, rtol=1e-5, max_steps=30)
x = lx.linear_solve(operator, 
                    b, 
                    solver=solver, 
                    options={"preconditioner": preconditioner}, 
                    throw=False).value

In [55]:
# Check the residual norm
error = jnp.linalg.norm(b - (A @ x))
error

Array(19.014511, dtype=float32)

That's much better! More advanced preconditioners such as multigrid preconditioners could be used to further improve the convergence of the solver.