First install the repo and requirements.

In [None]:
%pip --quiet install git+https://github.com/wilson-labs/cola.git

# Linear Solves

CoLA has two main ways for doing linear solves. The first is to use the high-level API found [here](https://cola.readthedocs.io/en/latest/package/cola.linalg.html). The second one is to access one of the series of algorithms provided in CoLA found [here](https://cola.readthedocs.io/en/latest/package/cola.algorithms.html).

Let's start with an example of the high-level API. Below I create a random linear problem of size $N=100$.

In [1]:
from jax import numpy as jnp
from jax.random import PRNGKey, normal, split
import cola

N = 100
key = PRNGKey(seed=21)
A = normal(key, shape=(N, N))
key = split(key, num=1)
rhs = normal(key, shape=(N,))
rhs /= jnp.linalg.norm(rhs)



To solve the linear system we use `cola.inverse` as follows.

In [2]:
soln = cola.inverse(cola.ops.Dense(A)) @ rhs
soln_jax = jnp.linalg.solve(A, rhs)
abs_diff = jnp.linalg.norm(soln - soln_jax)
print(f"{abs_diff:1.2e}")

6.47e-06


Now to specify an algorithm we do the following.

In [3]:
from jax.config import config
config.update("jax_enable_x64", True)

N = 1_000
key = PRNGKey(seed=21)
dtype = jnp.float64
A = cola.ops.Dense(normal(key, shape=(N, N)))
mu = 1.e-1  # a large enough value ensures PSD
S = A @ A.T + mu * cola.ops.I_like(A)
rhs = normal(key, shape=(N,))
rhs /= jnp.linalg.norm(rhs)

In [None]:
soln = cola.algorithms.cg(cola.ops.Dense(A), rhs)
soln_jax = jnp.linalg.solve(S.to_dense(), rhs)
abs_diff = jnp.linalg.norm(soln - soln_jax)
print(f"{abs_diff:1.2e}")