# JAX Parallelization: Sharded Matrices and Collective Operations

This tutorial explores how to effectively use JAX's distributed array abstractions and collective operations for efficient model parallelism. We'll cover:

1. **Introduction to JAX and its parallelism features**
2. **Setting up a JAX environment for multi-device computation**
3. **Sharded Matrices: Partitioning arrays across devices**
4. **Collective Operations: Efficient communication between devices**
5. **Practical examples: Implementing distributed algorithms**

Let's begin!

## 1. Introduction to JAX

JAX is a high-performance numerical computing library that combines NumPy's familiar API with the power of automatic differentiation and hardware acceleration on GPUs and TPUs. Developed by Google Research, JAX enables writing high-performance code that can run efficiently on a single device or scale across multiple devices.

Key features of JAX include:

- **NumPy-like API**: Familiar interface for array operations
- **Automatic differentiation**: Compute gradients with `grad`, `value_and_grad`, etc.
- **Acceleration**: Seamless execution on GPUs and TPUs
- **Function transformations**: `jit`, `vmap`, `pmap`, etc.
- **Parallelism primitives**: Distribute computation across devices

In this tutorial, we'll focus specifically on JAX's capabilities for distributed computation using sharded matrices and collective operations.

## 2. Setting Up for Multi-Device Computation

Let's start by setting up our environment and exploring the available devices.

In [None]:
# Install JAX if needed
# !pip install --upgrade jax jaxlib

In [None]:
# import os
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

In [2]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
import numpy as np
import time
import matplotlib.pyplot as plt

In [4]:
# Check available devices
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

JAX version: 0.6.1
Available devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Device count: 8


### Creating a Device Mesh

JAX uses the concept of a device mesh to organize available devices. A mesh is a multi-dimensional array of devices that can be addressed along different axes. This allows us to partition our data and computation along different dimensions of the mesh.

In [5]:
# Create a simple 1D mesh using all available devices
devices = jax.devices()
mesh_1d = Mesh(devices, axis_names=('devices',))
print(f"1D Mesh: {mesh_1d}")

# If you have multiple devices, you can create a 2D mesh
if len(devices) >= 4:
    # Reshape into a 2D mesh (for example, 2x2 for 4 devices)
    devices_array = np.array(devices[:4]).reshape((2, 2))
    mesh_2d = Mesh(devices_array, axis_names=('rows', 'cols'))
    print(f"2D Mesh: {mesh_2d}")
else:
    print("Not enough devices for a 2D mesh demonstration")

1D Mesh: Mesh('devices': 8, axis_types=(Auto,))
2D Mesh: Mesh('rows': 2, 'cols': 2, axis_types=(Auto, Auto))


## 3. Sharded Matrices in JAX

Sharded matrices are arrays that are split across multiple devices. JAX provides abstractions to create and operate on these distributed arrays efficiently.

### Basic Concepts

- **Mesh**: A logical arrangement of devices
- **PartitionSpec (P)**: Specifies how to partition an array across mesh dimensions
- **jax.device_put()**: Places an array on a specific device or according to a sharding

Let's see how to create sharded arrays using JAX's sharding API:

In [6]:
# Let's create a large matrix and shard it
matrix_size = 8000  # Adjust based on your device memory
large_matrix = jnp.ones((matrix_size, matrix_size))
print(f"Matrix shape: {large_matrix.shape}, Size in memory: {large_matrix.size * 4 / (1024**2):.2f} MB")

Matrix shape: (8000, 8000), Size in memory: 244.14 MB


In [7]:
# Create a sharded version of the matrix
with mesh_1d:
    # Define the partition spec - shard along the first dimension
    partition_spec = P('devices', None)
    
    # Create a shardings object
    from jax.sharding import NamedSharding
    shardings = NamedSharding(mesh_1d, partition_spec)
    
    # Create the sharded array
    sharded_matrix = jax.device_put(large_matrix, shardings)
    
    print(f"Sharded matrix type: {type(sharded_matrix)}")
    print(f"Sharding spec: {shardings}")

Sharded matrix type: <class 'jaxlib._jax.ArrayImpl'>
Sharding spec: NamedSharding(mesh=Mesh('devices': 8, axis_types=(Auto,)), spec=PartitionSpec('devices', None), memory_kind=unpinned_host)


