# Data Parallel and Fully Sharded Data Parallel Training

In the previous tutorial, we explored the basics of JAX parallelization, including device meshes, sharded matrices, and collective operations. In this tutorial, we'll build on those concepts to implement data parallel (DP) and fully sharded data parallel (FSDP) training for neural networks.

We'll cover:
1. Data parallel training with plain JAX
2. Fully sharded data parallel training with plain JAX
3. Using Flax NNX API for data parallel training
4. Using Flax NNX API for FSDP training

By the end of this tutorial, you'll understand how to scale neural network training using parallelization techniques and how to leverage Flax's high-level abstractions for efficient distributed training.

## Setup

Let's start by importing the necessary libraries and initializing our environment.

In [None]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
import numpy as np
import matplotlib.pyplot as plt
import time
from functools import partial
from typing import Dict, Tuple, Any, List, Callable

# For visualizing sharding
from jax.experimental.shard_map import shard_map
from jax.experimental.pjit import pjit

# Check available devices
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Number of devices: {jax.device_count()}")

## 1. Data Parallel Training with Plain JAX

Data parallelism is a strategy where we replicate the model across multiple devices and shard the data batch. Each device processes a portion of the batch using its copy of the model, and then we aggregate the gradients across all devices.

Let's implement a simple neural network and train it using data parallelism in JAX.

### 1.1 Define a Simple Neural Network

First, we'll define a simple Multi-Layer Perceptron (MLP) with two hidden layers.

In [None]:
def init_mlp_params(layer_sizes, key):
    """Initialize parameters for a simple MLP."""
    keys = jax.random.split(key, len(layer_sizes))
    params = []
    
    for i in range(len(layer_sizes) - 1):
        in_dim, out_dim = layer_sizes[i], layer_sizes[i + 1]
        w_key, b_key = jax.random.split(keys[i])
        
        # Initialize weights with scaled normal distribution
        w = jax.random.normal(w_key, (in_dim, out_dim)) * jnp.sqrt(2.0 / in_dim)
        b = jnp.zeros((out_dim,))
        
        params.append((w, b))
    
    return params

def mlp_forward(params, x):
    """Forward pass through the MLP."""
    activations = x
    
    # Apply each layer with ReLU activation except the last layer
    for i, (w, b) in enumerate(params[:-1]):
        activations = jnp.dot(activations, w) + b
        activations = jax.nn.relu(activations)
    
    # Last layer without activation (for regression)
    w, b = params[-1]
    output = jnp.dot(activations, w) + b
    
    return output

def mse_loss(params, x_batch, y_batch):
    """Mean squared error loss function."""
    predictions = mlp_forward(params, x_batch)
    return jnp.mean((predictions - y_batch) ** 2)

### 1.2 Generate Synthetic Data

Let's create a synthetic dataset to train our model. We'll use a simple quadratic function with some noise.

In [None]:
def generate_data(num_samples, input_dim, key):
    """Generate synthetic data from a quadratic function with noise."""
    x_key, noise_key = jax.random.split(key)
    
    # Generate random input features
    x = jax.random.normal(x_key, (num_samples, input_dim))
    
    # Generate target values: sum of squares with noise
    y_clean = jnp.sum(x ** 2, axis=1, keepdims=True)
    noise = 0.1 * jax.random.normal(noise_key, (num_samples, 1))
    y = y_clean + noise
    
    return x, y

# Generate training data
seed = 42
key = jax.random.PRNGKey(seed)
num_samples = 10000
input_dim = 10
x_data, y_data = generate_data(num_samples, input_dim, key)

# Display a few examples
for i in range(5):
    print(f"x[{i}] = {x_data[i][:3]}...  =>  y[{i}] = {y_data[i][0]:.4f}")

### 1.3 Single-Device Training

First, let's implement and benchmark training on a single device to establish a baseline.

In [None]:
def train_step(params, x_batch, y_batch, learning_rate):
    """Perform a single training step using gradient descent."""
    # Compute loss and gradients
    loss_value, grads = jax.value_and_grad(mse_loss)(params, x_batch, y_batch)
    
    # Update parameters using gradient descent
    new_params = [(w - learning_rate * dw, b - learning_rate * db)
                 for (w, b), (dw, db) in zip(params, grads)]
    
    return new_params, loss_value

# JIT-compile the training step for better performance
train_step_jit = jax.jit(train_step)

