<a href="https://colab.research.google.com/github/ybw9000/jax_playground/blob/main/jax_pmap_shard_map_and_jit_inshardings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import jax.nn as nn

In [2]:
devices = jax.devices()
devices

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
# would error out if there are less than 7 XLA devices
result = jax.pmap(lambda x: x ** 2)(jnp.arange(8))
print(result)

[ 0  1  4  9 16 25 36]


In [6]:
from jax import make_jaxpr

def f(x, y):
  a = jnp.dot(x, y)
  b = jnp.tanh(a)
  return b

xs = jnp.ones((8, 2, 3))
ys = jnp.ones((8, 3, 4))

# 2D dot general aka gemm
print("f jaxpr")
print(make_jaxpr(f)(xs[0], ys[0]))

# 3D dot general aka BMM
print("vmap(f) jaxpr")
print(make_jaxpr(jax.vmap(f))(xs, ys))

# 2D dot general nested in a xla_pmap
print("pmap(f) jaxpr")
print(make_jaxpr(jax.pmap(f))(xs, ys))

f jaxpr
{ lambda ; a:f32[2,3] b:f32[3,4]. let
    c:f32[2,4] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a b
    d:f32[2,4] = tanh c
  in (d,) }
vmap(f) jaxpr
{ lambda ; a:f32[8,2,3] b:f32[8,3,4]. let
    c:f32[8,2,4] = dot_general[
      dimension_numbers=(([2], [1]), ([0], [0]))
      preferred_element_type=float32
    ] a b
    d:f32[8,2,4] = tanh c
  in (d,) }
pmap(f) jaxpr
{ lambda ; a:f32[8,2,3] b:f32[8,3,4]. let
    c:f32[8,2,4] = xla_pmap[
      axis_name=<axis 0x7968e03cdbd0>
      axis_size=8
      backend=None
      call_jaxpr={ lambda ; d:f32[2,3] e:f32[3,4]. let
          f:f32[2,4] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] d e
          g:f32[2,4] = tanh f
        in (g,) }
      devices=None
      donated_invars=(False, False)
      global_axis_size=8
      in_axes=(0, 0)
      is_explicit_global_axis_size=False
      name=f
      

In [19]:
result = jax.pmap(lambda x: x ** 2)(jnp.arange(8))
z = result / 2
print(z)

[ 0.   0.5  2.   4.5  8.  12.5 18.  24.5]


In [11]:
# sharded over multiple devices
result.sharding

PmapSharding(sharding_spec=ShardingSpec((Unstacked(8),), (ShardedAxis(axis=0),)), devices=[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)])

In [12]:
# moved back to device 0 as result / 2 is not a sharded operation
z.sharding

SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), memory_kind=device)

In [13]:
result.global_shards

[Shard(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), index=(0,), replica_id=0, data=0),
 Shard(device=TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), index=(1,), replica_id=0, data=1),
 Shard(device=TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), index=(2,), replica_id=0, data=4),
 Shard(device=TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), index=(3,), replica_id=0, data=9),
 Shard(device=TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), index=(4,), replica_id=0, data=16),
 Shard(device=TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), index=(5,), replica_id=0, data=25),
 Shard(device=TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), index=(6,), replica_id=0, data=36),
 Shard(device=TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1), index=(7,), replica_id=0, data=49)]

In [14]:
# can not use jnp apis on each shard directly
type(result.global_shards[1])

In [17]:
z_sharded = jax.pmap(lambda x: x / 2)(result)

In [18]:
# same value but different layout
print(z_sharded)
print(z_sharded.sharding)

[ 0.   0.5  2.   4.5  8.  12.5 18.  24.5]
PmapSharding(sharding_spec=ShardingSpec((Unstacked(8),), (ShardedAxis(axis=0),)), device_ids=[0, 1, 2, 3, 4, 5, 6, 7], device_platform=TPU, device_shape=(8,))


In [20]:
from jax import random

