# Parallelization: Sharded Matrices and Collective Operations

This tutorial covers the basics of sharded matrix operations and collection primitives that are often used in distributed training and inference in deep learning.

## 0. Why JAX?

[JAX](https://docs.jax.dev/) is a high-performance numerical computing library that combines NumPy's familiar API with the power of automatic differentiation and hardware acceleration on GPUs and TPUs. Developed by Google Research, JAX enables writing high-performance code that can run efficiently on a single device or scale across multiple devices.


We use JAX for this notebook (and in the MinText library) because of several reasons:

1. Beginner-friendly automatic parallelization using `jax.jit`
2. Ability to simulate multiple devices using `"XLA_FLAGS"`
3. Google Colab provides an 8 device runtime, v2-8 TPU, for free. This consists of 8x8 GB TPU cores which adds up to a total of 64 GB VRAM compute. So you can actually run distributed operations over 8 devices.
4. There are already several great pedagogical style libraries in Pytorch (such as [HuggingFace Nanotron](https://github.com/huggingface/nanotron)) which serve a similar purpose. The key concepts from these tutorials can often be directly translated to PyTorch.

## 1. Multi-Device Computation

Choose the v2-8 TPU runtime in Google Colab to run this notebook. Once you restart the  and exploring the available devices.

In [None]:
# Install JAX if needed
# !pip install --upgrade jax jaxlib

In [1]:
# Uncomment this to simulate running the code on 8 CPU devices (use for local runs)
# import os
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

In [22]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
import numpy as np
import time
import matplotlib.pyplot as plt
from typing import Optional

In [23]:
# Check available devices
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

JAX version: 0.5.2
Available devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Device count: 8


If you cannot access the v2-8 TPU from Google Colab (you timed out or are running this locally) restart this notebook and un-comment the `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` flag to simulate multiple devices on your CPU.

The output of the cell above should then change to look something like this:

```bash
Available devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Device count: 8
```

### Creating a Device Mesh

JAX uses the concept of a device mesh to organize available devices. A device can be a CPU (or a CPU core), GPU, or TPU for JAX's purpose. A mesh is a multi-dimensional array of devices that can be addressed along different axes. This allows us to partition our data and computation along different dimensions of the mesh.

In [3]:
devices = jax.devices()

# If you have multiple devices, you can create a 2D mesh
if len(devices) >= 4:
    mesh_2d = jax.make_mesh((2, 2), ('x', 'y'))
    print(f"2D Mesh: {mesh_2d}")
    print([d.id for d in mesh_2d.devices.flat])
else:
    print("Not enough devices for a 2D mesh demonstration")

2D Mesh: Mesh('x': 2, 'y': 2)
[0, 1, 2, 3]


## 2. Sharded Matrices

Sharded matrices are arrays that are split across multiple devices. JAX provides abstractions to create and operate on these distributed arrays efficiently.

### Basic Concepts

- **Mesh**: A logical arrangement of devices
- **PartitionSpec (P)**: Specifies how to partition an array across mesh dimensions
- **jax.device_put()**: Places an array on a specific device or according to a sharding

Let's see how to create sharded arrays using JAX's sharding API:

![Sharding](https://github.com/jax-ml/scaling-book/blob/main/assets/img/sharding-example.png?raw=true)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

In [4]:
# Let's create a large matrix and shard it
matrix_size = 8192  # Adjust based on your device memory
matrix = jnp.ones((matrix_size, matrix_size))
print(f"Matrix shape: {matrix.shape}, Size in memory: {matrix.size * 4 / (1024**2):.2f} MB")

Matrix shape: (8192, 8192), Size in memory: 256.00 MB


In [5]:
# Define the partition spec - shard along the first dimension
partition_spec = P('x', 'y')

# Create a shardings object
from jax.sharding import NamedSharding
shardings = NamedSharding(mesh_2d, partition_spec)

# Create the sharded array
sharded_matrix = jax.device_put(x=matrix, device=shardings)

print(f"Sharded matrix type: {type(sharded_matrix)}")
print(f"Sharding spec: {shardings}")

jax.debug.visualize_array_sharding(sharded_matrix)

Sharded matrix type: <class 'jaxlib.xla_extension.ArrayImpl'>
Sharding spec: NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec('x', 'y'), memory_kind=device)


### Inspecting Sharded Matrices

We can inspect how the array is distributed across devices:

In [10]:
print("Global matrix shape", sharded_matrix.shape)

# Get the local arrays on each device
print("Shapes of Matrix Shards:")
for shard in sharded_matrix.addressable_shards:
  print(shard.device, shard.index, shard.data.shape)

Global matrix shape (8192, 8192)
Shapes of Matrix Shards:
TPU_0(process=0,(0,0,0,0)) (slice(0, 4096, None), slice(0, 4096, None)) (4096, 4096)
TPU_1(process=0,(0,0,0,1)) (slice(0, 4096, None), slice(4096, 8192, None)) (4096, 4096)
TPU_2(process=0,(1,0,0,0)) (slice(4096, 8192, None), slice(0, 4096, None)) (4096, 4096)
TPU_3(process=0,(1,0,0,1)) (slice(4096, 8192, None), slice(4096, 8192, None)) (4096, 4096)


![Array Sharding Visualization](https://jax-ml.github.io/scaling-book/assets/img/sharding-colored4.png)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

### Performance Benefits

Let's compare the performance of element-wise operations on sharded vs. non-sharded matrices (we will matrix-level operations in a bit)

In [20]:
# `matrix` is present on a single device
%timeit -n 5 -r 5 jnp.sin(matrix).block_until_ready()

2.88 ms ± 36.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [21]:
# `sharded_matrix` is distributed across 4 devices
%timeit -n 5 -r 5 jnp.sin(sharded_matrix).block_until_ready()

1.64 ms ± 55.9 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Different Sharding Strategies

Each axis of the matrix can be sharded across each possible axis in the device mesh. This gives rise to a combinatorial number of possible shardings.

In [24]:
# A small helper function to define a sharding mesh
default_mesh = jax.make_mesh((2, 2), ('a', 'b'))

def mesh_sharding(
    pspec: P, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)

In [25]:
# Shard first axis of matrix along mesh axis 'a', second axis of matrix along mesh axis 'b'
ix_jy_sharding = jax.device_put(matrix, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(ix_jy_sharding)

In [27]:
# Shard first axis of matrix along mesh axis 'b', second axis of matrix along mesh axis 'a'
iy_jx_sharding = jax.device_put(matrix, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(iy_jx_sharding)

In [29]:
# Shard first axis of matrix along mesh axis 'a', replicate second axis of matrix along each shard
ix_j_sharding = jax.device_put(matrix, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(ix_j_sharding, use_color=False)

In [31]:
# Partition second axis of x over second mesh axis 'b', replicate first axis of matrix along each shard
i_jy_shard = jax.device_put(matrix, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(i_jy_shard, use_color=False)

For a 2D matrix being sharded along a 2D device mesh, here are all the possible sharding strategies

![Possible Array Shardings](https://jax-ml.github.io/scaling-book/assets/img/sharding-colored5.png)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

### Choosing the Right Sharding Strategy

As we discuss collectives (next section) and different parallelism strategies in (next tutorials), we will slowly do a deeper dive into how to chose sharding strategies based on your use case.

## 3. Collective Operations

Collective operations are essential for distributed algorithms where devices need to share or aggregate information. In deep learning, these are utilized in performing computations with sharded arrays.

### Matrix Operations With Sharded Matrices

#### Case 1: No Contracting Dimension

#### Case 2 (All-Gather): One matrix has a sharded contracting dimension

All-reduce aggregates values across devices using an operation like sum, mean, max, etc. Each device ends up with the same aggregated result.

In [None]:
from jax.experimental import pjit

# Create device-specific values
def create_device_values(mesh):
    # Create different values for each device
    device_values = []
    for i, device in enumerate(mesh.devices.flat):
        value = jnp.ones((10, 10)) * (i + 1)  # Each device gets a different value
        device_values.append(jax.device_put(value, device))
    return device_values

with mesh_1d:
    # Example of all-reduce using pjit
    @jax.jit
    def all_reduce_sum(x):
        # Explicit all-reduce
        return jax.lax.psum(x, axis_name='devices')

    # Create a sharded array where each device has a different value
    n_devices = jax.device_count()
    sharded_values = jnp.arange(1, n_devices + 1).reshape((n_devices, 1))
    shardings = NamedSharding(mesh_1d, P('devices', None))
    sharded_array = jax.device_put(sharded_values, shardings)

    # Perform the all-reduce
    result = all_reduce_sum(sharded_array)
    print(f"Original values:\n{sharded_values}")
    print(f"After all-reduce sum:\n{result}")

    # Expected result: each device should have the sum of all values
    expected_sum = jnp.sum(jnp.arange(1, n_devices + 1))
    print(f"Expected sum: {expected_sum}")

#### Case 3 (All-Reduce): Both Matrices have sharded contracting dimensions

All-gather collects values from all devices, resulting in each device having a complete copy of all values.

In [None]:
with mesh_1d:
    # Example of all-gather
    @jax.jit
    def all_gather(x):
        # Explicit all-gather
        return jax.lax.all_gather(x, axis_name='devices', axis=0)

    # Use the same sharded array from before
    result = all_gather(sharded_array)
    print(f"Original values:\n{sharded_values}")
    print(f"After all-gather:\n{result}")

    # Expected result: each device should have all values
    expected_gather = jnp.arange(1, n_devices + 1).reshape((n_devices, 1))
    print(f"Expected result:\n{expected_gather}")

### Other Collectives

### Reduce-Scatter

#### All-to-All

All-to-all exchanges slices of data between all devices. This is useful for operations like matrix transposition or redistributing data with a different sharding.

In [None]:
if jax.device_count() >= 2:
    with mesh_1d:
        # Example of all-to-all
        @jax.jit
        def all_to_all(x):
            # Explicit all-to-all - reshape to have a split dimension
            return jax.lax.all_to_all(x, axis_name='devices', split_axis=0, concat_axis=1)

        # Create a matrix where each device has part of the rows
        n_devices = jax.device_count()
        data = jnp.arange(n_devices * n_devices).reshape((n_devices, n_devices))
        shardings = NamedSharding(mesh_1d, P('devices', None))
        sharded_data = jax.device_put(data, shardings)

        # Perform the all-to-all
        result = all_to_all(sharded_data)
        print(f"Original data:\n{data}")
        print(f"After all-to-all:\n{result}")

        # This effectively transposes the sharding from rows to columns
        print(f"Original sharding: {sharded_data.sharding}")
        print(f"Result sharding: {result.sharding}")
else:
    print("Need at least 2 devices for all-to-all demonstration")

### Performance Considerations for Collective Operations

When working with collective operations, consider:

1. **Communication overhead**: Collective operations require device-to-device communication
2. **Data size**: Larger transfers take more time
3. **Network topology**: The physical connections between devices matter
4. **Frequency**: Minimize the number of collective operations in your code

Let's benchmark a simple all-reduce operation:

In [None]:
if jax.device_count() > 1:
    # Benchmark all-reduce with different data sizes
    sizes = [10, 100, 1000, 10000]
    times = []

    with mesh_1d:
        for size in sizes:
            # Create data
            data = jnp.ones((size, size))
            shardings = NamedSharding(mesh_1d, P('devices', None))
            sharded_data = jax.device_put(data, shardings)

            # Define and compile the all-reduce
            @jax.jit
            def all_reduce(x):
                return jax.lax.psum(x, axis_name='devices')

            # Warm-up
            all_reduce(sharded_data).block_until_ready()

            # Benchmark
            start = time.time()
            all_reduce(sharded_data).block_until_ready()
            elapsed = time.time() - start
            times.append(elapsed)

            print(f"Size {size}x{size}, Time: {elapsed:.6f} seconds")

    # Plot the results
    plt.figure(figsize=(10, 6))
    plt.plot(sizes, times, 'o-')
    plt.xscale('log')
    plt.xlabel('Matrix Size')
    plt.ylabel('Time (seconds)')
    plt.title('All-Reduce Performance by Data Size')
    plt.grid(True)
    plt.show()
else:
    print("Multiple devices required for benchmark")

## 5. Practical Example: Distributed Matrix Multiplication

Let's implement a distributed matrix multiplication algorithm using sharded matrices and collective operations.

We'll implement a simple version of the cannon algorithm for distributed matrix multiplication:

In [None]:
def create_matrices(n):
    """Create two n x n matrices."""
    a = jnp.ones((n, n))
    b = jnp.ones((n, n))
    return a, b

def baseline_matmul(a, b):
    """Standard matrix multiplication on a single device."""
    return jnp.matmul(a, b)

def distributed_matmul(a, b, mesh):
    """Distributed matrix multiplication using sharded arrays."""
    with mesh:
        # Shard matrices along rows and columns
        row_spec = P('devices', None)
        col_spec = P(None, 'devices')

        # Put A in row-sharded form and B in column-sharded form
        a_shardings = NamedSharding(mesh, row_spec)
        b_shardings = NamedSharding(mesh, col_spec)

        a_sharded = jax.device_put(a, a_shardings)
        b_sharded = jax.device_put(b, b_shardings)

        # Define the distributed multiplication
        @jax.jit
        def sharded_matmul(a, b):
            # Compute local product
            local_product = jnp.matmul(a, b)
            # Reduce across devices to get final result
            return jax.lax.psum(local_product, axis_name='devices')

        return sharded_matmul(a_sharded, b_sharded)

# Benchmark the two approaches
if jax.device_count() > 1:
    # Parameters
    n = 2000  # Matrix size

    # Create matrices
    a, b = create_matrices(n)

    # Baseline (single device)
    baseline_fn = jax.jit(baseline_matmul)
    baseline_fn(a, b).block_until_ready()  # Warm-up

    start = time.time()
    c_baseline = baseline_fn(a, b).block_until_ready()
    baseline_time = time.time() - start
    print(f"Baseline time: {baseline_time:.4f} seconds")

    # Distributed
    start = time.time()
    c_distributed = distributed_matmul(a, b, mesh_1d).block_until_ready()
    distributed_time = time.time() - start
    print(f"Distributed time: {distributed_time:.4f} seconds")

    # Verify results are similar
    diff = jnp.abs(c_baseline - c_distributed).max()
    print(f"Maximum difference: {diff}")

    # Calculate speedup
    speedup = baseline_time / distributed_time
    print(f"Speedup: {speedup:.2f}x")
else:
    print("Multiple devices required for benchmark")

## 6. Advanced Topics

### Automatic Sharding with SPMD (Single Program Multiple Data)

JAX provides the `pjit` (Partitioned JIT) API for automatic sharding. With `pjit`, you specify the sharding of inputs and outputs, and JAX determines the optimal intermediate shardings.

In [None]:
from jax.experimental import pjit

def auto_sharded_matmul(a, b, mesh):
    """Matrix multiplication with automatic sharding."""
    with mesh:
        # Define input and output specs
        in_specs = (P('devices', None), P(None, 'devices'))  # Row and column sharding
        out_spec = P(None, None)  # Output on all devices

        # Define the pjit function
        pjit_matmul = pjit.pjit(
            lambda x, y: jnp.matmul(x, y),
            in_shardings=in_specs,
            out_shardings=out_spec
        )

        return pjit_matmul(a, b)

if jax.device_count() > 1:
    # Try the auto-sharded version
    a, b = create_matrices(2000)

    start = time.time()
    c_auto = auto_sharded_matmul(a, b, mesh_1d).block_until_ready()
    auto_time = time.time() - start
    print(f"Auto-sharded time: {auto_time:.4f} seconds")

    # Compare with our manual implementation
    if 'distributed_time' in locals():
        print(f"Manual vs Auto ratio: {distributed_time / auto_time:.2f}x")
else:
    print("Multiple devices required for demonstration")

### Custom Sharding Rules

JAX allows defining custom partitioning rules for operations. This is useful for operations where the default partitioning might not be optimal.

In [None]:
# Advanced topic: Custom partitioning rules
from functools import partial

def custom_matmul_rule(mesh):
    # This is a simplified example - real custom rules would be more complex
    def matmul_with_custom_rule(x, y):
        # Custom implementation that's aware of sharding
        return jnp.matmul(x, y)

    return matmul_with_custom_rule

# Example usage (conceptual)
# my_custom_matmul = custom_matmul_rule(mesh_1d)
# result = my_custom_matmul(sharded_a, sharded_b)

## 8. Conclusion

In this tutorial, we've explored JAX's powerful capabilities for distributed computation using sharded matrices and collective operations. We've covered:

1. **Setting up a device mesh** for organizing available devices
2. **Creating and using sharded matrices** to distribute data across devices
3. **Different sharding strategies** and when to use them
4. **Collective operations** for efficient device-to-device communication
5. **Implementing distributed algorithms** using these primitives
6. **Automatic sharding** with pjit for easier distributed programming
7. **Application to distributed training** for machine learning models

JAX's sharding capabilities enable efficient scaling of numerical computations across multiple devices, making it a powerful tool for large-scale machine learning and scientific computing.

### Further Resources

- [JAX Documentation](https://jax.readthedocs.io/)
- [JAX Docs: Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
- [JAX Docs: Manual parallelism with `shard_map`](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
- [How to Scale Your Model: Sharded Matrices and How to Multiply Them](https://jax-ml.github.io/scaling-book/sharding/)