# Data Parallel and Fully Sharded Data Parallel Training

In the previous tutorial, we explored the basics of JAX parallelization, including device meshes, sharded matrices, and collective operations. In this tutorial, we'll build on those concepts to implement data parallel (DP) and fully sharded data parallel (FSDP) training for neural networks.

We'll cover:
1. 8-way data parallel training with plain JAX
2. Re-implementing the same with Flax NNX
3. Fully sharded data parallel (FSDP) training with Flax NNX

By the end of this tutorial, you'll understand how to scale neural network training using parallelization techniques and how to leverage Flax's high-level abstractions for efficient distributed training.

## Setup

Let's start by importing the necessary libraries and initializing our environment.

## Theory: Understanding Parallelism Strategies

Before diving into the implementation, let's understand the theoretical foundations of different parallelism strategies. This section draws from \"How to Scale Your Model\" by Jacob Austin and others.

### Model Representation

For simplicity, we approximate a Transformer as a stack of MLP blocks, since attention is a comparatively small fraction of the FLOPs for larger models. Each layer consists of two main operations:

![Simple Transformer](https://github.com/jax-ml/scaling-book/blob/main/assets/img/simple-transformer.png?raw=true)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

We treat each layer as:
- **W<sub>in</sub>**: `bf16[D, F]` (up-projection) 
- **W<sub>out</sub>**: `bf16[F, D]` (down-projection)
- **Input**: `bf16[B, D]`

Where:
- **D** = d<sub>model</sub> (hidden dimension)
- **F** = d<sub>ff</sub> (feed-forward dimension) 
- **B** = batch size (total tokens)

### Communication vs Computation Trade-offs

The goal of scaling is to achieve **strong scaling**: linear increase in throughput with more chips. Performance depends on hiding inter-chip communication by overlapping it with useful FLOPs. 

We become **compute-bound** when:
$$\frac{T_{\text{math}}}{T_{\text{comms}}} > 1$$

The key insight is that computation time scales with batch size, while communication time is often independent of batch size (since we transfer model weights)."

In [None]:
import os
# Force JAX to see 8 devices for this tutorial
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
import numpy as np
import matplotlib.pyplot as plt
import time
from functools import partial

# Import our custom modules
from jax_mlp import (
    init_mlp_params, mlp_forward, mse_loss, train_step,
    generate_synthetic_data, train_model
)
from nnx_mlp import MLP, mse_loss as nnx_mse_loss, train_step as nnx_train_step

# Check available devices
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()[:4]}...")
print(f"Number of devices: {jax.device_count()}")

## 1. 8-way Data Parallel Training with Plain JAX

Data parallelism is a strategy where we replicate the model across multiple devices and shard the data batch. Each device processes a portion of the batch using its copy of the model, and then we aggregate the gradients across all devices.

Let's implement 8-way data parallel training, following the approach in `jax_data_parallel.py`.

### Data Parallelism Theory

**Definition**: Activations sharded along batch dimension, parameters replicated on each device. Communication only occurs during the backward pass.

**Mathematical representation**:
$$\text{In}[B_X, D] \cdot_D W_{\text{in}}[D, F] \cdot_F W_{\text{out}}[F, D] \rightarrow \text{Out}[B_X, D]$$

where $B_X$ indicates the batch is sharded across $X$ devices.