# create 8 random keys
keys = random.split(random.key(0), 8)
# create a 5000 x 6000 matrix on each device by mapping over keys
mats = jax.pmap(lambda key: random.normal(key, (128, 256)))(keys)
# the stack of matrices is represented logically as a single array
print(mats.shape)
print(mats.sharding)

(8, 128, 256)
PmapSharding(sharding_spec=ShardingSpec((Unstacked(8), NoSharding(), NoSharding()), (ShardedAxis(axis=0),)), device_ids=[0, 1, 2, 3, 4, 5, 6, 7], device_platform=TPU, device_shape=(8,))


In [21]:
# run a local matmul on each device in parallel (no data transfer)
result = jax.pmap(lambda x: jnp.matmul(x, x.T))(mats)
result.shape

(8, 128, 128)

In [23]:
# only one scalar per device was pulled back to the host
res = jax.pmap(jnp.sum)(result)
print(res, res.sharding)

[33369.695 33508.656 33175.78  32662.605 31562.873 32247.285 35248.633
 30591.709] PmapSharding(sharding_spec=ShardingSpec((Unstacked(8),), (ShardedAxis(axis=0),)), device_ids=[0, 1, 2, 3, 4, 5, 6, 7], device_platform=TPU, device_shape=(8,))


In [30]:
# all reduce with psum
from jax import lax

def all_reduce(x):
  return lax.psum(x, 'k')

x = jax.pmap(lambda key: random.normal(key, (2, 4)))(keys)
y = jax.pmap(all_reduce, axis_name='k')(x)
single_device_y = jnp.sum(x, axis=0)
# all reduce so each the shape of y is still 3D aka 8, 2, 4
print(y.shape, single_device_y.shape)
# value should match
print(y[0], single_device_y)
print(jnp.allclose(y[0], single_device_y))

(8, 2, 4) (2, 4)
[[-0.8284544  -4.95767     0.55136424 -2.3609824 ]
 [-0.40268528  1.9361467  -4.852154   -1.8133668 ]] [[-0.8284545  -4.95767     0.5513642  -2.3609827 ]
 [-0.40268534  1.9361467  -4.852154   -1.813367  ]]
True


In [7]:
from jax import random
from jax import lax

# let's work on a gemm K shard
def gemm_kshard(x, y):
  z = jnp.matmul(x, y)
  print(z)
  return lax.psum(z, 'i')

keys = random.split(random.key(0), 8)
x = jax.pmap(lambda key: random.normal(key, (2, 4)))(keys)
y = jax.pmap(lambda key: random.normal(key, (4, 2)))(keys)
z = jax.pmap(gemm_kshard, axis_name='i')(x, y)
print(z.sharding, z.shape)
x_0 = x.transpose(1, 0, 2).reshape(x.shape[1], -1)
y_0 = y.reshape(-1, y.shape[-1])
print(x_0.sharding, y_0.sharding)
z_0 = jnp.matmul(x_0, y_0)
print(jnp.allclose(z[0], z_0))

Traced<ShapedArray(float32[2,2])>with<DynamicJaxprTrace(level=0/1)>
PmapSharding(sharding_spec=ShardingSpec((Unstacked(8), NoSharding(), NoSharding()), (ShardedAxis(axis=0),)), device_ids=[0, 1, 2, 3, 4, 5, 6, 7], device_platform=TPU, device_shape=(8,)) (8, 2, 2)
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), memory_kind=device) SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), memory_kind=device)
True


In [46]:
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental.shard_map import shard_map
from functools import partial

mesh = Mesh(devices, ('k',))

@partial(shard_map, mesh=mesh, in_specs=(P(None, 'k'), P('k', None)), out_specs=P(None, None))
def gemm_shard_map(x, y):
  z = jnp.matmul(x, y)
  return lax.psum(z, 'k')

z_shard = gemm_shard_map(x_0, y_0)
print(z_shard.sharding)

