# Cholesky decomposition with jaxmg.potrs

Here, we give an example of calling `jax.potrs`, which solves the linear system of equations $Ax=b$ for symmetric, positive-definite $A$ via a Cholesky decomposition.

The interface of `jaxmg.potrs` is simple to use; one needs to supply to underlying mesh of the sharded data and specify the input shardings:

In [None]:
# examples/readme.py
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jaxmg import potrs

# Assumes we have at least one GPU available
devices = jax.devices("gpu")
assert len(devices) in [1, 2], "Example only works for 1 or 2 devices"
N = 12
T_A = 2
dtype = jnp.float64
# Create diagonal matrix and `b` all equal to one
A = jnp.diag(jnp.arange(N, dtype=dtype) + 1)
b = jnp.ones((N, 1), dtype=dtype)
ndev = len(devices)
# Make mesh and place data (columns sharded)
mesh = jax.make_mesh((ndev,), ("x",))
A = jax.device_put(A, NamedSharding(mesh, P(None, "x")))
b = jax.device_put(b, NamedSharding(mesh, P(None, None)))
# Call potrs
out = potrs(A, b, T_A=T_A, mesh=mesh, in_specs=(P(None, "x"), P(None, None)))
expected_out = 1.0 / (jnp.arange(N, dtype=dtype) + 1)
print(jnp.allclose(out.flatten(), expected_out))

Note that we did not have to perform the cyclic relayout here since `jaxmg.potrs` calls `cyclic_1d_layout` before calling the solver.

> **Note:** If the user can ensure that the matrix `A` is already in block cyclic form, then `jaxmg.potrs` can be called with the argument `cyclic_1d=True` (`False` by default). If the data is not laid out correctly, then calling `jaxmg.potrs` will result in an array of `NaN`s and a nonzero status, indicating the failure of solver.


