# Block-cyclic data layout

To use cuSolverMg, matrices must be stored in **1D block-cyclic, column-major form**. The reason for this is to ensure that all devices participating in a specific routine can perform computations without being blocked by other parts of the computation (see Dongarra 1996). In `jaxmg`, we handles this transformation on the JAX side with a single **all-to-all** within a `jax.shard_map` context.

Consider the case where we have 2 GPUs available and we are trying to solve the linear 
system $A\cdot x =b$, where $A$ is an $12\times12$, positive-definite matrix and $b$ corresponds to a vector of ones. Every shard on each GPU will be of size $12\times 6$.
We require a cyclic 1D tiling with tile size `T_A=2` for `cuSolverMg` to work:

<p align="center"><img src="../../_static/mat_example.png" alt="Matrix layout illustration" width="500"></p>


In order to interweave the blocks, we need to ensure that each shard is a multiple of
`ndev * T_A = 2`, so that we can reshape to `(ndev, T_A, ...)` and exchange the blocks via `jax.lax.all_to_all`. We therefore add zero padding of 2 columns to each shard (see top figure). After interweaving the blocks, we are left with extra padding on the right, which we ignore in the solver itself. After the solver is called, we again use a
single `jax.lax.all_to_all` call to remap the data back to block-sharded form. 

We can achieve this layout in `jaxmg` with the following code:

In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jaxmg import cyclic_1d_layout

# Assumes we have at least one GPU available
devices = jax.devices("gpu")
assert len(devices) in [1, 2], "Example only works for 1 or 2 devices"
N = 12
T_A = 2
dtype = jnp.float32
ndev = jax.device_count()
# Create diagonal matrix and `b` all equal to one
A = jnp.diag(jnp.arange(N, dtype=dtype) + 1)
mesh = jax.make_mesh((ndev,), ("x",))
A = jax.device_put(A, NamedSharding(mesh, P(None, "x")))
A_bc = cyclic_1d_layout(A, T_A)

for shard in A_bc.addressable_shards:
    print(f"shard{shard.data}")

A more involved example is the case where we have 4 GPUS, `N=100` and we want a tiling of `T_A=4`. 
Now we need a padding of 7 on each GPU in order to perform data remappping (produced with above code):

<p align="center"><img src="../../_static/mat.png" alt="Matrix layout illustration" width="500"></p>

- Dongarra, J.J., and D.W. Walker. *The Design of Linear Algebra Libraries for High Performance Computers.* Office of Scientific and Technical Information (OSTI), August 1, 1993. https://doi.org/10.2172/10184308.