### Inspecting Sharded Arrays

We can inspect how the array is distributed across devices:

In [8]:
# Inspecting the sharding
print(f"Global shape: {sharded_matrix.shape}")

# How is the array sharded?
sharding = sharded_matrix.sharding
print(f"Sharding: {sharding}")

# Get the local arrays on each device
local_chunks = jax.experimental.multihost_utils.process_allgather(sharded_matrix)
print(f"Number of chunks: {len(local_chunks)}")
if len(local_chunks) > 0:
    print(f"First chunk shape: {local_chunks[0].shape}")

Global shape: (8000, 8000)
Sharding: NamedSharding(mesh=Mesh('devices': 8, axis_types=(Auto,)), spec=PartitionSpec('devices', None), memory_kind=unpinned_host)
Number of chunks: 1
First chunk shape: (8000, 8000)


### Performance Benefits

Let's compare the performance of operations on sharded vs. non-sharded matrices:

In [9]:
# Define a simple matrix multiplication operation
@jax.jit
def matrix_multiply(x):
    return jnp.matmul(x, x)

# Create a smaller matrix for this benchmark
benchmark_size = 4000  # Adjust based on your hardware
matrix = jnp.ones((benchmark_size, benchmark_size))

# Benchmark on a single device
start = time.time()
result_single = matrix_multiply(matrix).block_until_ready()
time_single = time.time() - start
print(f"Single device time: {time_single:.4f} seconds")

# Benchmark with sharding (if multiple devices are available)
if jax.device_count() > 1:
    with mesh_1d:
        shardings = NamedSharding(mesh_1d, P('devices', None))
        sharded_matrix = jax.device_put(matrix, shardings)
        
        # Define a sharded computation
        @jax.jit
        def sharded_matrix_multiply(x):
            return jnp.matmul(x, x)
        
        start = time.time()
        result_sharded = sharded_matrix_multiply(sharded_matrix).block_until_ready()
        time_sharded = time.time() - start
        print(f"Sharded time: {time_sharded:.4f} seconds")
        print(f"Speedup: {time_single / time_sharded:.2f}x")
else:
    print("Multiple devices required for sharding benchmark")

Single device time: 0.1912 seconds
Sharded time: 0.2180 seconds
Speedup: 0.88x


### Different Sharding Strategies

JAX supports various sharding strategies, including:

1. **Row Sharding**: Distribute rows across devices
2. **Column Sharding**: Distribute columns across devices
3. **2D Block Sharding**: Distribute blocks in a grid pattern

Let's implement and compare these strategies:

In [None]:
# Create a matrix for our demonstrations
demo_matrix = jnp.ones((1024, 1024))

with mesh_1d:
    # 1. Row sharding
    row_spec = P('devices', None)
    row_shardings = NamedSharding(mesh_1d, row_spec)
    row_sharded = jax.device_put(demo_matrix, row_shardings)
    
    # 2. Column sharding
    col_spec = P(None, 'devices')
    col_shardings = NamedSharding(mesh_1d, col_spec)
    col_sharded = jax.device_put(demo_matrix, col_shardings)
    
    print(f"Row sharding: {row_shardings}")
    print(f"Column sharding: {col_shardings}")
    
    # If we have a 2D mesh, we can do 2D block sharding
    if 'mesh_2d' in locals():
        with mesh_2d:
            # 3. 2D Block sharding
            block_spec = P('rows', 'cols')
            block_shardings = NamedSharding(mesh_2d, block_spec)
            block_sharded = jax.device_put(demo_matrix, block_shardings)
            print(f"Block sharding: {block_shardings}")

### Choosing the Right Sharding Strategy

The optimal sharding strategy depends on:

1. **Computation pattern**: Match sharding to your algorithm's access patterns
2. **Device topology**: Consider the physical layout of your devices
3. **Operation types**: Different operations benefit from different sharding strategies

For example:
- Matrix multiplication benefits from row/column sharding
- Convolutional operations may benefit from spatial partitioning
- Element-wise operations work well with any partitioning

## 4. Collective Operations in JAX

Collective operations enable efficient communication between devices. These are essential for distributed algorithms where devices need to share or aggregate information.

