# Lab 3: Introduction to Parallelization in JAX
---
Author: Dr. Jan Blechschmidt\
Email: Jan.Blechschmidt@math.tu-freiberg.de\
Credits: This tutorial is based on [this tutorial](https://docs.jax.dev/en/latest/sharded-computation.html).

---

The following notebook is meant to give a short introduction into the three parallelization concepts in JAX.

## Introduction

In JAX, parallelization is realized by **sharding** or **partitioning** arrays across
multiple accelerators.

### Three ways to use parallelization

- Automatic sharding via `jax.jit()`: The compiler chooses the optimal computation strategy, you don’t even notice that the code is executed
on multiple devices.
- Explicit Sharding is similar to automatic sharding in that you’re writing a
global-view program:
    - sharding of array’s can be explicit part of model
    - these shardings are propagated
    - it’s still the compiler’s responsibility to turn the whole-array program into per-device programs (turning jnp.sum into psum for example)
    - the compiler is heavily constrained by the user-supplied shardings.
- Manual sharding via `jax.shard_map()`: enables per-device code and explicit communication collectives

### Sharding in JAX

Key concept to all of the distributed computation approaches is **data sharding**, which describes how data is laid out on the available devices.

- JAX’s datatype, the `jax.Array` immutable array data structure, represents arrays with physical  storage spanning one or multiple devices
- this helps make parallelism a core feature of JAX
- *every* `jax.Array` has a sharding attribute, which describes which shard of the global data is required by each global device
- when you create a jax.Array from scratch, you can also create its
Sharding.

---

The function `jax.devices(backend=None)` returns a list of all available devices.

If backend is `None`, it returns all the devices from the default backend.
The default backend is generally `gpu` or `tpu` if available, otherwise `cpu`.

In [None]:
import jax
import jax.numpy as jnp

jax.devices()

For an array, one can use the attribute `device` to show the corresponding device.
Similarly, the attribute `sharding` returns the sharding of the array

In [None]:
A = jnp.arange(64.0).reshape(4,16)
A.device

In [None]:
A.sharding

Using the XLA compiler flag

    xla_force_host_platform_device_count

one can set the number of CPU devices visible to JAX. Note that you have to restart the notebook and either set the corresponding environment variable or setting the flag using `os.environ` **before** loading jax.

If you are using Google Colab with activated TPU environment, you should already see multiple TPU devices.

In [None]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
import jax.numpy as jnp

jax.devices()

The standard device is the one with `id=0`.

In [None]:
A = jnp.arange(64.0).reshape(4,16)
print(f'{A.device = }\n')
print(f'{A.sharding = }\n')

One can explicitly set the device using the optional `device` variable in functions like `jnp.Array`, `jnp.ones`, `jnp.zeros`, etc.

Note the the `device`-parameter can either be `None`, `Device` or `Sharding`.
If it is a specific `Device` or a `Sharding`, it is the one to which the created array will be committed.

In [None]:
B = jnp.ones((4,16), device=jax.devices()[1])
print(f'{B.device = }\n')
print(f'{B.sharding = }\n')

### Visualization of sharding

Using `jax.debug.visualize_array_sharding`, one can visualize the sharding of arrays. For the arrays `A` and `B` this is boring, since the arrays are only committed to one device.

Note: Using the optional variable `scale`, one can scale the output.

In [None]:
print('Sharding of array A:\n')
jax.debug.visualize_array_sharding(A, scale=0.5)

print('\nSharding of array B:\n')
jax.debug.visualize_array_sharding(B, scale=0.5)

### Sharding over multiple devices

To create an array with a non-trivial sharding, you can define a `jax.sharding` specification for the array and pass this to `jax.device_put()`. 

#### 1st step: Create mesh

The function `jax.make_mesh` attempts to automatically compute a good
mapping from a set of logical axes (the mesh that we want to use in our programms) to a physical mesh (important on TPUs where TPUs are aranged in a 2d torus or 3d mesh, or on GPUs, where some of the GPUs might be connected with a fast NVLink-conncetion).

This finite-dimensional mesh allows to make of use specific hardware, in particular TPUs, where some of the accelators are connected with higher bandwidth than others. Thus, instead of having all devices in a one-dimensional list, we can arrange them in a multi-dimensional way, e.g., to accomodate for data and model parallelization using a `data`-axis and a `model`-axis.

The following creates a `mesh` of size 2 by 4, where the first axis is named `x` and the second dimension `y`.

In [None]:
mesh = jax.make_mesh((2, 4), ('x', 'y'))
print(mesh)

Note here that printing reveals that the mesh axis are currently in `Auto`-mode. Later, we will learn how to work with explicit axes.

#### 2nd step: Define a sharding
Define a `NamedSharding`, which specifies an finite-dimensional grid of devices with *named
axes*.

A `NamedSharding` expresses a sharding using *named axes* and is a pair of a
`mesh` and a `PartitionSpec` which describes how to shard an array across the mesh.

In [None]:
from jax.sharding import PartitionSpec as P

sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

**Task**: Put the array `A` on multiple accelerators using the function `jax.device_put` and the named sharding that we just defined. Store the sharded array under `A_sharded`.

Afterwards, visualize the sharding using `jax.debug.visualize_array_sharding`.

In [None]:
### Your code here

**Task**: Check the devices of `A_sharded` and the sharding.

In [None]:
### Your code here

Note that printing the arrays `A` and `A_sharded` doesn't reveal the sharding:

In [None]:
A

In [None]:
A_sharded

**Task**: Explore, what it means if you change the partition specification (`PartitionSpec`) to
- `P('x', None)`
- `P(None, 'y')`
- `P(None, None)`
- `P(None, 'x')`

In [None]:
### Your code here

## 1st way: Automatic parallelization using `jax.jit`

Once you have sharded data, the easiest way to do parallel computation is to simply pass the *sharded data* to a `jit`-compiled function:
- the XLA compiler which is used during just-in-time compilation includes heuristics for optimizing computations across multiple devices

- inter-device communication is done automatically
- you can specify how you want the input and output of your code to be partitioned explicitly

For simple functions, the XLA compiler heuristics result in a code where computation follows data.

In [None]:
@jax.jit
def f(x):
  return 2 * jnp.sin(x)**2 + 1

Z = f(A_sharded)

In [None]:
jax.debug.visualize_array_sharding(Z, scale=0.3)

For simple function, oftentimes input sharding equals output sharding:

In [None]:
Z.sharding == A_sharded.sharding

In [None]:
f(A).sharding == A.sharding

As computations get more complex, the compiler makes decisions about
how to propagate the sharding of the data.

Let's consider this summation over the first axis, i.e. a columnwise summation:

In [None]:
@jax.jit
def f_sum(x):
  return x.sum(axis=0)

Z = f_sum(A_sharded)

print('Input sharding: \n')
jax.debug.visualize_array_sharding(A_sharded, scale=.5)

print('Output sharding: \n')
jax.debug.visualize_array_sharding(Z, scale=.5)

print(f'Output: {Z}')

In [None]:
Z.sharding

Thus, the result is partially replicated: that is, the first four elements of the
array are replicated on devices 0 and 4, the second four on 1 and 5, etc.

**Task**: Create two arbitrary arrays `B` and `C` of shape 4 by 8.
The array `B` should be sharded along the `x`-axis of the mesh and replicated along the `y`-axis.
Array `C` should be sharded along the `y`-axis and replicated along the `x` axis.

What is the sharding of `B + C`?

What is the sharding of the matrix product `B.T @ C`? Think about an answer before verifying your guess.

In [None]:
### Your code here


### Data parallel layer of a feedforward neural network using automatic parallelization

Let's consider as an example the following definition of a (vectorized) layer of a neural network with inputs:
- `x` array of data of shape `(n_data, n_features)`
- `W` weight matrix of shape `(n_features, n_out`)
- `b` bias of shape `(n_out,)`

