In [1]:
import os
#os.environ['LIBTPU_INIT_ARGS'] = '--xla_enable_transpose_trace=true'
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8, --xla_enable_transpose_trace=true' # Use 8 CPU devices
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
os.environ["LIBTPU_INIT_ARGS"] = '--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE'
# It does not work on TPU VMs, unable to create CPU device array.

In [2]:
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

In [4]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [3]:
from jax.tree_util import tree_map, tree_all

def allclose(a, b):
    return tree_all(
        tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)
    )


vmap specifies which index to treat as the vectorization axis, similar to batch.
both vmap and pmap can provide a name for an axis which can be used for a collective operation.
The replicas for collective operations must match the number of devices in the said axis

In [4]:
def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs])  # vmap reference semantics
  print(allclose(ans, expected))

def sq(x):
    y = x @ x.T
    return y

In [5]:
check_vmap(sq, jnp.arange(12).reshape(4, 3, 1))

True


Shardmap in effect does a device put on all the inputs as per in_spec, and then results of shardmap are by default collated in the main program assuming the out_spec to be true (there are reasonable checks).
Shardmap does not perform any collective by itself, the function is expected to provide the local computation.

In the following example, we keep the results of the matmul sharded in each replica.
The local matmul result has the right shape but the results are partial only.
Reduce scatter helps each shard get the correct matmul result partition.
The output stays sharded.
But when it's printed it's gathered to show the full value.


In [7]:


a = jnp.arange(8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)

devices = mesh_utils.create_device_mesh((2,2))
mesh = Mesh(devices, axis_names=('x', 'y'))

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', 'y'))
def matmul_keep_sharded (a_block, b_block):
    c_partialsum = jnp.dot(a_block, b_block)
    c_block = jax.lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True)
    #c_block = c_partialsum
    jax.debug.print("printing c_block shape: {a}", a=c_block.shape)  # works with shard_map
    jax.debug.print("printing c_partialsum shape: {a}", a=c_partialsum.shape)  # works with shard_map


    return c_block

c = matmul_keep_sharded(a, b)
print(c.shape)
jax.debug.visualize_array_sharding(c)
allclose(c, jnp.dot(a, b))

printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
(8, 4)


True

Another approach could be to keep the results sharded on axis=0, but complete on axis=1

In [8]:
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', None))
def matmul_keep_sharded_x (a_block, b_block):
    c_partialsum = jnp.dot(a_block, b_block)
    c_block = jax.lax.psum(c_partialsum, 'y')
    jax.debug.print("printing c_block shape: {a}", a=c_block.shape)  # works with shard_map
    jax.debug.print("printing c_partialsum shape: {a}", a=c_partialsum.shape)  # works with shard_map


    return c_block

c = matmul_keep_sharded_x(a, b)
print(c.shape)
jax.debug.visualize_array_sharding(c)
allclose(c, jnp.dot(a, b))

printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
(8, 4)


True

Yet another approach is keep results sharded on y axis but replicated on x

In [7]:
from jax.tree_util import tree_map, tree_all
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

def allclose(a, b):
    return tree_all(
        tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)
    )
    
a = jnp.arange(8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)

devices = mesh_utils.create_device_mesh((2,2))
mesh = Mesh(devices, axis_names=('x', 'y'))
    
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P(None, 'y'), check_rep=False)
def matmul_keep_sharded_y (a_block, b_block):
    c_partialsum = jnp.dot(a_block, b_block)
    c_block = jax.lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True)
    c_x = jax.lax.all_gather(c_block, 'x' , tiled=True)
    jax.debug.print("printing c_block shape: {a}, c-x shape: {b}", a=c_block.shape, b=c_x.shape) 
    jax.debug.print("printing c_partialsum shape: {a}", a=c_partialsum.shape)


    return c_x

c = matmul_keep_sharded_y(a, b)
print(c.shape)
jax.debug.visualize_array_sharding(c)
allclose(c, jnp.dot(a, b))

printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32)), c-x shape: (Array(8, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32)), c-x shape: (Array(8, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32)), c-x shape: (Array(8, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32)), c-x shape: (Array(8, dtype=int32), Array(2, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
(8, 4)


True

In [24]:
from jax.tree_util import tree_map, tree_all
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

import numpy as np

def allclose(a, b):
    return tree_all(
        tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)
    )

M, K, N = 16384, 1024, 2048
A = jax.random.normal(jax.random.key(0), (M, K))
X = jax.random.normal(jax.random.key(1), (K, N))

devices = mesh_utils.create_device_mesh((4, 1))
mesh = Mesh(devices, axis_names=('x', 'y'))

A_x = jax.device_put(A, NamedSharding(mesh, P('x', None)))
#X = jax.device_put(X, NamedSharding(mesh, P()))




In [8]:
@jax.jit
def f(lhs, rhs):
  result = lhs @ rhs
  return jax.lax.with_sharding_constraint(result, NamedSharding(mesh, P()))
  

with jax.profiler.trace("/tmp/jax-trace"):
  C1 = f(A_x, X)
  
print(allclose(C1, jnp.dot(A, X)))

jax.debug.visualize_array_sharding(C1)
jax.debug.visualize_array_sharding(A_x)




True


In [19]:
# @partial(shard_map,
#   mesh=mesh,
#   in_specs=(P('x', None), P()),
#   out_specs=P(),
#   check_rep=False
#   )
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
    with jax.named_scope('collective_matmul'):
        axis_size = jax.lax.psum(1, axis_name='x')
        axis_index = jax.lax.axis_index(axis_name='x')
        chunk_size = lhs.shape[0]
        
        shift = partial(
            jax.lax.ppermute,
            axis_name='x',
            perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
        )

        result = jnp.zeros((chunk_size * axis_size, rhs.shape[1]), dtype=rhs.dtype)
        result = jax.lax.dynamic_update_slice(result, lhs@rhs, (axis_index * chunk_size, 0))
        for i in range(1, axis_size):
            update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
            lhs = shift(lhs)
            result = jax.lax.dynamic_update_slice(result, lhs@rhs, update_index)
    return result

jit_sharded_f = jax.jit(shard_map(
  collective_matmul_allgather_lhs_non_contracting,
  mesh,
  in_specs=(P('x', None), P()),
  out_specs=P(),
  check_rep=False
  ))

with jax.profiler.trace("/tmp/jax-trace"):
    C = jit_sharded_f(A_x, X)
    
print(allclose(C, A @ X))
    

True


In [21]:
def collective_matmul_allgather_lhs_non_contracting_bidi(lhs, rhs):
    with jax.named_scope('collective_matmul'):
        axis_size = jax.lax.psum(1, axis_name='x')
        axis_index = jax.lax.axis_index(axis_name='x')
        chunk_size = lhs.shape[0]
        
        shift_right = partial(
            jax.lax.ppermute,
            axis_name='x',
            perm=[(j, (j + 1) % axis_size) for j in range(axis_size)]
        )
        shift_left = partial(
            jax.lax.ppermute,
            axis_name='x',
            perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
        )
        result = jnp.zeros((chunk_size * axis_size, rhs.shape[1]), dtype=rhs.dtype)
        result = jax.lax.dynamic_update_slice(result, lhs@rhs, (axis_index * chunk_size, 0))
        lhs_hi = jax.lax.dynamic_slice_in_dim(lhs, 0, (chunk_size + 1) // 2)
        lhs_lo = jax.lax.dynamic_slice_in_dim(lhs, chunk_size + 1 // 2, chunk_size - (chunk_size + 1) // 2)
    
        for i in range(1, axis_size):
            lhs_hi = shift_left(lhs_hi)
            lhs_lo = shift_right(lhs_lo)
            update_index_hi = (((axis_index + i) % axis_size) * chunk_size, 0)
            update_index_lo = (((axis_index - i) % axis_size) * chunk_size + (chunk_size + 1) // 2, 0)
            result = jax.lax.dynamic_update_slice(result, lhs_lo@rhs, update_index_lo)
            result = jax.lax.dynamic_update_slice(result, lhs_hi@rhs, update_index_hi)
    return result


jit_sharded_f = jax.jit(shard_map(
  collective_matmul_allgather_lhs_non_contracting_bidi,
  mesh,
  in_specs=(P('x', None), P()),
  out_specs=P(),
  check_rep=False
  ))

with jax.profiler.trace("/tmp/jax-trace"):
    C = jit_sharded_f(A_x, X)
    
print(allclose(C, A @ X))

True


In [27]:
Y = jax.random.normal(jax.random.key(1), (4 * K, N))

jit_sharded_f = jax.jit(shard_map(
  collective_matmul_allgather_lhs_non_contracting_bidi,
  mesh,
  in_specs=(P('x', None), P('x', None)),
  out_specs=P('x', None),
  check_rep=False
  ))

X_x = jax.device_put(Y, NamedSharding(mesh, P('x', None)))
with jax.profiler.trace("/tmp/jax-trace"):
    C = jit_sharded_f(A_x, X_x)
    
#print(allclose(C, A @ X))

In [28]:
from jax.tree_util import tree_map, tree_all
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

import numpy as np

def allclose(a, b):
    return tree_all(
        tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)
    )

M, K, N = 16384, 1024, 2048
A = jax.random.normal(jax.random.key(0), (M, K))
X = jax.random.normal(jax.random.key(1), (K, N))

devices = mesh_utils.create_device_mesh((4, 1))
mesh = Mesh(devices, axis_names=('x', 'y'))

lhs_spec = P(None, 'x')
rhs_spec = P('x', None)

A_x = jax.device_put(A, NamedSharding(mesh, lhs_spec))
X_x = jax.device_put(X, NamedSharding(mesh, rhs_spec))

In [30]:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
  out_summand = lhs_block @ rhs_block
  return jax.lax.psum_scatter(out_summand, 'x', tiled=True)

out = matmul_psumscatter(A_x, X_x)
print(jnp.allclose(out, A_x @ X_x, atol=1e-3, rtol=1e-3))

True


In [31]:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'x')
  idx = jax.lax.axis_index('x')
  shift = partial(jax.lax.ppermute, axis_name='x',
                  perm=[(i, (i - 1) % size) for i in range(size)])
  lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1])  # split 1st axis

  out_summand = lhs_block[(idx + 1) % size] @ rhs_block
  for i in range(1, size):
    out_summand = shift(out_summand)
    out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
  return out_summand

In [33]:
out = matmul_psumscatter_overlapped(A_x, X_x)
print(jnp.allclose(out, A_x @ X_x,  atol=1e-3, rtol=1e-3))

True
