<a href="https://colab.research.google.com/github/sw32-seo/ProJAX/blob/main/Introduction_to_sharded_computation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to sharded computation

In [None]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

# Key concept: Data sharding

- Data sharding: How data is laid out on available devices.
- Every `jax.Array` has an associated `jax.sharding.Sharding` object
- `jax.sharding.Sharding` describes which shard of the global data is required by each global device.

In [None]:
# arrays are sharded on a single device
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [None]:
arr.sharding

SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))

In [None]:
jax.debug.visualize_array_sharding(arr)

- `NamedSharding` specifies an N-dimensional grid of devices with named axes.
- `jax.sharding.Mesh` allows for precise device placement.

In [None]:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))


In [None]:
arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


# Automatic parallelism via `jit`