<a href="https://colab.research.google.com/github/rishabh135/2015/blob/master/A_tour_of_parallelism_in_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Imports
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax import random, debug, nn
from jax.nn import initializers
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P, PositionalSharding as S
import typing as tp
import optax
from functools import partial

## Sharding in JAX

In [None]:
x = random.uniform(random.PRNGKey(0), shape=(8, 4))
print(x.shape)
debug.visualize_array_sharding(x)

(8, 4)


In [None]:
mesh = Mesh(
    devices=mesh_utils.create_device_mesh((2, 4)),
    axis_names=("a", "b"),
)
y = jax.device_put(x, NamedSharding(mesh, P("a", "b")))

print("gobal shape:", y.shape)
print("shard shape:", y.sharding.shard_shape(y.shape))
debug.visualize_array_sharding(y)

gobal shape: (8, 4)
shard shape: (4, 1)


In [None]:
mesh = Mesh(
    devices=mesh_utils.create_device_mesh((4, 2)),
    axis_names=("a", "b"),
)
y = jax.device_put(x, NamedSharding(mesh, P("a", "b")))

print("gobal shape:", y.shape)
print("shard shape:", y.sharding.shard_shape(y.shape))
debug.visualize_array_sharding(y)

gobal shape: (8, 4)
shard shape: (2, 2)


In [None]:
mesh = Mesh(
    devices=mesh_utils.create_device_mesh((4, 2)),
    axis_names=("a", "b"),
)
y = jax.device_put(x, NamedSharding(mesh, P("a", None)))

print(y.shape)
debug.visualize_array_sharding(y)

(8, 4)


## MLP

In [None]:
def mlp_init(key, din: int, dmid: int, dout: int) -> tp.Dict:
  key1, key2 = random.split(key, 2)
  return {
      "w1": initializers.lecun_normal()(key1, (din, dmid)),
      "w2": initializers.lecun_normal()(key2, (dmid, dout)),
  }

def mlp_forward(params: tp.Dict, x: jax.Array) -> jax.Array:
  x = x @ params["w1"]
  x = nn.gelu(x)
  x = x @ params["w2"]
  return x

params = mlp_init(random.PRNGKey(0), 32, 128, 32)

X = random.uniform(random.PRNGKey(0), (16, 32))
Y = mlp_forward(params, X)

print("params:", jax.tree_map(jnp.shape, params))

params: {'w1': (32, 128), 'w2': (128, 32)}


In [None]:
params = mlp_init(random.PRNGKey(0), 32, 128, 32)
tx = optax.adam(3e-4)
opt_state = tx.init(params)

print("params:", jax.tree_map(jnp.shape, params))
print("opt_state:", jax.tree_map(jnp.shape, opt_state))

@jax.jit
def train_step(params, opt_state, x, y):
  def loss_fn(params):
    y_pred = mlp_forward(params, x)
    return jnp.mean((y - y_pred)**2)

  grads = jax.grad(loss_fn)(params)

  updates, opt_state = tx.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return params, opt_state

params, opt_state = train_step(params, opt_state, X, Y)

params: {'w1': (32, 128), 'w2': (128, 32)}
opt_state: (ScaleByAdamState(count=(), mu={'w1': (32, 128), 'w2': (128, 32)}, nu={'w1': (32, 128), 'w2': (128, 32)}), EmptyState())


## Data Parallel Training with shard_map

In [None]:
from jax.experimental.shard_map import shard_map

params = mlp_init(random.PRNGKey(0), 32, 128, 32)
tx = optax.adam(3e-4)
opt_state = tx.init(params)

# mesh
mesh = Mesh(
    devices=mesh_utils.create_device_mesh((8,)),
    axis_names=("data",),
)

# create shardings
params_spec = {'w1': P(None), 'w2': P(None)}
opt_state_spec = (
  optax.ScaleByAdamState(
    count=P(), mu=params_spec, nu=params_spec
  ),
  optax.EmptyState()
)
data_spec = P("data")

print("params:", jax.tree_map(jnp.shape, params))
print("opt_state:", jax.tree_map(jnp.shape, opt_state))
print("\nglobal x.shape", X.shape)

