<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_JAX_Tutorial_Shard_Map.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SPMD multi-device parallelism with shard_map

Mapped function applications, or instances, communicate with each other via explicit collective communication operations.

In [1]:
import os
# Use 8 CPU devices
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

So, let's see a `shard_map`

In [3]:
from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

In [4]:
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

In [14]:
a = jnp.arange(8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block f32[2, 8]
  # b_block f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block f32[2, 4]
  return c_block

c = matmul_basic(a, b)

In [6]:
jax.debug.visualize_array_sharding(c)

In [10]:
from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(
      partial(jnp.allclose, atol=1e-2, rtol=1e-2),
      a,
      b
    )
  )

allclose(jnp.dot(a, b), c)

True

Same code with automatic parallelization

In [11]:
mesh

Mesh(device_ids=array([[0, 1],
       [2, 3],
       [6, 7],
       [4, 5]]), axis_names=('x', 'y'))

In [15]:
from jax.sharding import NamedSharding

a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))

@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))

c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))

True

We can think of `shard_map` as performing a `device_put` or `with_sharding_constraint` on its inputs according to its `mesh` and `in_specs` arguments, so the blocks over which `matmul_basic` operates are the same as in `matmul_reference`

In [16]:
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)

a blocks:


b blocks:


c blocks:


## Slow down, start with the basics!

We can think of `pmap` and `vmap` as unstacking each array input along an axis, applying its body function to each piece, and stacking the results back together.

In [17]:
def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
  print(allclose(ans, expected))

check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))

True


`shard_map` does not have rank-reducing behavior. Instead, we can think of it as slicing (or 'unconcatenating') along input axes into blocks, applying the body function, and concatenating the results back together.

In [18]:
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i', ))

In [20]:
def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))

check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))

True


We can say it's a rank-preserving map with unconcatenating/concatenating of its inputs/outputs