Common collective operations include:
- **all-reduce**: Aggregate values across devices (sum, mean, etc.)
- **all-gather**: Collect values from all devices
- **reduce-scatter**: Combine reduce and scatter operations
- **all-to-all**: Exchange data between all devices

Let's explore these operations in JAX:

### All-Reduce

All-reduce aggregates values across devices using an operation like sum, mean, max, etc. Each device ends up with the same aggregated result.

In [None]:
from jax.experimental import pjit

# Create device-specific values
def create_device_values(mesh):
    # Create different values for each device
    device_values = []
    for i, device in enumerate(mesh.devices.flat):
        value = jnp.ones((10, 10)) * (i + 1)  # Each device gets a different value
        device_values.append(jax.device_put(value, device))
    return device_values

with mesh_1d:
    # Example of all-reduce using pjit
    @jax.jit
    def all_reduce_sum(x):
        # Explicit all-reduce
        return jax.lax.psum(x, axis_name='devices')
    
    # Create a sharded array where each device has a different value
    n_devices = jax.device_count()
    sharded_values = jnp.arange(1, n_devices + 1).reshape((n_devices, 1))
    shardings = NamedSharding(mesh_1d, P('devices', None))
    sharded_array = jax.device_put(sharded_values, shardings)
    
    # Perform the all-reduce
    result = all_reduce_sum(sharded_array)
    print(f"Original values:\n{sharded_values}")
    print(f"After all-reduce sum:\n{result}")
    
    # Expected result: each device should have the sum of all values
    expected_sum = jnp.sum(jnp.arange(1, n_devices + 1))
    print(f"Expected sum: {expected_sum}")

### All-Gather

All-gather collects values from all devices, resulting in each device having a complete copy of all values.

In [None]:
with mesh_1d:
    # Example of all-gather
    @jax.jit
    def all_gather(x):
        # Explicit all-gather
        return jax.lax.all_gather(x, axis_name='devices', axis=0)
    
    # Use the same sharded array from before
    result = all_gather(sharded_array)
    print(f"Original values:\n{sharded_values}")
    print(f"After all-gather:\n{result}")
    
    # Expected result: each device should have all values
    expected_gather = jnp.arange(1, n_devices + 1).reshape((n_devices, 1))
    print(f"Expected result:\n{expected_gather}")

### All-to-All

All-to-all exchanges slices of data between all devices. This is useful for operations like matrix transposition or redistributing data with a different sharding.

In [None]:
if jax.device_count() >= 2:
    with mesh_1d:
        # Example of all-to-all
        @jax.jit
        def all_to_all(x):
            # Explicit all-to-all - reshape to have a split dimension
            return jax.lax.all_to_all(x, axis_name='devices', split_axis=0, concat_axis=1)
        
        # Create a matrix where each device has part of the rows
        n_devices = jax.device_count()
        data = jnp.arange(n_devices * n_devices).reshape((n_devices, n_devices))
        shardings = NamedSharding(mesh_1d, P('devices', None))
        sharded_data = jax.device_put(data, shardings)
        
        # Perform the all-to-all
        result = all_to_all(sharded_data)
        print(f"Original data:\n{data}")
        print(f"After all-to-all:\n{result}")
        
        # This effectively transposes the sharding from rows to columns
        print(f"Original sharding: {sharded_data.sharding}")
        print(f"Result sharding: {result.sharding}")
else:
    print("Need at least 2 devices for all-to-all demonstration")

### Performance Considerations for Collective Operations

When working with collective operations, consider:

1. **Communication overhead**: Collective operations require device-to-device communication
2. **Data size**: Larger transfers take more time
3. **Network topology**: The physical connections between devices matter
4. **Frequency**: Minimize the number of collective operations in your code

Let's benchmark a simple all-reduce operation:

In [None]:
if jax.device_count() > 1:
    # Benchmark all-reduce with different data sizes
    sizes = [10, 100, 1000, 10000]
    times = []
    
    with mesh_1d:
        for size in sizes:
            # Create data
            data = jnp.ones((size, size))
            shardings = NamedSharding(mesh_1d, P('devices', None))
            sharded_data = jax.device_put(data, shardings)
            
            # Define and compile the all-reduce
            @jax.jit
            def all_reduce(x):
                return jax.lax.psum(x, axis_name='devices')
            
            # Warm-up
            all_reduce(sharded_data).block_until_ready()
            
            # Benchmark
            start = time.time()
            all_reduce(sharded_data).block_until_ready()
            elapsed = time.time() - start
            times.append(elapsed)
            
            print(f"Size {size}x{size}, Time: {elapsed:.6f} seconds")
    
    # Plot the results
    plt.figure(figsize=(10, 6))
    plt.plot(sizes, times, 'o-')
    plt.xscale('log')
    plt.xlabel('Matrix Size')
    plt.ylabel('Time (seconds)')
    plt.title('All-Reduce Performance by Data Size')
    plt.grid(True)
    plt.show()