@jax.jit
@partial(
  shard_map,
  mesh=mesh,
  in_specs=(params_spec, opt_state_spec, data_spec, data_spec),
  out_specs=(params_spec, opt_state_spec),
)
def train_step(params, opt_state, x, y):
  print("shard  x.shape:", x.shape)

  def loss_fn(params):
    y_pred = mlp_forward(params, x)
    return jnp.mean((y - y_pred)**2)

  grads = jax.grad(loss_fn)(params)

  # all reduce avg
  grads = jax.lax.pmean(grads, axis_name="data")

  updates, opt_state = tx.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return params, opt_state

params, opt_state = train_step(params, opt_state, X, X)

print("\nVisualize final sharding")
print("params")
debug.visualize_array_sharding(params["w1"])
print("opt_state")
debug.visualize_array_sharding(opt_state[0].mu["w1"])

params: {'w1': (32, 128), 'w2': (128, 32)}
opt_state: (ScaleByAdamState(count=(), mu={'w1': (32, 128), 'w2': (128, 32)}, nu={'w1': (32, 128), 'w2': (128, 32)}), EmptyState())

global x.shape (16, 32)
shard  x.shape: (2, 32)

Visualize final sharding
params


opt_state


## Data Parallel with pjit

In [None]:
params = mlp_init(random.PRNGKey(0), 32, 128, 32)
tx = optax.adam(3e-4)
opt_state = tx.init(params)

# mesh
mesh = Mesh(
    devices=mesh_utils.create_device_mesh((8,)),
    axis_names=("data"),
)
def sharding(*args):
  return NamedSharding(mesh, P(*args))

# create shardings
params_sharding = {'w1': sharding(None), 'w2': sharding(None)}
opt_state_sharding = (
  optax.ScaleByAdamState(
    count=sharding(), mu=params_sharding, nu=params_sharding
  ),
  optax.EmptyState()
)
data_sharding = sharding("data")

# shard state
params = jax.device_put(params, params_sharding)
opt_state = jax.device_put(opt_state, opt_state_sharding)
x, y = jax.device_put(X, data_sharding), jax.device_put(Y, data_sharding)

print("params:", jax.tree_map(jnp.shape, params))
print("opt_state:", jax.tree_map(jnp.shape, opt_state))

print("\nInput sharding")
debug.visualize_array_sharding(x)

@jax.jit
def train_step(params, opt_state, x, y):
  def loss_fn(params):
    y_pred = mlp_forward(params, x)
    return jnp.mean((y - y_pred)**2)

  grads = jax.grad(loss_fn)(params)

  # constrain grads
  grads = jax.lax.with_sharding_constraint(grads, params_sharding)

  updates, opt_state = tx.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return params, opt_state

params, opt_state = train_step(params, opt_state, x, y)
params, opt_state = train_step(params, opt_state, x, y)

print("\nVisualize final sharding")
print("params")
debug.visualize_array_sharding(params["w1"])
print("opt_state")
debug.visualize_array_sharding(opt_state[0].mu["w1"])

params: {'w1': (32, 128), 'w2': (128, 32)}
opt_state: (ScaleByAdamState(count=(), mu={'w1': (32, 128), 'w2': (128, 32)}, nu={'w1': (32, 128), 'w2': (128, 32)}), EmptyState())

Input sharding



Visualize final sharding
params


opt_state


## Model Parallel

In [None]:
def mlp_forward(params: tp.Dict, x: jax.Array, viz=False) -> jax.Array:
  y = nn.gelu(x @ params["w1"])
  z = y @ params["w2"]
  if viz:
    print("x")
    debug.visualize_array_sharding(x)
    print("w1")
    debug.visualize_array_sharding(params['w1'])
    print("y")
    debug.visualize_array_sharding(y)
    print("w2")
    debug.visualize_array_sharding(params['w2'])
    print("z")
    debug.visualize_array_sharding(z)
  return x

In [None]:
params = mlp_init(random.PRNGKey(0), 32, 128, 32)
tx = optax.adam(3e-4)
opt_state = tx.init(params)