# Initialize model parameters
layer_sizes = [input_dim, 32, 16, 1]  # [input_dim, hidden1, hidden2, output_dim]
init_key = jax.random.PRNGKey(123)
params = init_mlp_params(layer_sizes, init_key)

# Training hyperparameters
learning_rate = 0.01
batch_size = 128
num_epochs = 5
steps_per_epoch = num_samples // batch_size

# Training loop
losses = []
start_time = time.time()

for epoch in range(num_epochs):
    # Shuffle data at the beginning of each epoch
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, num_samples)
    x_shuffled = x_data[perm]
    y_shuffled = y_data[perm]
    
    epoch_losses = []
    
    for step in range(steps_per_epoch):
        # Get batch
        idx_start = step * batch_size
        idx_end = idx_start + batch_size
        x_batch = x_shuffled[idx_start:idx_end]
        y_batch = y_shuffled[idx_start:idx_end]
        
        # Perform training step
        params, loss = train_step_jit(params, x_batch, y_batch, learning_rate)
        epoch_losses.append(loss)
    
    # Compute average loss for this epoch
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
single_device_time = end_time - start_time
print(f"Single-device training time: {single_device_time:.2f} seconds")

### 1.4 Data Parallel Training

Now let's implement data parallel training where we'll shard the batch across multiple devices.

In [None]:
# Create a device mesh for data parallelism
devices = jax.devices()
num_devices = len(devices)
device_mesh = mesh_utils.create_device_mesh((num_devices,))
print(f"Device mesh shape: {device_mesh.shape}")

# Define partition specs for data-parallel training
# We'll shard the batch dimension (first dimension) across devices
batch_dim = 'batch'

def data_parallel_train_step(params, x_batch, y_batch, learning_rate):
    """Training step function for data parallel training."""
    # Compute per-device gradients
    def loss_fn(p):
        return mse_loss(p, x_batch, y_batch)
    
    loss_value, grads = jax.value_and_grad(loss_fn)(params)
    
    # Update parameters using gradient descent
    new_params = [(w - learning_rate * dw, b - learning_rate * db)
                 for (w, b), (dw, db) in zip(params, grads)]
    
    return new_params, loss_value

# Set up pjit for data-parallel training
# We'll shard the batch dimension of inputs and replicate the parameters
dp_train_step = pjit(
    data_parallel_train_step,
    in_shardings=(None, P(batch_dim), P(batch_dim), None),  # Replicate params, shard data
    out_shardings=(None, None)  # Replicate outputs
)

# Adjust batch size to be a multiple of the number of devices
global_batch_size = batch_size * num_devices
dp_steps_per_epoch = num_samples // global_batch_size

# Initialize parameters (same as before)
dp_params = init_mlp_params(layer_sizes, init_key)

# Training loop with data parallelism
dp_losses = []
start_time = time.time()

with Mesh(device_mesh, (batch_dim,)):
    for epoch in range(num_epochs):
        # Shuffle data at the beginning of each epoch
        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, num_samples)
        x_shuffled = x_data[perm]
        y_shuffled = y_data[perm]
        
        epoch_losses = []
        
        for step in range(dp_steps_per_epoch):
            # Get global batch
            idx_start = step * global_batch_size
            idx_end = idx_start + global_batch_size
            x_batch = x_shuffled[idx_start:idx_end]
            y_batch = y_shuffled[idx_start:idx_end]
            
            # Perform data-parallel training step
            dp_params, loss = dp_train_step(dp_params, x_batch, y_batch, learning_rate)
            epoch_losses.append(loss)
        
        # Compute average loss for this epoch
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        dp_losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
dp_time = end_time - start_time
print(f"Data-parallel training time: {dp_time:.2f} seconds")
print(f"Speedup: {single_device_time / dp_time:.2f}x")

### 1.5 Visualizing the Data Parallelism Strategy

Let's visualize how the data is sharded across devices in the data parallel training setup.