Note that this layer could be the terminal layer in a classification problem.

In [None]:
@jax.jit
def layer(x, W, b):
  return jax.nn.softmax(x @ W + b)

In [None]:
n_features = 10
n_data = 128
n_out = 2

key = jax.random.key(0)
key, *keys = jax.random.split(key, 4)
x = jax.random.normal(keys[0], shape=(n_data,n_features))
W = jax.random.normal(keys[1], shape=(n_features, n_out))
b = jax.random.normal(keys[2], shape=(n_out,))

Since all arrays are created without specification of a sharding, they are performed on the default device, i.e. the one with `id=0`.

In [None]:
out = layer(x, W, b)
print(out.shape)
jax.debug.visualize_array_sharding(out)

**Task**: In a full data-parallel model, we would shard the data array `x` across our devices and replicate the model parameters `W` and `b` across the devices. Implement this approach and verify that the output sharding is the same as the sharding of your data.

Use a mesh with one dimension and four or eight accelerators.

In [None]:
### Your code here

## 2nd way: Parallelization using explicit sharding

- main idea is that the JAX-level type of a variable includes a description of how the variable is sharded.
- we can query the JAX-level type of any JAX variable using `jax.typeof`
- this also works for NumPy arrays, Python scalars and other variables

Note: To use `jax.typeof`, you have to use a current version of JAX. On Google Colab, you can update jax by

    !pip install jax==0.6.0