NamedSharding(mesh=Mesh('k': 8), spec=PartitionSpec(), memory_kind=device)


In [15]:
z_shard.shape

(2, 2)

In [16]:
print(jnp.allclose(z_shard, z_0))

True


In [19]:
@partial(shard_map, mesh=mesh, in_specs=(P(None, 'k'), P('k', None)), out_specs=P('k', None))
def gemm_reduce_scatter(x, y):
  z = jnp.matmul(x, y)
  return lax.psum_scatter(z, axis_name='k', scatter_dimension=0)

x = random.normal(keys[0], (8, 16))
y = random.normal(keys[0], (16, 2))
z = gemm_reduce_scatter(x, y)
print(z.sharding)

NamedSharding(mesh=Mesh('k': 8), spec=PartitionSpec('k',), memory_kind=device)


In [20]:
z.shape

(16,)

In [24]:
# check_rep has to be False otherwise it is legalizing against input is shard yet output is not sharded
@partial(shard_map, mesh=mesh, in_specs=(P(None, 'k'), P('k', None)), out_specs=P(None, None), check_rep=False)
def gemm_all_reduce(x, y):
  z = jnp.matmul(x, y)
  z_rs = lax.psum_scatter(z, axis_name='k', scatter_dimension=0)
  z_ag = lax.all_gather(z_rs, axis_name='k', axis=0)
  return z_ag

x = random.normal(keys[0], (8, 16))
y = random.normal(keys[0], (16, 2))
z = gemm_all_reduce(x, y)
print(z.sharding)
z_psum = gemm_shard_map(x, y)
print(z_psum.sharding)
print(jnp.allclose(z, z_psum))

NamedSharding(mesh=Mesh('k': 8), spec=PartitionSpec(), memory_kind=device)
NamedSharding(mesh=Mesh('k': 8), spec=PartitionSpec(), memory_kind=device)
True


In [51]:
def qkv(x, q, k, v, num_heads=8):
    # x: S, K
    # Q, K, V: K, M
    x_q = jnp.matmul(x, q).reshape(x.shape[0], num_heads, -1).transpose(1, 0, 2)  # S, M -> S, H, E -> H, S, E
    x_k = jnp.matmul(x, k).reshape(x.shape[0], num_heads, -1).transpose(1, 0, 2)  # S, M -> S, H, E -> H, S, E
    x_v = jnp.matmul(x, v).reshape(x.shape[0], num_heads, -1).transpose(1, 0, 2)  # S, M -> S, H, E -> H, S, E
    return x_q, x_k, x_v

def attention(q, k, v, proj):
    k_t = k.transpose(0, 2, 1)  # H, S, E -> H, E, S
    qk = jnp.matmul(q, k_t)  # H, S, S
    qk = qk / jnp.sqrt(q.shape[-1])
    qk = jax.nn.softmax(qk, axis=-1)
    mm1 = jnp.matmul(qk, v)  # H, S, E
    head_fuse = mm1.transpose(1, 0, 2).reshape(x.shape[0], -1)  # S, H, E -> S, E*H
    proj_out = jnp.matmul(head_fuse, proj)
    return proj_out

def mlp(up, down, x):
    out = jnp.matmul(x, up)
    out = jax.nn.gelu(out)
    out = jnp.matmul(out, down)
    return out

def transformer_layer(x, q, k, v, proj, up, down):
    x_q, x_k, x_v = qkv(x, q, k, v)
    out = attention(x_q, x_k, x_v, x, proj) + x
    out = mlp(up, down, out) + out
    return out

In [30]:
x = random.normal(keys[0], (128, 1024))
q = random.normal(keys[1], (1024, 1024)) / jnp.sqrt(1024)
k = random.normal(keys[2], (1024, 1024)) / jnp.sqrt(1024)
v = random.normal(keys[3], (1024, 1024)) / jnp.sqrt(1024)
proj = random.normal(keys[4], (1024, 1024)) / jnp.sqrt(1024)
up = random.normal(keys[5], (1024, 4096)) / jnp.sqrt(1024)
down = random.normal(keys[6], (4096, 1024)) / jnp.sqrt(4096)