In [None]:
def visualize_device_placement(mesh_shape, batch_shape):
    """Visualize how a batch is distributed across devices in data parallelism."""
    num_devices = np.prod(mesh_shape)
    batch_size = batch_shape[0]
    samples_per_device = batch_size // num_devices
    
    # Create a grid to represent the devices
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Draw the global batch
    ax.add_patch(plt.Rectangle((0, 0), 10, 2, fill=False, edgecolor='black', linewidth=2))
    ax.text(5, 1, f"Global Batch (size={batch_size})", ha='center', va='center', fontsize=12)
    
    # Draw device allocations
    for i in range(num_devices):
        ax.add_patch(plt.Rectangle((i * (10/num_devices), 3), 10/num_devices, 1.5, 
                                   fill=True, edgecolor='black', 
                                   facecolor=plt.cm.tab10(i % 10),
                                   alpha=0.7, linewidth=1))
        ax.text(i * (10/num_devices) + (5/num_devices), 3.75, 
                f"Device {i}\n{samples_per_device} samples", 
                ha='center', va='center', fontsize=10)
        
        # Connect with arrows
        ax.arrow(i * (10/num_devices) + (5/num_devices), 3, 0, -0.5, 
                 head_width=0.2, head_length=0.2, fc='black', ec='black')
    
    # Add model replication indicators
    for i in range(num_devices):
        ax.add_patch(plt.Rectangle((i * (10/num_devices), 5), 10/num_devices, 1.5, 
                                   fill=True, edgecolor='black', 
                                   facecolor='lightgreen',
                                   alpha=0.7, linewidth=1))
        ax.text(i * (10/num_devices) + (5/num_devices), 5.75, 
                f"Model Copy\non Device {i}", 
                ha='center', va='center', fontsize=10)
    
    # Add title and labels
    ax.set_title("Data Parallel Training: Batch Sharding with Model Replication", fontsize=14)
    ax.set_xlim(-0.5, 10.5)
    ax.set_ylim(-0.5, 7)
    ax.set_axis_off()
    
    plt.tight_layout()
    plt.show()

# Visualize data parallelism for our setup
visualize_device_placement((num_devices,), (global_batch_size, input_dim))

## 2. Fully Sharded Data Parallel (FSDP) Training with Plain JAX

Now, let's implement Fully Sharded Data Parallel (FSDP) training. In FSDP, we shard both the data and the model parameters across devices. This approach helps scale to larger models that might not fit in a single device's memory.

### 2.1 Setting Up a 2D Mesh for FSDP

For FSDP, we'll use a 2D mesh to shard data along one dimension and model parameters along another.

In [None]:
# Try to create a 2D mesh if possible, otherwise use a 1D mesh
if num_devices >= 4:
    # Create a 2D mesh (e.g., 2x4 for 8 devices)
    mesh_shape = (2, num_devices // 2)
else:
    # Fall back to a 1D mesh for fewer devices
    mesh_shape = (1, num_devices)

fsdp_device_mesh = mesh_utils.create_device_mesh(mesh_shape)
print(f"FSDP device mesh shape: {fsdp_device_mesh.shape}")

# Define mesh axes names
data_axis = 'data'  # For sharding data
model_axis = 'model'  # For sharding model parameters

### 2.2 Implementing FSDP Training

We'll modify our model representation to support parameter sharding and implement the FSDP training loop.

In [None]:
def fsdp_train_step(params, x_batch, y_batch, learning_rate):
    """Training step function for FSDP training."""
    # Compute loss and gradients
    loss_value, grads = jax.value_and_grad(mse_loss)(params, x_batch, y_batch)
    
    # Update parameters using gradient descent
    new_params = [(w - learning_rate * dw, b - learning_rate * db)
                 for (w, b), (dw, db) in zip(params, grads)]
    
    return new_params, loss_value

# Define parameter sharding strategy for each layer
def get_param_sharding_specs(layer_sizes):
    """Create parameter partition specs for FSDP."""
    param_specs = []
    
    for i in range(len(layer_sizes) - 1):
        # Weights are sharded along the model axis (second dimension)
        # For larger models, you could shard both dimensions
        w_spec = P(None, model_axis)  # Shard columns (output dim)
        b_spec = P(model_axis)  # Shard bias
        
        param_specs.append((w_spec, b_spec))
    
    return param_specs

# Get parameter sharding specs
param_specs = get_param_sharding_specs(layer_sizes)

# Set up pjit for FSDP training
# Shard batch across data axis and parameters according to param_specs
fsdp_train_step_pjit = pjit(
    fsdp_train_step,
    in_shardings=(param_specs, P(data_axis, None), P(data_axis, None), None),
    out_shardings=(param_specs, None)
)

# Initialize parameters for FSDP
fsdp_params = init_mlp_params(layer_sizes, init_key)

# Training loop with FSDP
fsdp_losses = []
start_time = time.time()

with Mesh(fsdp_device_mesh, (data_axis, model_axis)):
    for epoch in range(num_epochs):
        # Shuffle data at the beginning of each epoch
        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, num_samples)
        x_shuffled = x_data[perm]
        y_shuffled = y_data[perm]
        
        epoch_losses = []
        
        for step in range(dp_steps_per_epoch):  # Use same steps as DP for fair comparison
            # Get batch
            idx_start = step * global_batch_size
            idx_end = idx_start + global_batch_size
            x_batch = x_shuffled[idx_start:idx_end]
            y_batch = y_shuffled[idx_start:idx_end]
            
            # Perform FSDP training step
            fsdp_params, loss = fsdp_train_step_pjit(fsdp_params, x_batch, y_batch, learning_rate)
            epoch_losses.append(loss)
        
        # Compute average loss for this epoch
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        fsdp_losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
fsdp_time = end_time - start_time
print(f"FSDP training time: {fsdp_time:.2f} seconds")
print(f"Speedup over single device: {single_device_time / fsdp_time:.2f}x")
print(f"Speedup over data parallelism: {dp_time / fsdp_time:.2f}x")