A restart of the kernel is necessary.

Let's consider the following examples:

In [None]:
import numpy as np
np_array = np.arange(8)
print(f"JAX-level type of Numpy array: {jax.typeof(np_array)}")
print(f"JAX-level type of JAX array: {jax.typeof(A)}")
print(f"JAX-level type of Python scalar: {jax.typeof(4.2)}")

Thus, one can think of the JAX-level type to be the information about a value
within just-in-time compilation.
It is even possible to get the JAX-level type during tracing of a jit-compiled function.

In [None]:
@jax.jit
def g(x):
  print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
  return x
g(A);

Note that the JAX-level type of a sharded array still reveals it's global type, e.g. in the case of `A_sharded` the type `ShapedArray(float32[4,16])`.

In [None]:
g(A_sharded);

To make use of explicit sharding, we first have to create a mesh with *explicit axes*.

In [None]:
from jax.sharding import AxisType

mesh = jax.make_mesh((2, 4), ("x", "y"),
                     axis_types=(AxisType.Explicit, AxisType.Explicit))
print(mesh)

As the print reveals, the `axis_types` are now `Explicit`.

We now create a new array `A` of shape 8 by 4 and shard it using our mesh with explicit axes and in a way, such that it is sharded along the `x`-axis and replicated along the `y`-axis.

In [None]:
A = np.arange(32.).reshape(8, 4)
A_sharded = jax.device_put(A, jax.NamedSharding(mesh, P("x", None)))

Let's take a look the the JAX-level types of `A` and `A_sharded`.

In [None]:
print(f"Type of A: {jax.typeof(A)}")
print(f"Type of A_sharded: {jax.typeof(A_sharded)}")

One can read the type `float32[8@x, 4]` as an 8-by-4 array of 32-bit floats whose first dimension is sharded along mesh axis `x`. The array is replicated along the other mesh axis.

In [None]:
jax.debug.visualize_array_sharding(A_sharded)

The next cell defines two sharded arrays, the first one is sharded along the `x` axis and replicated along `y`, the second one is replicated along the `x` axis and sharded along the `y` axis.

In [None]:
A = jax.device_put(np.arange(4).reshape(4, 1),
                   jax.NamedSharding(mesh, P("x", None)))
B = jax.device_put(np.arange(8).reshape(1, 8),
                   jax.NamedSharding(mesh, P(None, "y")))

print('Sharding of A:\n')
jax.debug.visualize_array_sharding(A)

