In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

Flax provides ..
 1. Interface to specify partition of your data when defining `flax.linen.Module`
 2. Utility functinos to generate the sharding information that `jax.jit` requires to run.
 3. An interface to customize your axis name called "logical axis annotations" to decouple both your Module code and partition plan to experiment with different partition layouts more easily.

In [2]:
import functools
from typing import Optional, Callable

import numpy as np
import jax
from jax import lax, random, numpy as jnp

import flax
from flax import struct, traverse_util, linen as nn
from flax.core import freeze, unfreeze
from flax.training import train_state, checkpoints

import optax

In [3]:
print(f"We have 8 fake Jax deivces now: {jax.devices()}")

We have 8 fake Jax deivces now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [4]:
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils

In [6]:
# create a mesh and annotate each axis with a name
device_mesh =mesh_utils.create_device_mesh((2, 4))
print(device_mesh)

# annotate each axis with a name using the axis_names param
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)

def mesh_sharding(pspec: jax.sharding.PartitionSpec) -> NamedSharding:
  return NamedSharding(mesh, pspec)

[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
 [CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]
Mesh('data': 2, 'model': 4)


# Define a layer

To shard the params efficiently, apply the following APIs to annotate the parameters and intermediate variables:
 1. Use `flax.linen.with_partitioning` to decorate the initializer function when creating sub-layers or raw parameters.
 2. Apply `jax.lax.with_sharding_constraint` to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.

In [7]:
class DotReluDot(nn.Module):
  depth: int
  dense_init: Callable = nn.initializers.xavier_normal()

  @nn.compact
  def __call__(self, x):

    y = nn.Dense(self.depth,
                 kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
                 use_bias=False, # or overwrite with `bias_init`
                 )(x)
    y = jax.nn.relu(y)

    # Force a local sharding annotation
    y = with_sharding_constraint(y, mesh_sharding(P('data', 'model')))

    W2 = self.param(
        'W2',
        nn.with_partitioning(self.dense_init, ('model', None)),
        (self.depth, x.shape[-1])
    )

    z = jnp.dot(y, W2)
    # Force a local sharding annotation
    z = with_sharding_constraint(z, mesh_sharding(P('data', None)))

    # returns a tuple to conform with the API `flax.linen.scan`
    # as shown in the cell below
    return z, None

- When you define `W1` with shape `(x.shape[-1], self.depth)` and annotate as `(None, 'model')`:
 - The first dimension is replicated across all devices
 - The second dimension (of length `self.depth`) will be sharded over the `model` axis of the device mesh.
- When you annotate the output `z` as `('data', None)`:
 - The first dimension will be sharded over the 'data' axis
 - The second dimension will be replicated across all devices.

## Define a model with `flax.linen.scan` lifted transformation

In [8]:
class MLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool

  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(DotReluDot, length=self.num_layers,
                     variable_axes={"params": 0},
                     split_rngs={"params": True},
                     metadata_params={nn.PARITION_NAME: None}
                     )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = DotReluDot(self.depth)(x)
    return x

In [9]:
# ML hyperparameters
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False

# create fake inputs
x = jnp.ones((BATCH, DEPTH))
# initialize a PRNG Key
k = random.key(0)

# create an optax optimizer
optimizer = optax.adam(learning_rate=0.001)
# instantiate the model
model = MLP(LAYERS, DEPTH, USE_SCAN)

Note that device axis names like `data`, `model`, or `None` are passed into both `flax.linen.with_partitioning` and `jax.lax.with_sharding_constraint` API calls

> `flax.linen.with_partitioning`: Wraps a function’s return value with `Partitioned`

## Specify sharding

### input's sharding

In [10]:
# for data parallelism, you can shard the batched input x across the data axis
x_sharding = mesh_sharding(P('data', None))
# use jax.device_put to place it onto the correct devices
x = jax.device_put(x, x_sharding)
jax.debug.visualize_array_sharding(x)

### Output's sharding
You need to compile `model.init()` and its output as a pytree of parameters. Additionally, you may sometimes need to wrap it with a `flax.training.train_data` to track other variables, such as optimizer states and that would make the output an even more complex pytree.

To achieve this, luckily, you don't have to hardcode the output's sharding by hand. Instead,
1. Evaluate `model.init` abstractly using `jax.eval_shape`
2. Use `flax.linen.get_sharding` to automatically generate the `jax.sharding.NamedSharding`.
 - This step utilizes the `flax.linen.with_partitioning` annotations in the earlier def to generate the correct sharding for the parameters.

In [11]:
def init_fn(k, x, model, optimizer):
  variables = model.init(k, x) # Initialize the model
  state = train_state.TrainState.create(
      apply_fn=model.apply,
      params=variables['params'],
      tx=optimizer
  )
  return state


In [13]:
abstract_variables = jax.eval_shape(
    functools.partial(init_fn, model=model, optimizer=optimizer),
    k,
    x
)

# This `state_sharding` has the same pytree structure as `state`
# , which is the output of the `init_fn`
state_sharding = nn.get_sharding(abstract_variables, mesh)
state_sharding

TrainState(step=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec()), apply_fn=<bound method Module.apply of MLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = False
)>, params={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSha

In [15]:
jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
                      in_shardings=(mesh_sharding(()), x_sharding),
                      out_shardings=state_sharding)