### 2.3 Visualizing FSDP Strategy

Let's visualize how both data and model parameters are sharded in the FSDP approach.

In [None]:
def visualize_fsdp(mesh_shape):
    """Visualize both data and model sharding in FSDP."""
    data_dim, model_dim = mesh_shape
    num_devices = data_dim * model_dim
    
    # Create a grid to represent the devices
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Draw the 2D mesh grid
    for i in range(data_dim + 1):
        ax.axhline(y=i, color='black', linestyle='-', linewidth=1)
    for j in range(model_dim + 1):
        ax.axvline(x=j, color='black', linestyle='-', linewidth=1)
    
    # Label each device in the mesh
    for i in range(data_dim):
        for j in range(model_dim):
            device_id = i * model_dim + j
            ax.text(j + 0.5, i + 0.5, f"Device {device_id}", 
                    ha='center', va='center', fontsize=12)
    
    # Add axis labels
    ax.text(model_dim / 2, -0.5, "Model Sharding Dimension", ha='center', fontsize=14)
    ax.text(-0.5, data_dim / 2, "Data Sharding\nDimension", va='center', fontsize=14, rotation=90)
    
    # Draw data sharding illustration
    ax.add_patch(plt.Rectangle((-3, 0), 2, data_dim, fill=True, 
                               edgecolor='black', facecolor='lightblue', alpha=0.7))
    for i in range(data_dim):
        ax.text(-2, i + 0.5, f"Data Shard {i}", ha='center', va='center', fontsize=10)
        ax.arrow(-1.5, i + 0.5, 1, 0, head_width=0.2, head_length=0.2, fc='black', ec='black')
    
    # Draw model sharding illustration
    ax.add_patch(plt.Rectangle((0, -3), model_dim, 2, fill=True, 
                               edgecolor='black', facecolor='lightgreen', alpha=0.7))
    for j in range(model_dim):
        ax.text(j + 0.5, -2, f"Model\nShard {j}", ha='center', va='center', fontsize=10)
        ax.arrow(j + 0.5, -1.5, 0, 1, head_width=0.2, head_length=0.2, fc='black', ec='black')
    
    # Set title
    ax.set_title("Fully Sharded Data Parallel (FSDP) Strategy", fontsize=16)
    
    # Set limits and turn off axis
    ax.set_xlim(-4, model_dim + 1)
    ax.set_ylim(-4, data_dim + 1)
    ax.set_aspect('equal')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize the FSDP strategy
visualize_fsdp(mesh_shape)

## 3. Data Parallel Training with Flax NNX

Now let's see how we can use Flax NNX to simplify distributed training. Flax provides high-level abstractions that make it easier to define and train neural networks in a distributed setting.

In [None]:
# Install Flax NNX if not already installed
try:
    import flax.nnx as nnx
except ImportError:
    !pip install -q flax
    import flax.nnx as nnx

### 3.1 Define a Flax MLP Model

Let's define our MLP using Flax NNX module system.