print('\nSharding of B:\n')
jax.debug.visualize_array_sharding(B, scale=0.2)

The next cell implements a function which just performs an elementwise multiplication (using the `*` operator) and prints the JAX-level types of the inputs and output.

In [None]:
@jax.jit
def multiply_arrays(x, y):
  ans = x * y
  print(f"x sharding: {jax.typeof(x)}")
  print(f"y sharding: {jax.typeof(y)}")
  print(f"ans sharding: {jax.typeof(ans)}")
  return ans

C = multiply_arrays(A, B)

jax.debug.visualize_array_sharding(C, scale=0.4)

C

**Note**. Shardings propagate deterministically at trace time.

### Data parallel layer of a feedforward neural network using explicit sharding

Below, you find again the example of one layer of a feedforward neural network from above.

In [None]:
@jax.jit
def layer(x, W, b):
  return jax.nn.softmax(x @ W + b)
    
n_features = 10
n_data = 128
n_out = 2

key = jax.random.key(0)
key, *keys = jax.random.split(key, 4)
x = jax.random.normal(keys[0], shape=(n_data,n_features))
W = jax.random.normal(keys[1], shape=(n_features, n_out))
b = jax.random.normal(keys[2], shape=(n_out,))

**Task**: Implement a data parallel execution of the layer using **explicit sharding** mode. Use again a mesh with one dimension and four or eight accelerators.

Write a wrapper-function which executes the layer function but also prints the `JAX`-level types during execution.

In [None]:
### Your code here

## 3rd way: Manual parallelization

The first two approaches (automatic parallelization and explicit sharding) deal with the case, where we write a function as if we are operating on the full dataset and `jax.jit` will split that computation across multiple devices.

The 3rd way uses manual parallelization:
- Here, we write a function that will handle a single shard of data, and `jax.shard_map()` will construct the full function $\leadsto$ it thus gives the most flexibility about parallelization and communication
- `shard_map()` works by mapping a function across a particular mesh of devices

In [None]:
# This is only necessary for older versions
from jax.experimental.shard_map import shard_map
# otherwise use
#from jax import shard_map

mesh = jax.make_mesh((4,), ('x',))

@jax.jit
def f(x):
    return 2*jnp.sin(x)**2 + 1
    
f_sharded = shard_map(
    f,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

A = jnp.arange(32.)
Z = f_sharded(A)

print('Sharding of f_sharded(A):\n')
jax.debug.visualize_array_sharding(Z, scale=0.2)


Note the following: 
- `jax.sharding.Mesh` allows for precise device placement
- the `in_specs` argument determines the shard sizes of the input
- the `out_specs` argument identifies how the blocks are assembled back together

---

The function which is sharded by `shard_map` only considers a single batch of the data, which you
can check by printing the device local shape:

In [None]:
x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
    print(f"Shape of x on local device: {x.shape}")
    print('\nValue of x on local device: ', x)
    print('\nType of x: ', jax.typeof(x))
    return x * 2

f_sharded = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
y = f_sharded(x)

Here, you can see the explicit shards during execution of the function.

---

Because each function only sees the device-local part of the data, it means
that aggregation-like functions don’t work as expected automatically.

See for example the function which sums all elements of an array:

In [None]:
def f(x):
  return jnp.sum(x, keepdims=True)

f_sharded = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
f_sharded(x)

Here, you see that the function operates separately on each shard, and the resulting
summation reflects this.

If you want to sum across shards instead, you need to explicitly request it suing collective operations like `jax.lax.psum()`.

In [None]:
def f(x):
  local_sum = x.sum()
  return jax.lax.psum(local_sum, 'x')

f_sharded = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())
f_sharded(x)

Note also, that the output has no longer a sharded dimension, since we specified `out_specs=P()`.

### Data parallel layer of a feedforward neural network using `shard_map`

Below, you find again the example of one layer of a feedforward neural network from above.

In [None]:
@jax.jit
def layer(x, W, b):
  return jax.nn.softmax(x @ W + b)
    
