<a href="https://colab.research.google.com/github/sw32-seo/ProJAX/blob/main/Introduction_to_sharded_computation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to sharded computation

In [38]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [39]:
import jax
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

# Key concept: Data sharding

- Data sharding: How data is laid out on available devices.
- Every `jax.Array` has an associated `jax.sharding.Sharding` object
- `jax.sharding.Sharding` describes which shard of the global data is required by each global device.

In [40]:
# arrays are sharded on a single device
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()

{CpuDevice(id=0)}

In [41]:
arr.sharding

SingleDeviceSharding(device=CpuDevice(id=0))

In [42]:
jax.debug.visualize_array_sharding(arr)

- `NamedSharding` specifies an N-dimensional grid of devices with named axes.
- `jax.sharding.Mesh` allows for precise device placement.

In [43]:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))


In [44]:
arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


# 1. Automatic parallelism via `jit`

`jax.jit()`-compiled function will automatically partition data inside and compile inter-device communication.

In [45]:
@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match: ", result.sharding == arr_sharded.sharding)

shardings match:  True


In [46]:
@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

[48. 52. 56. 60. 64. 68. 72. 76.]


# 2. Semi-automated sharding with constraints

`jax.lax.with_sharding_constraint()` can control over internel values and outputs.

In [47]:
@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  # mesh = jax.create_mesh((8,), 'x')
  devices = mesh_utils.create_device_mesh(8)
  mesh = jax.sharding.Mesh(devices, 'x')
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

[48. 52. 56. 60. 64. 68. 72. 76.]


# 3. Manual parallelism with `shard_map`

`jax.experimental.shard_map.shard_map` will map the function to a each shard of data.

In [48]:
from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')  # `jax.devices()` will give 1-D device array.
f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x')
)

arr = jnp.arange(32)
f_elementwise_sharded(arr)

Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116902,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314587,  1.840334  ,  2.9812148 ,
        2.3005757 ,  0.42419338, -0.92279494, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244087, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

In below, the function `f` applies mulplying 2 for each shard of data.

In [49]:
x = jnp.arange(32)
print(f"global shape: {x.shape}")

def f(x):
  print(f"device local shape: {x.shape}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)

global shape: (32,)
device local shape: (4,)


Because each of the functions only "sees" the device-local part of the data, it means that aggregation like functions require some extra thought.

For example, here's what a `shard_map` of a `jax.numpy.sum()`

In [50]:
def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)

Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

The function `f` operates separately on each shard, and the resulting summation shape reflects this.

If you want to sum across shards, you need to use collective operations like `jax.lax.psum()`

In [51]:
def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)

Array(496, dtype=int32)

Because the output no longer has a sharded dimension, set `out_specs=P()`

# Comparing the three approaches

In [52]:
@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)

In [53]:
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)

Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

## Automatic sharding

If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:

In [54]:
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)

Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

Alternatively, you can use `jax.lax.with_sharding_constraint()` in the function to automatically distribute unsharded inputs:

In [55]:
@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)

Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

Fianlly, you can do the same thing with `shard_map`, using `jax.lax.psum()` to indicate the cross-shard collective required for the matrix product.

In [56]:
from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))

def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)

Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)