<a href="https://colab.research.google.com/github/sw32-seo/ProJAX/blob/main/Distributed_arrays_and_automatic_parallelization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

In [2]:
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

if len(jax.local_devices()) < 8:
    raise Exception("Notebook requires 8 devices to run")

# Intro and a quick example

In [3]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

In [4]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))

In [5]:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)

In [6]:
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)

`x` is present on single device

In [7]:
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()

911 ms ± 341 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


`y` is sharded across 8 devices

In [8]:
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

682 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


# `Sharding` describes how array values are laid out in memory across devices

## Sharding basics, and the `PositionalSharding` subclass

In [9]:
import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))

In [10]:
jax.debug.visualize_array_sharding(x)

In [11]:
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))

In [12]:
from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x)

In [13]:
sharding

PositionalSharding([{CPU 0} {CPU 1} {CPU 2} {CPU 3} {CPU 4} {CPU 5} {CPU 6} {CPU 7}], shape=(8,))

In [14]:
sharding.reshape(8, 1)

PositionalSharding([[{CPU 0}]
                    [{CPU 1}]
                    [{CPU 2}]
                    [{CPU 3}]
                    [{CPU 4}]
                    [{CPU 5}]
                    [{CPU 6}]
                    [{CPU 7}]], shape=(8, 1))

In [15]:
sharding = sharding.reshape(4, 2)
print(sharding)

PositionalSharding([[{CPU 0} {CPU 1}]
                    [{CPU 2} {CPU 3}]
                    [{CPU 4} {CPU 5}]
                    [{CPU 6} {CPU 7}]], shape=(4, 2))


In [16]:
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

In [17]:
sharding = sharding.reshape(1, 8)
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

Sometimes, we want to *replicate* some slices, meaning storing copies of a slice's values in multiple devices' memories.

With the reducer method `replicate`, wecan express replication.

In [18]:
sharding = sharding.reshape(4, 2)
# before replicate
print(sharding)
# after replicate
print(sharding.replicate(axis=0, keepdims=True))

PositionalSharding([[{CPU 0} {CPU 1}]
                    [{CPU 2} {CPU 3}]
                    [{CPU 4} {CPU 5}]
                    [{CPU 6} {CPU 7}]], shape=(4, 2))
PositionalSharding([[{CPU 0, 2, 4, 6} {CPU 1, 3, 5, 7}]], shape=(1, 2))


In [19]:
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y)

In [20]:
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

2.62 s ± 28.4 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


# `NamedSharding` gives a way to express shardings with names

In [29]:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)

We can define a helper function to make things simpler:

In [30]:
devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)

In [31]:
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)

`None` in the axis names means that the array will not be sharded on its demension and shards are replicated across it.

In [32]:
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)

# Computation follows data sharding and is automatically parallelized.