initialized_state = jit_init_fn(k, x, model, optimizer)

jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)


Inspect the Module output

In [16]:
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)


<class 'flax.core.meta.Partitioned'>
<class 'jaxlib.xla_extension.ArrayImpl'>
(None, 'model')
(1024, 1024)


You can accesss the raw `jax.Arrays` by calling `flax.linen.meta.unbox()` upon the dictionary, or call `.value` upon individual variable. You can also use `flax.linen.meta.replace_boxed()` to change the underlying `jax.Array` without modifying the sharding annotations.

In [17]:
unboxed_params = nn.meta.unbox(initialized_state.params)
all_zero = jax.tree_map(jnp.zeros_like, unboxed_params)
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0

  all_zero = jax.tree_map(jnp.zeros_like, unboxed_params)


You can also check the underlying `jax.sharding` of each parameter, which is now more internal than `NamedSharding`.

In [18]:
initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding

NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))

In [19]:
print(initialized_state.step)
initialized_state.step.sharding

0


NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec())

You can use `jax.tree.map` to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays

In [21]:
diff = jax.tree.map(
    lambda a, b: a - b,
    initialized_state.params['DotReluDot_0'],
    initialized_state.params['DotReluDot_0']
)
diff_array = diff['Dense_0']['kernel'].value
print(type(diff_array))
print(diff_array.shape)
print(diff_array.sharding)

<class 'jaxlib.xla_extension.ArrayImpl'>
(1024, 1024)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))


In [22]:
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
                   out_shardings=state_sharding)
def train_step(state, x):
  # a fake loss fcn
  def loss_unrolled(params):
    y = model.apply({'params': params}, x)
    return y.sum()
  grad_fn = jax.grad(loss_unrolled)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

with mesh:
  new_state = train_step(initialized_state, x)

In [23]:
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)


Sharding of Weight 1:


Sharding of Weight 2:


In [24]:
# create a compiled inference step
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
                   out_shardings=x_sharding)
def apply_fn(state, x):
  return state.apply_fn({'params': state.params}, x)

with mesh:
  y = apply_fn(new_state, x)

print(type(y))
print(y.dtype)
print(y.shape)
jax.debug.visualize_array_sharding(y)

<class 'jaxlib.xla_extension.ArrayImpl'>
float32
(8, 1024)


# Profiling

If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function

In [25]:
%%timeit

def block_all(xs):
  jax.tree.map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(initialized_state, x))

403 ms ± 65.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Logical axis annotation

Jax's automatic SPMD encourages users to explore different sharding layouts to find the optimal one.

The `LogicalDotReluDot` and `LogicalMLP` Module def below are similar to the Modules you created earlier, except for the following:
 1. All axes are annotated with more concrete, meaningful names.
 2. `flax.linen.with_logical_partitioning` replaces `flax.linen.with_partitioning`; and `flax.linen.with_logical_constraint` replaces `jax.with_sharding_constraint`



In [26]:
class LogicalDotReluDot(nn.Module):
  depth: int
  dense_init: Callable = nn.initializers.xavier_normal()

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.depth,
                 kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
                 use_bias=False,
                 )(x)
    y = jax.nn.relu(y)
    # force a local sharding annotation
    y = with_sharding_constraint(y, mesh_sharding(P('data', 'model')))

    W2 = self.param(
        'W2',
        nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
        (self.depth, x.shape[-1])
    )

    z = jnp.dot(y, W2)
    # Force a local sharding annotation
    z = nn.with_logical_constraint(z, ('batch', 'embed'))
    return z, None

class LogicalMLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool

  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
                     variable_axes={"params": 0},
                     split_rngs={"params": True},
                     metadata_params={nn.PARITION_NAME: None}
                     )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = LogicalDotReluDot(self.depth)(x)
    return x

Now, initiate a model and try to figure out what sharding its `state` should have.

To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the devices axis 'data' or 'model'.

In [30]:
# Unspecified rule means unsharded by default,
# so no need to specify ('embed', None),
# and ('layer', None)
rules = (('batch', 'data'),
         ('hidden', 'model'))

logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)

logical_abstract_variables = jax.eval_shape(
    functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x
)
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
print('annotations are logical, not mesh-specific: ',
      logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel'])


annotations are logical, not mesh-specific:  PartitionSpec('embed', 'hidden')


In [31]:
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, rules)
print('sharding annotations are mesh-specific: ',
      logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)

sharding annotations are mesh-specific:  PartitionSpec(None, 'model')


You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous example. This allows you to `jax.jit` your Module's `flax.linen.Module.init` and `flax.linen.Module.apply` the same way in the above

In [32]:
state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0']

True

In [33]:
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
                              in_shardings=(mesh_sharding(()), x_sharding),
                              out_shardings=logical_state_sharding)

logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)

In [34]:
print(f"sharding of weight 1:")
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value)
print(f"sharding of weight 2:")
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value)

sharding of weight 1:


sharding of weight 2:


When to use device axis / logical axis
- Device mesh axis: If you want a very simple model, or you