In [2]:
import os
#os.environ['LIBTPU_INIT_ARGS'] = '--xla_enable_transpose_trace=true'
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8, --xla_enable_transpose_trace=true' # Use 8 CPU devices
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
# It does not work on TPU VMs, unable to create CPU device array.

In [3]:
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

In [4]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [11]:
from jax.tree_util import tree_map, tree_all

def allclose(a, b):
    return tree_all(
        tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)
    )


In [12]:
devices = mesh_utils.create_device_mesh((2,2))
mesh = Mesh(devices, axis_names=('x', 'y'))

vmap specifies which index to treat as the vectorization axis, similar to batch.
both vmap and pmap can provide any named axis which can be used for a collective operation.
The replicas for collective operations must match the number of devices in the said axis

In [27]:
def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs])  # vmap reference semantics
  print(ans.shape)
  print(allclose(ans, expected))

def sq(x):
    y = x @ x.T
    return y

In [28]:
check_vmap(sq, jnp.arange(12).reshape(4, 3, 1))

(4, 3, 3)
True


In [43]:
a = jnp.arange(8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)

Shardmap in effect does a device put on all the inputs as per in_spec, and then results of shardmap are by default collated in the main program assuming the out_spec to be true (there are reasonable checks).
Shardmap does not perform any collective by itself, the function is expected to provide the local computation.

In the following example, we keep the results of the matmul sharded in each replica.
The local matmul result has the right shape but the results are partial only.
Reduce scatter helps each shard get the correct matmul result partition.
The output stays sharded.
But when it's printed it's gathered to show the full value.


In [48]:
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', 'y'))
def not_matmul (a_block, b_block):
    c_partialsum = jnp.dot(a_block, b_block)
    c_block = jax.lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True)
    #c_block = c_partialsum
    jax.debug.print("printing c_block shape: {a}", a=c_block.shape)  # works with shard_map
    jax.debug.print("printing c_partialsum shape: {a}", a=c_partialsum.shape)  # works with shard_map


    return c_block

c = not_matmul(a, b)
print(c.shape)
jax.debug.visualize_array_sharding(c)
allclose(c, jnp.dot(a, b))

printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(2, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
(8, 4)


True

Another approach could be to keep the results sharded on axis=0, but complete on axis=1

In [50]:
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', None))
def basic_matmul (a_block, b_block):
    c_partialsum = jnp.dot(a_block, b_block)
    c_block = jax.lax.psum(c_partialsum, 'y')
    jax.debug.print("printing c_block shape: {a}", a=c_block.shape)  # works with shard_map
    jax.debug.print("printing c_partialsum shape: {a}", a=c_partialsum.shape)  # works with shard_map


    return c_block

c = basic_matmul(a, b)
print(c.shape)
jax.debug.visualize_array_sharding(c)
allclose(c, jnp.dot(a, b))

printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_block shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
printing c_partialsum shape: (Array(4, dtype=int32), Array(4, dtype=int32))
(8, 4)


True

In [51]:
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)

NameError: name 'f' is not defined