n_features = 10
n_data = 128
n_out = 2

key = jax.random.key(0)
key, *keys = jax.random.split(key, 4)
x = jax.random.normal(keys[0], shape=(n_data,n_features))
W = jax.random.normal(keys[1], shape=(n_features, n_out))
b = jax.random.normal(keys[2], shape=(n_out,))

**Task**: Implement a data parallel execution of the layer using manual parallelization through `shard_map`. Use a mesh with one dimension and four or eight accelerators.

In [None]:
### Your code here

In [None]:
out_man.shape

In [None]:
jax.debug.visualize_array_sharding(out_man)

## Neural networks using automatic parallelization

The following implements a feedforward neural network that can be used for solving regression problems. To see the expected speedup of parallelization, it should be executed on a machine with either multiple TPUs or GPUs.
This is due to the fact that the XLA compiler in `jax.jit` already parallelizes such simple programs automatically quite well and uses all CPU resources available.
You can verify this by taking a look at some resource monitor (e.g. `htop`) while the computations run.

Note however, that you can still implement the data parallelization as requested in the Task below. We are using a similar feedforward network model as used in Lab 2.

In [None]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
import jax
from tqdm import tqdm

# Define population line
def f(x):
    return jnp.sin(6*x) - .5 * x**2

def gen_data(key, f, n=100, sigma=0.1, xmin=0.0, xmax=2.0):
    keys = jax.random.split(key, 2)
    x = jax.random.uniform(keys[0], shape=(n, 1), minval=xmin, maxval=xmax)
    y = f(x) + sigma * jax.random.normal(keys[1], shape=(n, 1))
    return x, y

def random_layer_params(n_in, n_out, key, scale=1e-1):
    w_key, b_key = jax.random.split(key)
    # Weight matrix
    w = scale * jax.random.normal(w_key, (n_in, n_out))
    # Bias vector
    b = scale * jax.random.normal(b_key, (n_out,))
    return w, b

def init_network_params(sizes, key):
    keys = jax.random.split(key, len(sizes) - 1)
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def forward(params, x):
    for w, b in params[:-1]:
        outputs = jnp.dot(x, w) + b
        x = jnp.tanh(outputs)

    final_w, final_b = params[-1]
    out = jnp.dot(x, final_w) + final_b
    return out
    
def loss(params, x, y):
    return jnp.mean(jnp.square(forward(params, x) - y))
    
grad_loss = jax.jit(jax.grad(loss, argnums=0))

def train(params, x, y, epochs=100, step_size=1e-3):

    for _ in tqdm(range(epochs)):
        grads = grad_loss(params, x, y)
        params = [(W - step_size * dW, b - step_size * db)
                  for (W, b), (dW, db) in zip(params, grads)]
    return params


In [None]:
global_key = jax.random.key(seed=0)
gen_key, *keys = jax.random.split(global_key,3)

if jax.devices()[0].device_kind == 'cpu':
    n = 1024
    layer_sizes = [1, 1024, 1024, 1]
else:
    n = 8192*2
    layer_sizes = [1, 8192, 8192, 8192, 1]

x, y = gen_data(keys[0], f, n=n)
params = init_network_params(layer_sizes, keys[1])

# Execute grad_loss once to compile function
grad_loss(params, x, y);

In [None]:
params_single = train(params, x, y)

**Task**: Implement the training step using automatic parallelization. The training data `(x,y)` should be sharded across all devices, the parameters `(w, b)` should be replicated. Finally, train the model using your sharded data.

In [None]:
### Your code here

**Task**: Compare the runtimes of the function `grad_loss` using both the sharded and unsharded data. On a machine with 8 TPUs, like the one that is available on Google Colab for free, you should observe a speedup factor of around 6 to 8.

On a CPU, the implemention that uses your parallelization takes probably longer due to the automatic parallelization and optimizations of the XLA compiler when applied to the unsharded data.

In [None]:
### Your code here