# Distributed Arrays with JAX

In [None]:
import jax
import numpy as np

In [None]:
gpus = jax.devices('gpu')

In [None]:
for i, d in enumerate(gpus):
    print(f"Device {i}: {d}")

In [None]:
jax.local_device_count()

### `Sharding` describes how array values are laid out in memory across devices

In [None]:
mesh = jax.sharding.Mesh(jax.devices(), axis_names=['x'])

In [None]:
x = jax.random.normal(jax.random.key(0), (16000, 16000))

In [None]:
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', None)))

In [None]:
y.sum(axis=0)

In [None]:
jax.debug.visualize_array_sharding(y)

In [None]:
@jax.jit
def my_complex_sin(z):
    y = jax.numpy.sin(z) + jax.numpy.cos(z) + z ** 2

    return y.sum(axis=0)

In [None]:
%time z = my_complex_sin(x).block_until_ready()

In [None]:
%time z2 = my_complex_sin(y).block_until_ready()

In [None]:
z2.sharding

### JAX takes sharding into account when performing computations

#### <mark>Hands-on:</mark> Perform some standard array operations (e.g Matrix-Matrix Multiplication), compare the execution with non-sharded arrays and check the resulting array sharding.