# mesh
mesh = Mesh(
    devices=mesh_utils.create_device_mesh((8,)),
    axis_names=("model",),
)
def sharding(*args):
  return NamedSharding(mesh, P(*args))

# create shardings
params_sharding = {'w1': sharding(None, "model"), 'w2': sharding("model", None)}
opt_state_sharding = (
  optax.ScaleByAdamState(
    count=sharding(), mu=params_sharding, nu=params_sharding
  ),
  optax.EmptyState()
)
data_sharding = sharding(None)

# shard state
params = jax.device_put(params, params_sharding)
opt_state = jax.device_put(opt_state, opt_state_sharding)
x, y = jax.device_put(X, data_sharding), jax.device_put(Y, data_sharding)

print("params:", jax.tree_map(jnp.shape, params))
print("opt_state:", jax.tree_map(jnp.shape, opt_state))

# visualize computation sharding
mlp_forward(params, x, viz=True)

@jax.jit
def train_step(params, opt_state, x, y):
  def loss_fn(params):
    y_pred = mlp_forward(params, x)
    return jnp.mean((y - y_pred)**2)

  grads = jax.grad(loss_fn)(params)

  # constrain grads
  grads = jax.lax.with_sharding_constraint(grads, params_sharding)

  updates, opt_state = tx.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return params, opt_state

params, opt_state = train_step(params, opt_state, x, y)
params, opt_state = train_step(params, opt_state, x, y)

print("\nVisualize final sharding")
print("params")
debug.visualize_array_sharding(params["w1"])
print("opt_state")
debug.visualize_array_sharding(opt_state[0].mu["w1"])

params: {'w1': (32, 128), 'w2': (128, 32)}
opt_state: (ScaleByAdamState(count=(), mu={'w1': (32, 128), 'w2': (128, 32)}, nu={'w1': (32, 128), 'w2': (128, 32)}), EmptyState())
x


w1


y


w2


z



Visualize final sharding
params


opt_state


## Data + Model Parallel

In [None]:
params = mlp_init(random.PRNGKey(0), 32, 128, 32)
tx = optax.adam(3e-4)
opt_state = tx.init(params)

# mesh
mesh = Mesh(
    devices=mesh_utils.create_device_mesh((2, 4)),
    axis_names=("data", "model"),
)
def sharding(*args):
  return NamedSharding(mesh, P(*args))

# create shardings
params_sharding = {'w1': sharding(None, "model"), 'w2': sharding("model", None)}
opt_state_sharding = (
  optax.ScaleByAdamState(
    count=sharding(), mu=params_sharding, nu=params_sharding
  ),
  optax.EmptyState()
)
data_sharding = sharding("data")

# shard state
params = jax.device_put(params, params_sharding)
opt_state = jax.device_put(opt_state, opt_state_sharding)
x, y = jax.device_put(X, data_sharding), jax.device_put(Y, data_sharding)

print("params:", jax.tree_map(jnp.shape, params))
print("opt_state:", jax.tree_map(jnp.shape, opt_state))

# visualize computation sharding
mlp_forward(params, x, viz=True)

@jax.jit
def train_step(params, opt_state, x, y):
  def loss_fn(params):
    y_pred = mlp_forward(params, x)
    return jnp.mean((y - y_pred)**2)

  grads = jax.grad(loss_fn)(params)

  # constrain grads
  grads = jax.lax.with_sharding_constraint(grads, params_sharding)

  updates, opt_state = tx.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return params, opt_state

params, opt_state = train_step(params, opt_state, x, y)
params, opt_state = train_step(params, opt_state, x, y)

print("\nVisualize final sharding")
print("params")
debug.visualize_array_sharding(params["w1"])
print("opt_state")
debug.visualize_array_sharding(opt_state[0].mu["w1"])

params: {'w1': (32, 128), 'w2': (128, 32)}
opt_state: (ScaleByAdamState(count=(), mu={'w1': (32, 128), 'w2': (128, 32)}, nu={'w1': (32, 128), 'w2': (128, 32)}), EmptyState())
x


w1


y


w2


z



Visualize final sharding
params


opt_state