![Data Parallelism](https://github.com/jax-ml/scaling-book/blob/main/assets/img/data-parallelism.png?raw=true)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

**Key properties**:
- Forward pass requires **no communication** 
- Backward pass requires **AllReduce on gradients**
- Model parameters and optimizer states are fully replicated
- Memory usage scales with number of devices

**Algorithm**:

**Forward pass:**
1. Tmp[B<sub>X</sub>, F] = In[B<sub>X</sub>, D] ×<sub>D</sub> W<sub>in</sub>[D, F]
2. Out[B<sub>X</sub>, D] = Tmp[B<sub>X</sub>, F] ×<sub>F</sub> W<sub>out</sub>[F, D]

**Backward pass:**
1. dW<sub>out</sub>[F, D] = **AllReduce**(Tmp[B<sub>X</sub>, F] ×<sub>B</sub> dOut[B<sub>X</sub>, D])
2. dW<sub>in</sub>[D, F] = **AllReduce**(In[B<sub>X</sub>, D] ×<sub>B</sub> dTmp[B<sub>X</sub>, F])

**When do we become communication-bound?**

For TPUv5p with $C = 4.6 \times 10^{14}$ FLOPs/s and $W = 2 \times 9 \times 10^{10}$ bytes/s:

$$\frac{B}{X} > \frac{C}{W_{\text{ici}}} = 2550$$

So our **batch size per chip must be at least 2,550** to avoid being communication-bound with 1D data parallelism.

**Limitations**: Largest model we can train has approximately $\text{HBM per device} / 10$ parameters (≈9B for TPUv5p with Adam optimizer)."

### 1.1 Generate Training Data

First, let's generate our synthetic dataset.

In [None]:
# Generate training data
seed = 42
key = jax.random.PRNGKey(seed)
num_samples = 10000
input_dim = 784  # Similar to MNIST input dimension

x_data, y_data = generate_synthetic_data(num_samples, input_dim, key)

print(f"Data shape: x={x_data.shape}, y={y_data.shape}")
print(f"Example: x[0] sum = {jnp.sum(x_data[0]**2):.4f}, y[0] = {y_data[0][0]:.4f}")

### 1.2 Single-Device Baseline

Let's first establish a baseline by training on a single device.

In [None]:
# Model architecture - larger model to demonstrate parallelization benefits
layer_sizes = [input_dim, 8192, 8192, 8192, 10]
batch_size = 8192  # Large batch size as in the reference
learning_rate = 1e-5
num_epochs = 5

# Initialize model parameters
init_key = jax.random.PRNGKey(123)
params = init_mlp_params(layer_sizes, init_key)

# Place everything on a single device
single_device = jax.devices()[0]
params_single = jax.device_put(params, single_device)
x_data_single = jax.device_put(x_data, single_device)
y_data_single = jax.device_put(y_data, single_device)

# Train on single device
print("Training on single device...")
params_single, losses_single, time_single = train_model(
    params_single, x_data_single, y_data_single,
    learning_rate, batch_size, num_epochs, key
)
print(f"Single-device training time: {time_single:.2f} seconds")

### 1.3 8-way Data Parallel Training

Now let's implement 8-way data parallel training where we'll shard the batch across 8 devices.

In [None]:
# Create an 8-device mesh for data parallelism
devices = jax.devices()
mesh = jax.make_mesh((8,), ('batch',))
print(f"Mesh shape: {mesh.shape}")
print(f"Mesh axis names: {mesh.axis_names}")

# Create sharding specifications
# Replicate parameters across all devices
replicated_sharding = NamedSharding(mesh, P())
# Shard data along the batch dimension
batch_sharding = NamedSharding(mesh, P('batch'))

In [None]:
# Define data-parallel training step
@jax.jit
def dp_train_step(params, x_batch, y_batch, learning_rate):
    """Data-parallel training step.
    
    This function will be executed on each device with its shard of the data.
    JAX automatically handles gradient aggregation across devices.
    """
    # Compute loss and gradients
    loss_value, grads = jax.value_and_grad(mse_loss)(params, x_batch, y_batch)
    
    # Update parameters using gradient descent
    new_params = [(w - learning_rate * dw, b - learning_rate * db)
                 for (w, b), (dw, db) in zip(params, grads)]
    
    return new_params, loss_value

# Place parameters and data with appropriate sharding
params_dp = jax.device_put(params, replicated_sharding)

# Ensure batch size is divisible by number of devices
assert batch_size % 8 == 0, "Batch size must be divisible by number of devices"

# Training loop
print("
Training with 8-way data parallelism...")
losses_dp = []
start_time = time.time()

for epoch in range(num_epochs):
    # Shuffle data
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, num_samples)
    x_shuffled = x_data[perm]
    y_shuffled = y_data[perm]
    
    epoch_losses = []
    steps_per_epoch = num_samples // batch_size
    
    for step in range(steps_per_epoch):
        # Get batch
        idx_start = step * batch_size
        idx_end = idx_start + batch_size
        x_batch = x_shuffled[idx_start:idx_end]
        y_batch = y_shuffled[idx_start:idx_end]
        
        # Shard the batch across devices
        x_batch = jax.device_put(x_batch, batch_sharding)
        y_batch = jax.device_put(y_batch, batch_sharding)
        
        # Perform training step
        params_dp, loss = dp_train_step(params_dp, x_batch, y_batch, learning_rate)
        epoch_losses.append(float(loss))
    
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    losses_dp.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
time_dp = end_time - start_time
print(f"
Data-parallel training time: {time_dp:.2f} seconds")
print(f"Speedup: {time_single / time_dp:.2f}x")

### 1.4 Visualizing Data Sharding

Let's visualize how the batch is sharded across devices.

In [None]:
# Create a small example batch to visualize sharding
example_batch = jnp.arange(32).reshape(8, 4)  # 8 samples, 4 features
sharded_batch = jax.device_put(example_batch, batch_sharding)

print("Visualizing batch sharding across 8 devices:")
print("Original batch shape:", example_batch.shape)
print("
Each device gets 1 sample:")
jax.debug.visualize_array_sharding(sharded_batch)

### Fully Sharded Data Parallelism (FSDP) Theory

**Definition**: Activations, weights, and optimizer states are sharded along batch dimension. Weights are gathered just-in-time before use.

**Mathematical representation**:
$$\text{In}[B_X, D] \cdot_D W_{\text{in}}[D_X, F] \cdot_F W_{\text{out}}[F, D_X] \rightarrow \text{Out}[B_X, D]$$

where both batch and weight dimensions are sharded across $X$ devices.

![FSDP](https://github.com/jax-ml/scaling-book/blob/main/assets/img/fsdp.png?raw=true)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

**Key properties**:
- **Drastically reduces per-device memory usage**
- Saves on backward pass FLOPs
- Decomposes AllReduce into AllGather + ReduceScatter
- Same communication cost as pure data parallelism
- Also called \"ZeRO sharding\" (ZeRO-3 shards parameters, gradients, and optimizer states)

**Algorithm**:

**Forward pass:**
1. W<sub>in</sub>[D, F] = **AllGather**(W<sub>in</sub>[D<sub>X</sub>, F])
2. Tmp[B<sub>X</sub>, F] = In[B<sub>X</sub>, D] ×<sub>D</sub> W<sub>in</sub>[D, F]
3. W<sub>out</sub>[F, D] = **AllGather**(W<sub>out</sub>[F, D<sub>X</sub>])
4. Out[B<sub>X</sub>, D] = Tmp[B<sub>X</sub>, F] ×<sub>F</sub> W<sub>out</sub>[F, D]

**Backward pass:**
1. dW<sub>out</sub>[F, D<sub>X</sub>] = **ReduceScatter**(Tmp[B<sub>X</sub>, F] ×<sub>B</sub> dOut[B<sub>X</sub>, D])
2. dW<sub>in</sub>[D<sub>X</sub>, F] = **ReduceScatter**(dTmp[B<sub>X</sub>, F] ×<sub>B</sub> In[B<sub>X</sub>, D])

**Communication Analysis**:

FSDP has the **same roofline as pure data parallelism** because:
- AllReduce = AllGather + ReduceScatter
- Total communication volume is identical
- Same condition: $\frac{B}{X} > \frac{C}{W_{\text{ici}}} = 2550$

**Benefits**:
- Memory reduction: Parameters and optimizer states sharded across devices
- Zero overhead: Same FLOPs-to-communication ratio
- Can upgrade from data parallelism without performance loss
- Essential for models > 9B parameters"

## 2. Data Parallel Training with Flax NNX

Now let's implement the same 8-way data parallel training using Flax NNX, which provides higher-level abstractions.

In [None]:
# Import Flax NNX
try:
    import flax.nnx as nnx
    import optax
except ImportError:
    !pip install -q flax optax
    import flax.nnx as nnx
    import optax

In [None]:
# Create model and optimizer
rngs = nnx.Rngs(0)
model = MLP(layer_sizes, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate))

# Replicate model and optimizer state across devices
state = nnx.state((model, optimizer))
state = jax.device_put(state, replicated_sharding)
nnx.update((model, optimizer), state)

# Visualize model parameter sharding (should be replicated)
print("Model parameter sharding (replicated across all devices):")
jax.debug.visualize_array_sharding(model.layers[0].kernel.value)

In [None]:
# Define the training step for Flax NNX
@nnx.jit
def nnx_dp_train_step(model: MLP, optimizer: nnx.Optimizer, x_batch, y_batch):
    """Data-parallel training step using Flax NNX."""
    def loss_fn(model):
        predictions = model(x_batch)
        return jnp.mean((predictions - y_batch) ** 2)
    
    loss_value, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    return loss_value

# Training loop with Flax NNX
print("
Training with Flax NNX 8-way data parallelism...")
losses_nnx_dp = []
start_time = time.time()

for epoch in range(num_epochs):
    # Shuffle data
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, num_samples)
    x_shuffled = x_data[perm]
    y_shuffled = y_data[perm]
    
    epoch_losses = []
    steps_per_epoch = num_samples // batch_size
    
    for step in range(steps_per_epoch):
        # Get batch
        idx_start = step * batch_size
        idx_end = idx_start + batch_size
        x_batch = x_shuffled[idx_start:idx_end]
        y_batch = y_shuffled[idx_start:idx_end]
        
        # Shard the batch across devices
        x_batch = jax.device_put(x_batch, batch_sharding)
        y_batch = jax.device_put(y_batch, batch_sharding)
        
        # Visualize data sharding for first step
        if epoch == 0 and step == 0:
            print("
Data sharding (batch distributed across devices):")
            jax.debug.visualize_array_sharding(x_batch)
        
        # Perform training step
        loss = nnx_dp_train_step(model, optimizer, x_batch, y_batch)
        epoch_losses.append(float(loss))
    
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    losses_nnx_dp.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
time_nnx_dp = end_time - start_time
print(f"
Flax NNX data-parallel training time: {time_nnx_dp:.2f} seconds")
print(f"Speedup over single device: {time_single / time_nnx_dp:.2f}x")

## 3. Fully Sharded Data Parallel (FSDP) Training with Flax NNX

Now let's implement FSDP where we shard both the data and model parameters across devices. This is especially useful for large models that don't fit on a single device.

### Understanding FSDP Memory Benefits

The diagram below shows how FSDP reduces memory usage compared to pure data parallelism:

![FSDP Memory Comparison](https://github.com/jax-ml/scaling-book/blob/main/assets/img/fsdp-figure.png?raw=true)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

The rows show:
1. **Pure Data Parallelism**: Parameters and optimizer states fully replicated
2. **ZeRO-1**: Optimizer states sharded
3. **ZeRO-2**: Optimizer states and gradients sharded  
4. **ZeRO-3 (FSDP)**: Parameters, gradients, and optimizer states all sharded

**Why FSDP matters**: Standard data parallelism involves significant duplicated work and memory. With FSDP:
- Each device only stores 1/N of the parameters
- Each device only updates 1/N of the optimizer state
- AllGather parameters as needed for forward pass
- ReduceScatter gradients for efficient updates

This enables training much larger models that wouldn't fit with pure data parallelism."

In [None]:
# Create a 2D mesh for FSDP
# We'll use 2 devices for data parallelism and 4 for model parallelism
fsdp_mesh = jax.sharding.Mesh(
    mesh_utils.create_device_mesh((2, 4)),
    ('data', 'model')
)
print(f"FSDP mesh shape: {fsdp_mesh.shape}")
print(f"FSDP mesh axis names: {fsdp_mesh.axis_names}")

In [None]:
# Define a custom MLP with explicit parameter sharding for FSDP
import dataclasses
from typing import Optional

@dataclasses.dataclass(unsafe_hash=True)
class MeshRules:
    """Rules for how to shard different parts of the model."""
    input_dim: Optional[str] = None
    output_dim: Optional[str] = 'model'
    bias: Optional[str] = 'model'
    
mesh_rules = MeshRules()

class FSDP_MLP(nnx.Module):
    """MLP with explicit parameter sharding for FSDP."""
    
    def __init__(self, layer_sizes, rngs: nnx.Rngs):
        self.layer_sizes = layer_sizes
        
        # Create layers with sharded parameters
        self.weights = []
        self.biases = []
        
        for i in range(len(layer_sizes) - 1):
            in_dim, out_dim = layer_sizes[i], layer_sizes[i + 1]
            
            # Initialize weight with sharding spec
            w = nnx.Param(
                nnx.initializers.lecun_normal()(rngs.params(), (in_dim, out_dim)),
                sharding=(mesh_rules.input_dim, mesh_rules.output_dim)
            )
            
            # Initialize bias with sharding spec
            # For the last layer with output_dim=1, we don't shard
            if out_dim == 1:
                b = nnx.Param(
                    jnp.zeros((out_dim,)),
                    sharding=(None,)  # Don't shard single-element bias
                )
            else:
                b = nnx.Param(
                    jnp.zeros((out_dim,)),
                    sharding=(mesh_rules.bias,)
                )
            
            self.weights.append(w)
            self.biases.append(b)
    
    def __call__(self, x):
        activations = x
        
        # Apply each layer
        for i in range(len(self.weights) - 1):
            activations = jnp.dot(activations, self.weights[i]) + self.biases[i]
            activations = jax.nn.relu(activations)
        
        # Last layer
        return jnp.dot(activations, self.weights[-1]) + self.biases[-1]

In [None]:
# Custom SGD optimizer for FSDP (based on flax_nnx_dp_fsdp.py)
class SGDState(nnx.Variable):
    pass

class SGD(nnx.Object):
    def __init__(self, params: nnx.State, lr, decay=0.9):
        def init_optimizer_state(variable: nnx.Variable):
            return SGDState(
                jnp.zeros_like(variable.value), **variable.get_metadata()
            )
        
        self.lr = lr
        self.params = params
        self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
        self.decay = decay
    
    def update(self, grads: nnx.State):
        def update_fn(params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState):
            # Momentum update
            momentum.value = self.decay * momentum + (1 - self.decay) * grad.value
            # Parameter update
            params.value -= self.lr * momentum
        
        jax.tree.map(update_fn, self.params, self.momentum, grads)

### Mixed FSDP and Tensor Parallelism

For even larger models or smaller batch sizes, we can combine FSDP with tensor parallelism using a 2D mesh:

![Mixed FSDP and Tensor Parallelism](https://github.com/jax-ml/scaling-book/blob/main/assets/img/mixed-fsdp-model-parallelism.png?raw=true)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

**Mathematical representation**:
$$\text{In}[B_X, D_Y] \cdot_D W_{\text{in}}[D_X, F_Y] \cdot_F W_{\text{out}}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y]$$

where:
- $X$ = data/FSDP parallelism dimension
- $Y$ = model/tensor parallelism dimension

**Key insight**: 
- **FSDP moves weights** (communication scales with weight size)
- **Tensor parallelism moves activations** (communication scales with activation size)
- As batch size shrinks, activations get smaller → tensor parallelism becomes cheaper
- As we add more tensor parallelism, weight gathers get smaller → FSDP becomes cheaper

**Optimal sharding**: For $N = X \times Y$ total chips:
$$X_{\text{opt}} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N}$$

where $M_X$ and $M_Y$ are the number of mesh axes for FSDP and tensor parallelism respectively.

**Communication bound condition**:
$$\frac{B}{N} > \frac{4\alpha^2}{M_X M_Y F}$$

where $\alpha = C/W_{\text{ici}} = 2550$ for TPUv5p.

This allows batch sizes as low as ~400 tokens per chip, roughly 2× better than pure FSDP."

In [None]:
# Create and initialize FSDP model
@nnx.jit
def create_fsdp_model():
    # Use smaller layer sizes for FSDP demo to avoid memory issues
    fsdp_layer_sizes = [input_dim, 1024, 512, 1]
    model = FSDP_MLP(fsdp_layer_sizes, rngs=nnx.Rngs(0))
    optimizer = SGD(nnx.variables(model, nnx.Param), learning_rate, decay=0.9)
    
    # Get sharding specifications for the state
    state = nnx.state(optimizer)
    
    def get_named_shardings(path: tuple, value: nnx.VariableState):
        if hasattr(value, 'sharding') and value.sharding is not None:
            return value.replace(NamedSharding(fsdp_mesh, P(*value.sharding)))
        return value
    
    named_shardings = nnx.map_state(get_named_shardings, state)
    sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
    nnx.update(optimizer, sharded_state)
    
    return model, optimizer

print("Creating FSDP model...")
fsdp_model, fsdp_optimizer = create_fsdp_model()

# Visualize parameter sharding
print("
Weight sharding (distributed across model axis):")
jax.debug.visualize_array_sharding(fsdp_model.weights[0].value)
print("
Momentum sharding:")
jax.debug.visualize_array_sharding(fsdp_optimizer.momentum['weights'][0].value)

In [None]:
# FSDP training step
@nnx.jit
def fsdp_train_step(model: FSDP_MLP, optimizer: SGD, x_batch, y_batch):
    """FSDP training step."""
    def loss_fn(model):
        predictions = model(x_batch)
        return jnp.mean((predictions - y_batch) ** 2)
    
    loss_value, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    return loss_value

# Training with FSDP
print("
Training with FSDP...")
losses_fsdp = []
start_time = time.time()

# Use smaller batch size for FSDP demo
fsdp_batch_size = 256
steps_per_epoch = num_samples // fsdp_batch_size

for epoch in range(num_epochs):
    # Shuffle data
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, num_samples)
    x_shuffled = x_data[perm]
    y_shuffled = y_data[perm]
    
    epoch_losses = []
    
    for step in range(steps_per_epoch):
        # Get batch
        idx_start = step * fsdp_batch_size
        idx_end = idx_start + fsdp_batch_size
        x_batch = x_shuffled[idx_start:idx_end]
        y_batch = y_shuffled[idx_start:idx_end]
        
        # Shard data across the data axis only
        data_sharding = NamedSharding(fsdp_mesh, P('data', None))
        x_batch = jax.device_put(x_batch, data_sharding)
        y_batch = jax.device_put(y_batch, data_sharding)
        
        # Perform training step
        loss = fsdp_train_step(fsdp_model, fsdp_optimizer, x_batch, y_batch)
        epoch_losses.append(float(loss))
    
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    losses_fsdp.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

end_time = time.time()
time_fsdp = end_time - start_time
print(f"
FSDP training time: {time_fsdp:.2f} seconds")

## 4. Performance Comparison

Let's compare the performance of all training approaches.

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

# Plot loss curves
plt.subplot(1, 2, 1)
plt.plot(losses_single, label='Single Device', linewidth=2)
plt.plot(losses_dp, label='JAX Data Parallel (8-way)', linewidth=2)
plt.plot(losses_nnx_dp, label='Flax NNX Data Parallel (8-way)', linewidth=2)
plt.plot(losses_fsdp, label='Flax NNX FSDP', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot speedup comparison
plt.subplot(1, 2, 2)
methods = ['Single
Device', 'JAX
DP', 'Flax NNX
DP', 'FSDP']
times = [time_single, time_dp, time_nnx_dp, time_fsdp]
speedups = [time_single / t for t in times]
colors = ['gray', 'blue', 'green', 'red']

bars = plt.bar(methods, speedups, color=colors, alpha=0.7)
plt.axhline(y=1, color='black', linestyle='--', alpha=0.5)
plt.ylabel('Speedup (relative to single device)')
plt.title('Training Speedup Comparison')
plt.grid(True, alpha=0.3, axis='y')

# Add speedup values on bars
for bar, speedup in zip(bars, speedups):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
             f'{speedup:.2f}x', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Print summary
print("
=== Performance Summary ===")
print(f"Single Device: {time_single:.2f}s")
print(f"JAX Data Parallel (8-way): {time_dp:.2f}s (speedup: {time_single/time_dp:.2f}x)")
print(f"Flax NNX Data Parallel (8-way): {time_nnx_dp:.2f}s (speedup: {time_single/time_nnx_dp:.2f}x)")
print(f"FSDP: {time_fsdp:.2f}s (speedup: {time_single/time_fsdp:.2f}x)")

### Communication vs Computation Analysis

The plot below shows how different parallelism strategies perform as batch size changes:

![Communication vs FLOPs Analysis](https://github.com/jax-ml/scaling-book/blob/main/assets/img/mixed-fsdp-comms-2.png?raw=true)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

Key observations:
- **Pure Data Parallelism/FSDP** (green): Ratio scales linearly with batch size. Best for large batches.
- **Tensor Parallelism** (blue): Fixed ratio independent of batch size. Limited by feed-forward dimension.
- **Mixed FSDP + Tensor Parallelism** (red): Ratio scales with √B. Optimal for intermediate batch sizes.

**The horizontal line at ratio = 1** marks the boundary between compute-bound (above) and communication-bound (below) regimes.

**Practical implications**:
- Large batch sizes (>850 per chip): Use pure FSDP
- Intermediate batch sizes (400-850 per chip): Use mixed FSDP + tensor parallelism
- Very small batch sizes: May become communication-bound regardless

This analysis helps choose the right parallelism strategy based on your model size, batch size, and available hardware."

### Summary: Parallelism Strategies Comparison

Here's a comprehensive comparison of the parallelism strategies covered:

| **Strategy** | **Mathematical Formula** | **Communication** | **Memory** | **Best Use Case** |
|--------------|---------------------------|-------------------|------------|-------------------|
| **Data Parallelism** | In[B<sub>X</sub>, D] ⋅ W[D, F] → Out[B<sub>X</sub>, D] | AllReduce gradients (backward only) | Parameters replicated | Model fits on single device |
| **FSDP** | In[B<sub>X</sub>, D] ⋅ W[D<sub>X</sub>, F] → Out[B<sub>X</sub>, D] | AllGather weights + ReduceScatter grads | Parameters sharded | Large models, memory constraints |
| **Tensor Parallelism** | In[B, D<sub>Y</sub>] ⋅ W[D, F<sub>Y</sub>] → Out[B, D<sub>Y</sub>] | AllGather activations + ReduceScatter | Weights sharded by FF dim | Small batch sizes |
| **Mixed FSDP + TP** | In[B<sub>X</sub>, D<sub>Y</sub>] ⋅ W[D<sub>X</sub>, F<sub>Y</sub>] → Out[B<sub>X</sub>, D<sub>Y</sub>] | Both weight and activation movement | Both sharding strategies | Very large models, small batches |

**Communication Bound Conditions** (for TPUv5p):
- **Data Parallelism & FSDP**: Batch size per chip > 2,550
- **Tensor Parallelism**: F > Y × 2,550 (typically 8-16 way max)
- **Mixed FSDP + TP**: Batch size per chip > 400 (optimal combination)

**Key Insight**: The choice of parallelism strategy depends on the interplay between:
1. **Model size** (does it fit on one device?)
2. **Batch size** (how much data per device?)
3. **Hardware constraints** (memory, bandwidth, number of devices)

For this tutorial's examples, we demonstrated the progression from simple data parallelism to more sophisticated FSDP approaches, showing how each strategy addresses different scaling challenges."

## 5. Key Takeaways

In this tutorial, we've explored different approaches to parallel training:

1. **8-way Data Parallel with Plain JAX**
   - Simple to implement using `jax.jit` and sharding specifications
   - Model is replicated across all devices
   - Data is sharded along the batch dimension
   - Good speedup for compute-intensive workloads

2. **Data Parallel with Flax NNX**
   - Higher-level API makes implementation cleaner
   - Same performance as plain JAX implementation
   - Better for complex models with many components

3. **Fully Sharded Data Parallel (FSDP) with Flax NNX**
   - Shards both model parameters and data
   - Uses a 2D mesh (data × model axes)
   - Essential for large models that don't fit on single device
   - More complex but enables training of massive models

### When to Use Each Approach:

- **Data Parallel**: When your model fits on a single device and you want to scale training speed
- **FSDP**: When your model is too large for a single device or you need maximum memory efficiency

### References:
- Example implementations: `jax_data_parallel.py` and `flax_nnx_dp_fsdp.py` in the root directory
- JAX documentation on [distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html)
- Flax NNX documentation on [distributed training](https://flax.readthedocs.io/en/latest/nnx/index.html)