<a href="https://colab.research.google.com/github/sw32-seo/ProJAX/blob/main/Distributed_arrays_and_automatic_parallelization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

In [2]:
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

if len(jax.local_devices()) < 8:
    raise Exception("Notebook requires 8 devices to run")

# Intro and a quick example

In [3]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

In [4]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

In [5]:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)

In [6]:
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)

`x` is present on single device

In [7]:
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()

The slowest run took 8.68 times longer than the fastest. This could mean that an intermediate result is being cached.
24.8 ms ± 30 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


`y` is sharded across 8 devices

In [8]:
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

2.37 ms ± 34.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


# `Sharding` describes how array values are laid out in memory across devices

## Sharding basics, and the `PositionalSharding` subclass

In [9]:
import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))

In [10]:
jax.debug.visualize_array_sharding(x)

In [11]:
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))

In [12]:
from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x)

In [13]:
sharding

PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}], shape=(8,))

In [14]:
sharding.reshape(8, 1)

PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]
                    [{TPU 6}]
                    [{TPU 7}]
                    [{TPU 4}]
                    [{TPU 5}]], shape=(8, 1))

In [15]:
sharding = sharding.reshape(4, 2)
print(sharding)

PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]], shape=(4, 2))


In [16]:
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

In [17]:
sharding = sharding.reshape(1, 8)
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

Sometimes, we want to *replicate* some slices, meaning storing copies of a slice's values in multiple devices' memories.

With the reducer method `replicate`, wecan express replication.

In [18]:
sharding = sharding.reshape(4, 2)
# before replicate
print(sharding)
# after replicate
print(sharding.replicate(axis=0, keepdims=True))

PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]], shape=(4, 2))
PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]], shape=(1, 2))


In [19]:
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y)

In [20]:
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

The slowest run took 14.65 times longer than the fastest. This could mean that an intermediate result is being cached.
21.8 ms ± 31.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


# `NamedSharding` gives a way to express shardings with names

In [21]:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)

We can define a helper function to make things simpler:

In [22]:
devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)

In [23]:
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)

`None` in the axis names means that the array will not be sharded on its demension and shards are replicated across it.

In [24]:
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)

# Computation follows data sharding and is automatically parallelized.

Functions decorated with `jax.jit` can overate over sharded arrays.

For example, the simplest computation is an elementwise one:

In [25]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

In [26]:
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding: ')
jax.debug.visualize_array_sharding(x)

y = jnp.sin(x)
print('output sharding: ')
jax.debug.visualize_array_sharding(y)

input sharding: 


output sharding: 


We can do the same for more than jus elementwise operation.

In [27]:
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))

print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)

lhs sharding:


rhs sharding:


out sharding:


Let's do the timing experiment for running in parallel.

In [28]:
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)

In [29]:
np.allclose(jnp.dot(x_single, x_single), jnp.dot(y, z))

True

In [30]:
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()

49.7 ms ± 217 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [31]:
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()

7.51 ms ± 45.3 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


Even copying a sharded `Array` produces a sharded result.

In [32]:
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)

If the sharding axis is reversed, the result value is same. The only difference is the replication.

In [33]:
w_wrong = jnp.dot(z, y)
jax.debug.visualize_array_sharding(w_wrong)

In [34]:
np.allclose(jnp.dot(x_single, x_single), jnp.dot(z, y))

True

In [35]:
%timeit -n 5 -r 5 jnp.dot(z, y).block_until_ready()

55.7 ms ± 181 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


If the replication is wrong, matrix multiplication gives wrong result.

In [36]:
n = jax.device_put(x, sharding.reshape(4, 2))
m = jax.device_put(x, sharding.reshape(2, 4))

k = jnp.dot(n, m)
jax.debug.visualize_array_sharding(k)

In [37]:
np.allclose(jnp.dot(x_single, x_single), k)

False

In [89]:
n = jax.device_put(x, sharding.reshape(4, 2))
m = jax.device_put(x, sharding.replicate(0))

In [90]:
k = jnp.dot(n, m)
np.allclose(jnp.dot(x_single, x_single), k)

True

In [91]:
%timeit -n 5 -r 5 jnp.dot(n, m).block_until_ready()

8.04 ms ± 19.5 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


# Constraining shardings of intermediates in `jit`ted code

In [41]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

In [42]:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))

In [43]:
@jax.jit
def f(x):
    x = x + 1
    y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
    return y

In [44]:
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)

# Examples: neural networks

In [45]:
import jax
import jax.numpy as jnp

In [46]:
def predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jnp.maximum(outputs, 0)
    return outputs

def loss(params, batch):
    inputs, targets = batch
    predictions = predict(params, inputs)
    return jnp.mean(np.sum((predictions - targets) ** 2, axis=-1))

In [47]:
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))

In [48]:
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out, ))
    return W, b


def init_model(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)


layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

## 8-way batch data parallelism

In [49]:
sharding = PositionalSharding(jax.devices()).reshape(8, 1)

In [50]:
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate())

In [51]:
loss_jit(params, batch)

Array(23.469475, dtype=float32)

In [52]:
step_size = 1e-5
for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))

10.760102


In [53]:
%timeit -n 5 -r 5 loss_jit(params, batch).block_until_ready()

14.5 ms ± 87.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [54]:
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])

In [55]:
%timeit -n 5 -r 5 loss_jit(params_single, batch_single).block_until_ready()

146 ms ± 83.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


## 4-way batch data parallelism and 2-way model tensor parallelism

In [56]:
sharding = sharding.reshape(4, 2)

In [57]:
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])

In [78]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())

W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))

W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())

W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)

In [79]:
jax.debug.visualize_array_sharding(W2)

In [80]:
jax.debug.visualize_array_sharding(W3)

In [81]:
print(loss_jit(params, batch))

10.745022


In [82]:
jax.debug.visualize_array_sharding((batch[0] @ W1 + b1))

the result after the second hidden layer becomes 4 by 2 sharded array

In [83]:
jax.debug.visualize_array_sharding((batch[0] @ W1 + b1) @ W2 + b2)

After the third layer, the result becomes 4 by 1

In [84]:
jax.debug.visualize_array_sharding(((batch[0] @ W1 + b1) @ W2 + b2) @ W3 + b3)

In [85]:
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]

In [86]:
print(loss_jit(params, batch))

10.737434


In [87]:
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()

58.3 ms ± 365 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


# Sharp bits

## Generating random numbers

In [None]:
@jax.jit
def f(key, x):
    numbers = jax.random.uniform(key, x.shape)
    return x + numbers

key = jax.random.PRNGKey(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding)

In [None]:
jax.debug.visualize_array_sharding(f(key, x))

In [None]:
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())

Communicating? True