else:
    print("Multiple devices required for benchmark")

## 5. Practical Example: Distributed Matrix Multiplication

Let's implement a distributed matrix multiplication algorithm using sharded matrices and collective operations.

We'll implement a simple version of the cannon algorithm for distributed matrix multiplication:

In [None]:
def create_matrices(n):
    """Create two n x n matrices."""
    a = jnp.ones((n, n))
    b = jnp.ones((n, n))
    return a, b

def baseline_matmul(a, b):
    """Standard matrix multiplication on a single device."""
    return jnp.matmul(a, b)

def distributed_matmul(a, b, mesh):
    """Distributed matrix multiplication using sharded arrays."""
    with mesh:
        # Shard matrices along rows and columns
        row_spec = P('devices', None)
        col_spec = P(None, 'devices')
        
        # Put A in row-sharded form and B in column-sharded form
        a_shardings = NamedSharding(mesh, row_spec)
        b_shardings = NamedSharding(mesh, col_spec)
        
        a_sharded = jax.device_put(a, a_shardings)
        b_sharded = jax.device_put(b, b_shardings)
        
        # Define the distributed multiplication
        @jax.jit
        def sharded_matmul(a, b):
            # Compute local product
            local_product = jnp.matmul(a, b)
            # Reduce across devices to get final result
            return jax.lax.psum(local_product, axis_name='devices')
        
        return sharded_matmul(a_sharded, b_sharded)

# Benchmark the two approaches
if jax.device_count() > 1:
    # Parameters
    n = 2000  # Matrix size
    
    # Create matrices
    a, b = create_matrices(n)
    
    # Baseline (single device)
    baseline_fn = jax.jit(baseline_matmul)
    baseline_fn(a, b).block_until_ready()  # Warm-up
    
    start = time.time()
    c_baseline = baseline_fn(a, b).block_until_ready()
    baseline_time = time.time() - start
    print(f"Baseline time: {baseline_time:.4f} seconds")
    
    # Distributed
    start = time.time()
    c_distributed = distributed_matmul(a, b, mesh_1d).block_until_ready()
    distributed_time = time.time() - start
    print(f"Distributed time: {distributed_time:.4f} seconds")
    
    # Verify results are similar
    diff = jnp.abs(c_baseline - c_distributed).max()
    print(f"Maximum difference: {diff}")
    
    # Calculate speedup
    speedup = baseline_time / distributed_time
    print(f"Speedup: {speedup:.2f}x")
else:
    print("Multiple devices required for benchmark")

## 6. Advanced Topics

### Automatic Sharding with SPMD (Single Program Multiple Data)

JAX provides the `pjit` (Partitioned JIT) API for automatic sharding. With `pjit`, you specify the sharding of inputs and outputs, and JAX determines the optimal intermediate shardings.

In [None]:
from jax.experimental import pjit

def auto_sharded_matmul(a, b, mesh):
    """Matrix multiplication with automatic sharding."""
    with mesh:
        # Define input and output specs
        in_specs = (P('devices', None), P(None, 'devices'))  # Row and column sharding
        out_spec = P(None, None)  # Output on all devices
        
        # Define the pjit function
        pjit_matmul = pjit.pjit(
            lambda x, y: jnp.matmul(x, y),
            in_shardings=in_specs,
            out_shardings=out_spec
        )
        
        return pjit_matmul(a, b)

if jax.device_count() > 1:
    # Try the auto-sharded version
    a, b = create_matrices(2000)
    
    start = time.time()
    c_auto = auto_sharded_matmul(a, b, mesh_1d).block_until_ready()
    auto_time = time.time() - start
    print(f"Auto-sharded time: {auto_time:.4f} seconds")
    
    # Compare with our manual implementation
    if 'distributed_time' in locals():
        print(f"Manual vs Auto ratio: {distributed_time / auto_time:.2f}x")