In [31]:
out = transformer_layer(x, q, k, v, proj, up, down)
print(out.sharding)

SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), memory_kind=device)


In [34]:
in_shardings = (
    jax.sharding.NamedSharding(mesh, P(None, None)),  # x
    jax.sharding.NamedSharding(mesh, P(None, 'k')),  # q
    jax.sharding.NamedSharding(mesh, P(None, 'k')),  # k
    jax.sharding.NamedSharding(mesh, P(None, 'k')),  # v
    jax.sharding.NamedSharding(mesh, P('k', None)),  # proj
    jax.sharding.NamedSharding(mesh, P(None, 'k')),  # up
    jax.sharding.NamedSharding(mesh, P('k', None))   # down
)
transformer_layer_gspmd = jax.jit(transformer_layer, in_shardings=in_shardings)

In [35]:
out_gspmd = transformer_layer_gspmd(x, q, k, v, proj, up, down)
print(out_gspmd.sharding)

NamedSharding(mesh=Mesh('k': 8), spec=PartitionSpec(None, None), memory_kind=device)


In [36]:
out_gspmd.shape

(128, 1024)

In [37]:
out.shape

(128, 1024)

In [38]:
print(jnp.allclose(out, out_gspmd))

False


In [41]:
print(jnp.allclose(out[:, 0], out_gspmd[:, 0]))

True


In [44]:
print(jnp.allclose(out[0, :128], out_gspmd[0, :128]))

True


In [48]:
x_sharded = jax.device_put(x, NamedSharding(mesh, P(None, None)))
q_sharded = jax.device_put(q,  NamedSharding(mesh, P(None, 'k')))
k_sharded = jax.device_put(k,  NamedSharding(mesh, P(None, 'k')))
v_sharded = jax.device_put(v,  NamedSharding(mesh, P(None, 'k')))
proj_sharded = jax.device_put(proj,  NamedSharding(mesh, P('k', None)))
up_sharded = jax.device_put(up,  NamedSharding(mesh, P(None, 'k')))
down_sharded = jax.device_put(down,  NamedSharding(mesh, P('k', None)))

In [49]:
out_gspmd_sharded_input = jax.jit(transformer_layer)(x_sharded, q_sharded, k_sharded, v_sharded, proj_sharded, up_sharded, down_sharded)

In [50]:
print(jnp.allclose(out_gspmd, out_gspmd_sharded_input))

True


In [None]:
qkv_manual_shard = shard_map(qkv, mesh, in_specs=(P(None, None), P(None, 'k'), P(None, 'k'), P(None, 'k')), out_specs=(P('k', None, None)) * 3)

@partial(shard_map, mesh=mesh, in_specs=(P('k', None, None)) * 3 + (P('k', None)), out_specs=P(None, None))
def attention_manual_shard(q, k, v proj):
    out = attention(q, k, v, proj)
    return lax.psum(out, 'k')


@partial(shard_map, mesh=mesh, in_specs=(P(None, 'k'), P('k', None), P(None, None)), out_specs=P(None, None))
def mlp_manual_shard(up, down, x):
    out = jnp.matmul(x, up)
    out = jax.nn.gelu(out)
    out = jnp.matmul(out, down)
    return out

def transformer_layer_manual_shard(x, q, k, v, proj, up, down):
    x_q, x_k, x_v = qkv_manual_shard(x, q, k, v)
    out = mlp_manual_shard(x_q, x_k, x_v, x, proj) + x
    out = mlp_manual_shard(up, down, out) + out
    return out


In [None]:
out_manual_shard = transformer_layer_manual_shard(x_sharded, q_sharded, k_sharded, v_sharded, proj_sharded, up_sharded, down_sharded)
print(jnp.allclose(out_gspmd, out_manual_shard))