<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_Jax_Tutorial_shard_map_toy_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map

# Matrix Multiplication

In [2]:
mesh = Mesh(jax.devices()[:4], ('i', ))

def device_put(x, pspec):
  return jax.device_put(x, NamedSharding(mesh, pspec))

## Example 1: `all-gather` on one side

Consider a matrix multiplication where we shard the left-hand side argument on its leading dimension:

In [None]:
lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)

In [None]:
rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)

In [None]:
jax.debug.visualize_array_sharding(lhs)

To perform this matrix multiplication, we can first all-gather the right-hand side and then perform local matrix multiplication against the sharded left-hand side:

In [None]:
from functools import partial

### Review
 - `in_specs` identifies some of the corresponding input array's axes with mesh axes by name using `PartitionSpec`s, representing how to split that input into the blocks to which the body function is applied.
 - `out_specs` identifies some of the corresponding output array's axes with mesh axes by name, represnting how the output blocks should be assembled back together to form the final output value.

In [None]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
  # all-gather : gathering array shards along an axis,
  #              so that each function application has a full copy of the data
  #              along that axis
  rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
  return lhs_block @ rhs

In [None]:
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs))

True


We're not getting any compute/communication overlap here: before we can start the matmul, we need the `all_gather` to complete.

In [None]:
from jax import lax

In [None]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')  # number of devices
  idx = jax.lax.axis_index('i')

  shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size

  # start_index: i * B / slice_size: B / axis: 1
  lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)

  out_block = lhs_blocks(idx) @ rhs_block
  for i in range(1, size):
    rhs_block = shift(rhs_block)
    out_block += lhs_blocks((idx - i) % size) @ rhs_block
  return out_block

In [None]:
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))

True


The implementation allows overlap between communication and computation, and also avoids gathering a large intermediate onto each device. But on TPU it uses only half the interconnect bandwidth by permuting in only one direction along the ring.

In [None]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size // 2
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2 * i + hi) * B, B, 1)

  rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
  out_block  = lhs_blocks(idx, 0) @ rhs_block_lo
  out_block += lhs_blocks(idx, 1) @ rhs_block_hi
  for i in range(1, size):
    rhs_block_lo = shift_up(rhs_block_lo)
    rhs_block_hi = shift_dn(rhs_block_hi)
    out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
    out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
  return out_block

In [None]:
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))

True


In [None]:
out

Array([[ 0.4204327 , -3.2336767 ,  0.46415687,  2.3247633 ],
       [-3.5942657 , -0.24846068, -1.9459741 ,  1.9542422 ],
       [-2.1133657 ,  0.61801076, -0.18297681,  3.210961  ],
       [-0.55151576, -3.1808214 ,  1.4808508 ,  0.21939349],
       [-0.9624958 ,  5.5949306 ,  1.8143425 ,  0.3755707 ],
       [-3.1725793 , -1.1635897 ,  1.3896794 , -0.84395945],
       [ 1.0931401 , -2.4149542 ,  1.5044949 ,  1.3278372 ],
       [ 0.1603303 , -0.08382738, -0.7018988 , -1.267864  ]],      dtype=float32)

In [None]:
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)

  rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
  out_block  = lhs_blocks(idx, 0) @ rhs_block_lo
  out_block += lhs_blocks(idx, 1) @ rhs_block_hi
  for i in range(1, size):
    rhs_block_lo = shift_up(rhs_block_lo)
    rhs_block_hi = shift_dn(rhs_block_hi)
    out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
    out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
  return out_block

In [None]:
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))

False


In [None]:
out

Array([[ 0.4204327 , -3.2336767 ,  0.46415687,  2.3247633 ],
       [-3.5942657 , -0.24846068, -1.9459741 ,  1.9542422 ],
       [-2.1133657 ,  0.61801076, -0.18297681,  3.210961  ],
       [-0.55151576, -3.1808214 ,  1.4808508 ,  0.21939349],
       [-0.9624958 ,  5.5949306 ,  1.8143425 ,  0.3755707 ],
       [-3.1725793 , -1.1635897 ,  1.3896794 , -0.84395945],
       [ 1.0931401 , -2.4149542 ,  1.5044949 ,  1.3278372 ],
       [ 0.1603303 , -0.08382738, -0.7018988 , -1.267864  ]],      dtype=float32)

In [None]:
lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec) # Divide the data column by column

In [None]:
jax.debug.visualize_array_sharding(lhs)

In [None]:
rhs_spec = P('i', None) # Divide the data row by row
rhs = device_put(rhs, rhs_spec)

In [None]:
jax.debug.visualize_array_sharding(rhs)

In [None]:
from functools import partial

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
  out_summand = lhs_block @ rhs_block
  return jax.lax.psum_scatter(out_summand, 'i', tiled=True)

out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs))

True


But the scattering communication must wait for the entire local matrix multiply to finish before it can start.

