# Calling the solver in a `jax.shard_map` context


The `potrs` interface uses a call to `jax.shard_map` to relayout the data in 1D cyclic form and call the underlying cuSolverMg API. In practice, one may have a more complicated jitted function that manipulates shards in a shard_map context already, which 
requires calling the solver within this function on individual shards.

To allow for this use case, we also provide an API that has to be called in a shard_map context. 
Here, we rely on the user to correctly call `jax.shard_map`, passing the correct in and out shardings to their own function.

In the example below, we use this API for a trivial matrix, now with `complex64` data type, where we apply a diagonal shift to to the 
matrix `A` before handing it to the solver:

In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jaxmg import potrs_no_shardmap
from functools import partial

# Assumes we have at least one GPU available
devices = jax.devices("gpu")
assert len(devices) in [1, 2], "Example only works for 1 or 2 devices"
N = 8
T_A = 2
dtype = jnp.complex64
# Create diagonal matrix and `b` all equal to one
A = jnp.diag(jnp.arange(N, dtype=dtype) + 1)
b = jnp.ones((N, 1), dtype=dtype)
ndev = len(devices)
# Make mesh and place data (columns sharded)
mesh = jax.make_mesh((ndev,), ("x",))
A = jax.device_put(A, NamedSharding(mesh, P(None, "x")))
b = jax.device_put(b, NamedSharding(mesh, P(None, None)))
diag_shift = 1e-1

@partial(jax.jit, static_argnames=("_T_A",))
def shift_and_solve(_a, _b, _ds, _T_A):
    idx = jnp.arange(_a.shape[0])
    shard_size = _a.shape[1]
    # Add shift based on index.
    _a = _a.at[idx + shard_size * jax.lax.axis_index("x"), idx].add(_ds)
    jax.debug.print("dev{}:_a={}", jax.lax.axis_index("x"), _a)
    # Call solver in shard_map context
    return potrs_no_shardmap(_a, _b, _T_A)

@partial(jax.jit, static_argnames=("_T_A",))
def jitted_potrs(_a, _b, _ds, _T_A):
    out = jax.shard_map(
        partial(shift_and_solve, _T_A=_T_A),
        mesh=mesh,
        in_specs=(P(None, "x"), P(None, None), P()),
        out_specs=(P(None, None), P(None)),
        check_vma=False
    )(_a, _b, _ds)
    return out

out, status = jitted_potrs(A, b, diag_shift, T_A)
print(f"Status: {status}")
expected_out = 1.0 / (jnp.arange(N, dtype=dtype) + 1 + diag_shift)
print(jnp.allclose(out.flatten(), expected_out))

for two devices, this will print



> **Note:** `potrs_no_shardmap` always returns a status.

> **Note:** Jax will complain about replication errors if you do not pass `check_vma=True`. This is likely because it cannot infer the output sharding from the ffi call.