<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_JAX_Tutorial_Using_JAX_in_multi_host_and_multi_process_environments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

In [2]:
if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run")

Goal of this tutorial
 - `jax.Array`: unified datatype for representing arrays, even with physical storage spanning multiple devices
 - how using `jax.Array`s together with `jax.jit` can provide automatic compiler-based parallelization

In [3]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

In [4]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

In [5]:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
y = jax.device_put(x, sharding.reshape(4, 2))

In [6]:
jax.debug.visualize_array_sharding(y)

In [11]:
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()

9.74 ms ± 61.1 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [12]:
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

2.39 ms ± 72.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


## `Sharding` describes how array values are laid out in memory across devices

In Jax, `Sharding` objects describe distributed memory layouts. They can be used with `jax.device_put` to produce a value with distributed layout.

In [33]:
import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))

In [14]:
jax.debug.visualize_array_sharding(x)

In [15]:
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))

In [16]:
devices

array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
       TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
       TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
       TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
       TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1),
       TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
       TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)],
      dtype=object)

Here, `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements:

In [35]:
from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)
x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x)

In [20]:
sharding.reshape((2, 4))

PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7} {TPU 4} {TPU 5}]])

In [22]:
from typing import Sequence

def is_congruent(x_shape: Sequence[int],
                 sharding_shape: Sequence[int]) -> bool:
  return (len(x_shape) == len(sharding_shape) and
          all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))

In [23]:
is_congruent(x.shape, (4, 2))

True

In [24]:
sharding = sharding.reshape(4, 2)
print(sharding)

PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]])


In [26]:
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

In [27]:
sharding = sharding.reshape(1, 8)
print(sharding)

PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]])


In [28]:
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

In [37]:
sharding = sharding.reshape(4, 2)
print(sharding.replicate(axis=0, keepdims=True))

PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]])


In [38]:
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y)

In [39]:
print(sharding.replicate(0).shape)

(1, 2)


In [40]:
print(sharding.replicate(1).shape)

(4, 1)


In [42]:
y = jax.device_put(x, sharding.replicate(1))
jax.debug.visualize_array_sharding(y)


## `NamedSharding` gives a way to express shardings with names

In [43]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)

In [47]:
devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)

In [48]:
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)

In [49]:
y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)

In [50]:
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)

In [51]:
y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)

In [52]:
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)

In [53]:
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)

## Computation follows data sharding and is automatically parallelized.

Computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary

In [54]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

In [55]:
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding:')
jax.debug.visualize_array_sharding(x)

input sharding:


In [57]:
y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)

output sharding:


In [58]:
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))

print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

lhs sharding:


rhs sharding:


In [59]:
w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)

out sharding:


In [60]:
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)

In [61]:
np.allclose(jnp.dot(x_single, x_single), jnp.dot(y, z))

True

In [63]:
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()

49.5 ms ± 48.7 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [64]:
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()

7.41 ms ± 20 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [65]:
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)

Whenb we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding.

## When explicit shardings disagree, JAX errors

In [66]:
import textwrap
from termcolor import colored

In [67]:
def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red')
  print(textwrap.fill(f'{name}: {str(e)}'))

In [68]:
sharding1 = PositionalSharding(jax.devices()[:4])
sharding2 = PositionalSharding(jax.devices()[4:])

y = jax.device_put(x, sharding1.reshape(2, 2))
z = jax.device_put(x, sharding2.reshape(2, 2))
try:
  y + z
except ValueError as e:
  print_exception(e)

ValueError: Received incompatible devices for jitted computation. Got
argument x1 of jax.numpy.add with shape float32[8192,8192] and device
ids [0, 1, 2, 3] on platform TPU and argument x2 of jax.numpy.add with
shape float32[8192,8192] and device ids [4, 5, 6, 7] on platform TPU


In [69]:
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = PositionalSharding(devices)
sharding2 = PositionalSharding(permuted_devices)

y = jax.device_put(x, sharding1.reshape(4, 2))
z = jax.device_put(x, sharding2.reshape(4, 2))

try:
  y + z
except ValueError as e:
  print_exception(e)

ValueError: Received incompatible devices for jitted computation. Got
argument x1 of jax.numpy.add with shape float32[8192,8192] and device
ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of
jax.numpy.add with shape float32[8192,8192] and device ids [0, 1, 2,
3, 6, 7, 4, 5] on platform TPU


We say arrays that have been explicitly placed or sharded with `jax.device_put` are committed to their device(s), and so won't be automatically moved.

The output of `jnp.zeros`, `jnp.arange` and `jnp.array` are uncommitted.

In [70]:
y = jax.device_put(x, sharding1.reshape(4, 2))
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')

no error!


## Constraining shardings of intermediates in `jit`ted code

While the compiler will attempt to decide how a function's intermediate values and outputs should be sharded, we can give it hints using `jax.lax.with_sharding_constraint`. It is much like `jax.device_put` except we use it inside staged-out (`jit`-decorated) functions:

In [71]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

In [72]:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))

In [73]:
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
  return y

In [74]:
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)


In [77]:
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.replicate())
  return y

In [78]:
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)

By adding `with_sharding_constraint`, we've constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiiler will use annotations to decide shardings for other values.

## Examples: neural networks

Warning: real examples may require more use of `with_sharding_constraint`

In [79]:
import jax
import jax.numpy as jnp

In [80]:
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return jnp.mean(jnp.sum((preds - targets) ** 2, axis=-1))


In [81]:
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))

In [83]:
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init_model(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

In [84]:
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

In [104]:
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

## 8-way batch data parallelism

In [88]:
sharding = PositionalSharding(jax.devices()).reshape(8, 1)

In [89]:
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate())

In [90]:
loss_jit(params, batch)

Array(23.469475, dtype=float32)

In [91]:
step_size = 1e-5

In [93]:
for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dw, b - step_size * db)
    for (W, b), (dw, db) in zip(params, grads)]

print(loss_jit(params, batch))

10.760102


In [94]:
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()

53.7 ms ± 1.16 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [99]:
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])

In [100]:
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()

361 ms ± 101 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


4-way batch data parallelism and 2-way model tensor parallelism

AttributeError: 'tuple' object has no attribute 'shape'

In [105]:
sharding = sharding.reshape(4, 2)

batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])

In [109]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())

W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))

W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())

W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)

In [117]:
print(loss_jit(params, batch))

10.76014


In [118]:
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
    for (W, b), (dW, db) in zip(params, grads)]

In [119]:
print(loss_jit(params, batch))

10.752443


In [120]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)

In [122]:
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()

58.3 ms ± 536 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


# Sharp bits
## Generating random numbers

In [123]:
@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding)

In [124]:
jax.debug.visualize_array_sharding(f(key, x))

In principle, random number generation should be a pure map over counter values. It should require no cross-device communication, nor any redundant computation across devices.

However, the existing stable RNG implementation is not automatically partitionable, for historical reasons.

In [125]:
f_exe = f.lower(key, x).compile()
print("Communicating?", 'collective-permute' in f_exe.as_text())

Communicating? True


One way to work around this:

In [126]:
jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print("Communicating?", 'collective-permute' in f_exe.as_text())

Communicating? False


In [127]:
jax.debug.visualize_array_sharding(f(key, x))