In [None]:
class MLP(nnx.Module):
    """A simple multi-layer perceptron using Flax NNX."""
    
    def __init__(self, layer_sizes, *, key):
        super().__init__()
        self.layer_sizes = layer_sizes
        
        # Create layers
        self.layers = []
        for i in range(len(layer_sizes) - 1):
            in_dim, out_dim = layer_sizes[i], layer_sizes[i + 1]
            k = jax.random.fold_in(key, i)
            dense = nnx.Linear(in_dim, out_dim, key=k)
            self.layers.append(dense)
    
    def __call__(self, x):
        """Forward pass through the MLP."""
        activations = x
        
        # Apply each layer with ReLU activation except the last layer
        for i, layer in enumerate(self.layers[:-1]):
            activations = layer(activations)
            activations = jax.nn.relu(activations)
        
        # Last layer without activation (for regression)
        return self.layers[-1](activations)

# Define a simple SGD optimizer
class SGD(nnx.Optimizer):
    """Simple SGD optimizer for Flax NNX."""
    
    def __init__(self, learning_rate=0.01):
        super().__init__()
        self.learning_rate = learning_rate
    
    def update(self, gradients, state):
        return state - self.learning_rate * gradients

# Define MSE loss function
def flax_mse_loss(model, x_batch, y_batch):
    """Mean squared error loss function for Flax model."""
    predictions = model(x_batch)
    return jnp.mean((predictions - y_batch) ** 2)

### 3.2 Data Parallel Training with Flax

In [None]:
# Initialize the model
init_key = jax.random.PRNGKey(123)
flax_model = MLP(layer_sizes, key=init_key)

# Create optimizer
optimizer = SGD(learning_rate=learning_rate)
opt_state = optimizer.init(flax_model)

# Define data-parallel training step
def flax_dp_train_step(model, opt_state, x_batch, y_batch):
    """Data-parallel training step using Flax NNX."""
    # Compute loss and gradients
    loss_value, gradients = jax.value_and_grad(flax_mse_loss)(model, x_batch, y_batch)
    
    # Update model parameters
    opt_state = optimizer.update(gradients, opt_state)
    
    return opt_state, loss_value

# JIT-compile the training step with data-parallel spec
flax_dp_train_step_jit = pjit(
    flax_dp_train_step,
    in_shardings=(None, None, P(batch_dim), P(batch_dim)),  # Replicate model, shard data
    out_shardings=(None, None)  # Replicate outputs
)

# Training loop with Flax data parallelism
flax_dp_losses = []
start_time = time.time()

with Mesh(device_mesh, (batch_dim,)):
    for epoch in range(num_epochs):
        # Shuffle data at the beginning of each epoch
        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, num_samples)
        x_shuffled = x_data[perm]
        y_shuffled = y_data[perm]
        
        epoch_losses = []
        
        for step in range(dp_steps_per_epoch):
            # Get global batch
            idx_start = step * global_batch_size
            idx_end = idx_start + global_batch_size
            x_batch = x_shuffled[idx_start:idx_end]
            y_batch = y_shuffled[idx_start:idx_end]
            
            # Perform data-parallel training step
            opt_state, loss = flax_dp_train_step_jit(flax_model, opt_state, x_batch, y_batch)
            epoch_losses.append(loss)
        
        # Compute average loss for this epoch
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        flax_dp_losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
flax_dp_time = end_time - start_time
print(f"Flax data-parallel training time: {flax_dp_time:.2f} seconds")
print(f"Speedup over single device: {single_device_time / flax_dp_time:.2f}x")

## 4. Fully Sharded Data Parallel Training with Flax NNX

Finally, let's implement FSDP using Flax NNX.

### 4.1 Defining FSDP Training with Flax

In [None]:
# Initialize the model for FSDP
flax_fsdp_model = MLP(layer_sizes, key=init_key)
fsdp_opt_state = optimizer.init(flax_fsdp_model)

# Define FSDP training step
def flax_fsdp_train_step(model, opt_state, x_batch, y_batch):
    """FSDP training step using Flax NNX."""
    # Compute loss and gradients
    loss_value, gradients = jax.value_and_grad(flax_mse_loss)(model, x_batch, y_batch)
    
    # Update model parameters
    opt_state = optimizer.update(gradients, opt_state)
    
    return opt_state, loss_value

# Create model parameter sharding rules for Flax module
def get_flax_module_sharding_rules(model):
    """Create sharding rules for Flax NNX module parameters."""
    rules = {}
    
    # For each layer, shard weights and biases
    for i, layer in enumerate(model.layers):
        # Shard weights along the output dimension (second axis)
        rules[f"layers/{i}/kernel"] = P(None, model_axis)
        # Shard bias along its only dimension
        rules[f"layers/{i}/bias"] = P(model_axis)
    
    return rules