In [None]:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  # interleave the communication steps with local matrix multiplication
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i - 1) % size) for i in range(size)])
  # shape of lhs_block (A, B // size)
  # changed shape      (size, A // size, B // size)
  # i.e. divide the rows as well
  lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1])

  # device idx computes partial result = idx-th row of col idx @ row idx
  # This partial result will contribute to the idx-th row of the final result
  # For device 0, here it computes local contribution to the second row of the final result
  out_summand = lhs_block[(idx + 1) % size] @ rhs_block
  for i in range(1, size):
    # here, we send out_summand to the next gpu
    # at the same time, we recieve out_summand from the previous gpu
    # ex. From device 0 -> device (N - 1) -> ... -> device 1
    # ex. From device 1 -> device 0 -> ... ->       device 2
    out_summand = shift(out_summand)
    #           out_summand: idx + i + 1 // current index: idx
    # ex. i = 1 out_summand:
    out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
  return out_summand

In [None]:
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs))

True


In [None]:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[0] // size // 2 # half-sized blocks
  lhs_blocks = lambda i, hi: jax.lax.dynamic_slice_in_dim(
      lhs_block,
      (2 * i + hi) * B,
      B,
      0
  )

  # out_summand_lo will be shifted to the next gpu
  # it'll end up at (idx - 1)-th gpu
  out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
  # out_summand_hi will be shifted to the prev gpu
  # it'll end up at (idx + 1)-th gpu
  out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block

  for i in range(1, size):
    out_summand_lo = shift_up(out_summand_lo)
    out_summand_hi = shift_dn(out_summand_hi)

    # Let's say i'm device 0
    # I receive the copy from gpu N-1
    # This output corresponds to (N-2)-th row
    out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block

    # Let's say i'm device 0
    # received the copy from gpu 1
    # This output corresponds to 3rd row (idx = 2)
    out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block

  return jnp.concatenate([out_summand_lo, out_summand_hi])

In [None]:
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-2, rtol=1e-2))

True


# Neural Networks

In [11]:
import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))

In [None]:
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

In [None]:
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32

params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)

While in automatic partitioning, we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do.

## 8-way batch data parallelism

 - shard the batch of inputs and targets over multiple devices
 - replicate the parameters over those devices

To evaluate the total loss, the devices need only communicate with a scalar-sized all-reduce-sum at the end.

In [None]:
from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((8,))

In [None]:
# Replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

In [None]:
# Adapt the loss fcn to sum the losses across devices
def loss_dp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
  def loss_spmd(local_batch):
    inputs, targets = local_batch
    predictions = predict(params, inputs) # local prediction
    local_loss = jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(batch)

In [None]:
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))

22.795706
22.795706


In [None]:
def allclose(a, b):
  return jax.tree_util.tree_all(jax.tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

In [None]:
print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_dp))(params, batch)))

True


## 8-way fully sharded data parallelism (FSDP)

- Additionally shard the parameters over the devices, and all-gather each one when the full value is needed for the `jnp.dot` or bias addition.

- So now we need collectives in two places: model prediction function `predict` needs to all-gather the parameters before they're used & as in the DP case the loss function needs to sum the local losses to compute the total loss

- one more ingredient: we don't want to store the fully gathered parameters from the forward pass for use on backward pass. Instead, we want to gather them again on the backward pass. We can express that by using `jax.mat` with a custom policy though XLA typically does that rematerialization automatically..

In [5]:
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

In [6]:
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32

params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)

In [8]:
from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((8,))

In [9]:
# shard data batch and *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))

In [10]:
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
# *_ to express no interest in any further arguments
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss_fsdp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
  def loss_spmd(local_params, local_batch):
    inputs, targets = local_batch
    predictions = predict_fsdp(local_params, inputs)
    local_loss = jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(params, batch)

In [12]:
print(jax.jit(loss)(params, batch))

22.795706


In [13]:
print(jax.jit(loss_fsdp)(params, batch))

22.795706


In [28]:
def allclose(a, b):
  return jax.tree_util.tree_all(jax.tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

In [19]:
print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp))(params, batch)))

True


## 8-way tensor parallelism

Parallelization idea is that we'll keep the data/activations sharded over its feature axis (rather than its batch axis), and we'll similarly shard weight matrices over their input-feature axis (and biases over their feature values). Then to perform the parallel matrix multiplications, we'll perform local matrix multiplications followed by a `psum_scatter` to sum the local results and efficiently scatter the result's shards.

In [24]:
mesh = Mesh(devices, ('feats',))

batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))

def predict_tp(params, inputs):
  for W, b in params:
    outputs = gemm_tp(inputs, W, b)
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh, in_specs=(P(None, 'feats'), P('feats', None), P('feats')), out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
  block_result = jnp.dot(inputs, W)
  return jax.lax.psum_scatter(block_result, 'feats', scatter_dimension=1, tiled=True) + b

def loss_tp(params, batch):
  inputs, targets = batch
  predictions = predict_tp(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!


In [25]:
print(jax.jit(loss)(params, batch))

22.795746


In [26]:
print(jax.jit(loss_tp)(params, batch))

22.795746


In [29]:
print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_tp))(params, batch)))

True
