(sharded-computation)=
# Introduction to sharded computation

JAX's {class}`jax.Array` object is designed with distributed data and computation in mind.

This section will cover three modes of parallel computation:

- Automatic parallelism via {func}`jax.jit`, in which we let the compiler choose the optimal computation strategy
- Semi-automatic parallelism using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`
- Fully manual parallelism using {func}`jax.experimental.shard_map.shard_map`

These examples will be run on Colab's free TPU runtime, which provides eight devices to work with:

In [1]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Key concept: data sharding

Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.

Each concrete {class}`jax.Array` object has a `sharding` attribute and a `devices()` method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:

In [2]:
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [3]:
arr.sharding

SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))

For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array:

In [4]:
jax.debug.visualize_array_sharding(arr)

To create an array with a non-trivial sharding, we can define a `sharding` specification for the array and pass this to {func}`jax.device_put`.
Here we'll define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes:

In [5]:
# Pardon the boilerplate; constructing a sharding will become easier soon!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))


Passing this `sharding` to {func}`jax.device_put`, we obtain a sharded array:

In [6]:
arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.



## Automatic parallelism via `jit`
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function!
The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.
In the simplest of cases, those heuristics boil down to *computation follows data*.

For example, here's a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:

In [7]:
@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)

shardings match: True


As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
Here we sum along the leading axis of `x`:

In [8]:
@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

[48. 52. 56. 60. 64. 68. 72. 76.]


The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.



## Semi-automated sharding with constraints

If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function.

For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:

In [9]:
@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  # mesh = jax.create_mesh((8,), 'x')
  devices = mesh_utils.create_device_mesh(8)
  mesh = jax.sharding.Mesh(devices, 'x')
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

[48. 52. 56. 60. 64. 68. 72. 76.]


This gives you a function with the particular output sharding you'd like.



## Manual parallelism with `shard_map`

In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices.
By contrast, with `shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.

`shard_map` works by mapping a function across a particular *mesh* of devices:

In [10]:
from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)

Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

The function you write only "sees" a single batch of the data, which we can see by printing the device local shape:

In [11]:
x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)

global shape: x.shape=(32,)
device local shape: x.shape=(4,)


Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here's what a `shard_map` of a `sum` looks like:

In [12]:
def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)

Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

Our function `f` operates separately on each shard, and the resulting summation reflects this.
If we want to sum across shards, we need to explicitly request it using collective operations like {func}`jax.lax.psum`:

In [13]:
def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)

Array(496, dtype=int32)

Because the output no longer has a sharded dimension, we set `out_specs=P()`.



## Comparing the three approaches

With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.
We'll define our canonical function like this:

In [14]:
@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)

In [15]:
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)

Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

We can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.
If we shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will autoatically happen in parallel:

In [16]:
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)

Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

Alternatively, we can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:

In [17]:
@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)  # pass in unsharded inputs

Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

Finally, we can do the same thing with `shard_map`, using `psum` to indicate the cross-shard collective required for the matrix product:

In [18]:
from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)

Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

This section has been a brief introduction of sharded and parallel computation;
for more discussion of `shard_map`, see {doc}`../notebooks/shard_map`.