# Get sharding rules for the model
model_rules = get_flax_module_sharding_rules(flax_fsdp_model)

# Set up FSDP pjit with module sharding rules
flax_fsdp_train_step_jit = pjit(
    flax_fsdp_train_step,
    in_shardings=(model_rules, model_rules, P(data_axis, None), P(data_axis, None)),
    out_shardings=(model_rules, None)
)

# Training loop with Flax FSDP
flax_fsdp_losses = []
start_time = time.time()

with Mesh(fsdp_device_mesh, (data_axis, model_axis)):
    for epoch in range(num_epochs):
        # Shuffle data at the beginning of each epoch
        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, num_samples)
        x_shuffled = x_data[perm]
        y_shuffled = y_data[perm]
        
        epoch_losses = []
        
        for step in range(dp_steps_per_epoch):
            # Get batch
            idx_start = step * global_batch_size
            idx_end = idx_start + global_batch_size
            x_batch = x_shuffled[idx_start:idx_end]
            y_batch = y_shuffled[idx_start:idx_end]
            
            # Perform FSDP training step
            fsdp_opt_state, loss = flax_fsdp_train_step_jit(
                flax_fsdp_model, fsdp_opt_state, x_batch, y_batch)
            epoch_losses.append(loss)
        
        # Compute average loss for this epoch
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        flax_fsdp_losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
flax_fsdp_time = end_time - start_time
print(f"Flax FSDP training time: {flax_fsdp_time:.2f} seconds")
print(f"Speedup over single device: {single_device_time / flax_fsdp_time:.2f}x")
print(f"Speedup over Flax data parallelism: {flax_dp_time / flax_fsdp_time:.2f}x")

## 5. Performance Comparison

Let's compare the performance of all training approaches we've implemented.

In [None]:
# Collect timing results
timings = {
    'Single Device': single_device_time,
    'Data Parallel (JAX)': dp_time,
    'FSDP (JAX)': fsdp_time,
    'Data Parallel (Flax)': flax_dp_time,
    'FSDP (Flax)': flax_fsdp_time
}

# Calculate speedups relative to single device
speedups = {k: single_device_time / v for k, v in timings.items()}

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Plot training times
methods = list(timings.keys())
times = list(timings.values())
ax1.bar(methods, times, color='skyblue')
ax1.set_ylabel('Training Time (seconds)')
ax1.set_title('Training Time Comparison')
ax1.set_xticklabels(methods, rotation=45, ha='right')
for i, v in enumerate(times):
    ax1.text(i, v + 0.1, f"{v:.2f}s", ha='center')

# Plot speedups
speeds = list(speedups.values())
ax2.bar(methods, speeds, color='lightgreen')
ax2.set_ylabel('Speedup (relative to single device)')
ax2.set_title('Speedup Comparison')
ax2.set_xticklabels(methods, rotation=45, ha='right')
ax2.axhline(y=1, color='red', linestyle='--')
for i, v in enumerate(speeds):
    ax2.text(i, v + 0.1, f"{v:.2f}x", ha='center')

plt.tight_layout()
plt.show()

## 6. Conclusion

In this tutorial, we've explored different approaches to parallel training of neural networks in JAX:

1. **Data Parallel (DP) Training**
   - Replicates the model across devices
   - Shards the data batch across devices
   - Good for small to medium-sized models
   - Limited by model size fitting on a single device

2. **Fully Sharded Data Parallel (FSDP) Training**
   - Shards both the model parameters and data across devices
   - Uses a 2D mesh to organize sharding dimensions
   - Better memory efficiency for large models
   - Can achieve higher throughput with sufficient devices

3. **Using Flax NNX for Distributed Training**
   - Provides higher-level abstractions for model definition
   - Simplifies the implementation of training loops
   - Integrates well with JAX's parallelization primitives
   - Offers a more user-friendly interface for complex models

Key takeaways:

- Data parallel training is straightforward to implement and provides good speedup for many scenarios
- FSDP is more complex but can handle larger models and potentially provide better scaling
- Flax NNX provides helpful abstractions that make distributed training more accessible
- The choice between DP and FSDP depends on your model size, device memory constraints, and scaling requirements

When scaling to very large models, like those used in NLP, FSDP becomes increasingly important as it allows you to train models that wouldn't fit in a single device's memory. This approach forms the foundation for training massive language models efficiently.