else:
    print("Multiple devices required for demonstration")

### Custom Sharding Rules

JAX allows defining custom partitioning rules for operations. This is useful for operations where the default partitioning might not be optimal.

In [None]:
# Advanced topic: Custom partitioning rules
from functools import partial

def custom_matmul_rule(mesh):
    # This is a simplified example - real custom rules would be more complex
    def matmul_with_custom_rule(x, y):
        # Custom implementation that's aware of sharding
        return jnp.matmul(x, y)
    
    return matmul_with_custom_rule

# Example usage (conceptual)
# my_custom_matmul = custom_matmul_rule(mesh_1d)
# result = my_custom_matmul(sharded_a, sharded_b)

## 7. Real-World Application: Distributed Training

Let's see how sharded matrices and collective operations are used in a simplified distributed training loop:

In [None]:
def create_model(rng_key):
    """Create a simple linear model."""
    return jnp.ones((1000, 100))

def create_batch(rng_key, batch_size=32):
    """Create a batch of data."""
    x = jax.random.normal(rng_key, (batch_size, 1000))
    y = jax.random.normal(rng_key, (batch_size, 100))
    return x, y

def loss_fn(params, x, y):
    """Compute MSE loss."""
    preds = jnp.matmul(x, params)
    return jnp.mean((preds - y) ** 2)

def grad_fn(params, x, y):
    """Compute gradients."""
    return jax.grad(loss_fn)(params, x, y)

def distributed_training_step(params, batch, mesh, learning_rate=0.01):
    """Perform a distributed training step."""
    x, y = batch
    
    with mesh:
        # Shard parameters and data
        param_spec = P('devices', None)
        data_spec = P('devices', None)
        
        param_shardings = NamedSharding(mesh, param_spec)
        data_shardings = NamedSharding(mesh, data_spec)
        
        sharded_params = jax.device_put(params, param_shardings)
        sharded_x = jax.device_put(x, data_shardings)
        sharded_y = jax.device_put(y, data_shardings)
        
        # Define sharded computation
        @jax.jit
        def sharded_update(params, x, y):
            # Compute local gradients
            grads = grad_fn(params, x, y)
            
            # Average gradients across devices
            grads = jax.lax.pmean(grads, axis_name='devices')
            
            # Update parameters
            new_params = params - learning_rate * grads
            
            # Compute loss (just for monitoring)
            loss = loss_fn(params, x, y)
            avg_loss = jax.lax.pmean(loss, axis_name='devices')
            
            return new_params, avg_loss
        
        new_params, loss = sharded_update(sharded_params, sharded_x, sharded_y)
        return new_params, loss

# Example training loop
if jax.device_count() > 1:
    # Initialize model and data
    rng_key = jax.random.PRNGKey(42)
    model_params = create_model(rng_key)
    
    # Train for a few steps
    for step in range(5):
        # Create a new batch
        rng_key, subkey = jax.random.split(rng_key)
        batch = create_batch(subkey)
        
        # Update parameters
        model_params, loss = distributed_training_step(model_params, batch, mesh_1d)
        print(f"Step {step}, Loss: {loss:.6f}")
else:
    print("Multiple devices required for distributed training")

## 8. Conclusion

In this tutorial, we've explored JAX's powerful capabilities for distributed computation using sharded matrices and collective operations. We've covered:

1. **Setting up a device mesh** for organizing available devices
2. **Creating and using sharded matrices** to distribute data across devices
3. **Different sharding strategies** and when to use them
4. **Collective operations** for efficient device-to-device communication
5. **Implementing distributed algorithms** using these primitives
6. **Automatic sharding** with pjit for easier distributed programming
7. **Application to distributed training** for machine learning models

JAX's sharding capabilities enable efficient scaling of numerical computations across multiple devices, making it a powerful tool for large-scale machine learning and scientific computing.

### Further Resources

- [JAX Documentation](https://jax.readthedocs.io/)
- [JAX Docs: Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
- [JAX Docs: Manual parallelism with `shard_map`](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
- [How to Scale Your Model: Sharded Matrices and How to Multiply Them](https://jax-ml.github.io/scaling-book/sharding/)