# Tutorial 3: Tensor Parallel and Transformers Scaling

[![Open in GitHub](https://img.shields.io/badge/Open%20in-GitHub-181717?style=flat-square&logo=github)](https://github.com/sshkhr/MinText/blob/main/docs/tutorials/3_Tensor_Parallel_and_Transformers.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sshkhr/MinText/blob/main/docs/tutorials/3_Tensor_Parallel_and_Transformers.ipynb)

In the previous tutorial, we learned about data parallelism and how to use it to shard data batches across devices. We also learned about Fully Sharded Data Parallel (FSDP) and how it can be used to shard model parameters, gradients and optimizer states across devices. In this part, we will cover tensor parallelism and how it can be used to shard model layers across devices. We will also learn how to use the different parallelism techniques together to scale up training of large transformer models.

### 0.1 Setup

Let's start by importing the necessary libraries and initializing our environment.

In [None]:
import os
# Force JAX to see 8 devices for this tutorial (only use if not using TPU runtime)
#os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
from flax import nnx
import numpy as np
import matplotlib.pyplot as plt
import time
from functools import partial
import dataclasses

In [None]:
# Check available devices
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()[:4]}...")
print(f"Number of devices: {jax.device_count()}")

JAX version: 0.5.2
Available devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)]...
Number of devices: 8


## 1. Tensor Parallelism

While fully-sharded data parallelism distributes model weights across different devices during the AllReduce operation, tensor parallelism takes a different approach. Also known as "1D model parallelism" or Megatron sharding, this technique shards the feedforward dimensions of individual model layers and distributes activations between devices during computation. This method enables smaller effective batch sizes per device, making it particularly useful for training very large models. The diagram below illustrates how a single matrix is partitioned across devices using this approach:

### 1.1 Tensor Parallelism Theory

**Sharding**: Model layer activations are sharded along tenso axes across devices, model parameters are replicated on each device.

**Equation** (for our MLP example):
$$\text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]$$

where $F_Y$ indicates the activations are sharded across $Y$ devices.

![Tensor Parallel](https://jax-ml.github.io/scaling-book/assets/img/model-parallelism-1400.webp)

<sup> Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) </sup>

### 1.2 Tensor Parallelism Algorithm

The computation pattern In[B, D_Y] * W_in[D, F_Y] * W_out[F_Y, D] → Out[B, D_Y] requires gathering activations prior to the initial matrix multiplication. This approach becomes more efficient than ZeRO sharding when activation sizes are smaller than weight sizes. However, this efficiency typically emerges only when combined with some degree of ZeRO sharding, which reduces the gather operation's overhead. This synergy explains why ZeRO sharding and model parallelism are commonly used together in practice.

**Forward pass:** need to compute Loss[B]

1.  In[B, D] = **AllGather**(In[B, D<sub>Y</sub>]) *(on critical path)*
2.  Tmp[B, F<sub>Y</sub>] = In[B, D] \*<sub>D</sub> W<sub>in</sub>[D, F<sub>Y</sub>] *(not sharded along contracting, so no comms)*
3.  Out[B, D] {U<sub>Y</sub>} = Tmp[B, F<sub>Y</sub>] \*<sub>F</sub> W<sub>out</sub>[F<sub>Y</sub>, D]
4.  Out[B, D<sub>Y</sub>] = **ReduceScatter**(Out[B, D] {U<sub>Y</sub>}) *(on critical path)*
5.  Loss[B] = ...

**Backward pass:** need to compute dW<sub>out</sub>[F<sub>Y</sub>, D], dW<sub>in</sub>[D, F<sub>Y</sub>]

1.  dOut[B, D<sub>Y</sub>] = ...
2.  dOut[B, D] = **AllGather**(dOut[B, D<sub>Y</sub>]) *(on critical path)*
3.  dW<sub>out</sub>[F<sub>Y</sub>, D] = Tmp[B, F<sub>Y</sub>] \*<sub>B</sub> dOut[B, D]
4.  dTmp[B, F<sub>Y</sub>] = dOut[B, D] \*<sub>D</sub> W<sub>out</sub>[F<sub>Y</sub>, D] *(can throw away dOut[B, D] here)*
5.  In[B, D] = **AllGather**(In[B, D<sub>Y</sub>]) *(this can be skipped by sharing with (1) from the forward pass)*
6.  dW<sub>in</sub>[D, F<sub>Y</sub>] = dTmp[B, F<sub>Y</sub>] \*<sub>B</sub> In[B, D]
7.  dIn[B, D] {U.Y} = dTmp[B, F<sub>Y</sub>] \*<sub>F</sub> W<sub>in</sub>[D, F<sub>Y</sub>] *(needed for previous layers)*
8.  dIn[B, D<sub>Y</sub>] = **ReduceScatter**(dIn[B, D] {U.Y}) *(on critical path)*

A key advantage of the two matrix operations in our MLP forward pass is that tensor parallelism integrates nicely with this setup. Without this optimization, we would need to perform an AllReduce operation after each matrix multiplication. However, the sequential computation In[B, D_Y] * W_in[D, F_Y] → Tmp[B, F_Y] followed by Tmp[B, F_Y] * W_out[F_Y, D] → Out[B, D_Y] allows us to perform a single AllGather on the input at the start and a single ReduceScatter on the output at the end, eliminating the need for intermediate AllReduce operations.

## 2. Combining Parallelism Techniques

In this section, we will combine data parallelism, FSDP, and tensor parallelism to implemnt distributed training of a simple MLP model.


### 2.1 Mesh Definition

In [None]:
# Assign logical names 'data' and 'model' to the axes of this grid.
# The first dimension (size 2) is named 'data'.
# The second dimension (size 4) is named 'model'.
mesh = jax.sharding.Mesh(
  mesh_utils.create_device_mesh((2, 4)),
  ('data', 'model'),
)

### 2.2 Sharding Helper Functions

In [None]:
# A dataclass to hold sharding rules for different parts of the model/data.
# Makes it easy to manage and change sharding strategies.
@dataclasses.dataclass(unsafe_hash=True)
class MeshRules:
  embed: str | None = None # Sharding rule for embedding-like dimensions
  mlp: str | None = None   # Sharding rule for MLP layers dimensions
  data: str | None = None  # Sharding rule for the data batch dimension

  # Allows calling the instance like `mesh_rules('embed', 'mlp')`
  # to get a tuple of the corresponding sharding rules.
  def __call__(self, *keys: str) -> tuple[str, ...]:
    return tuple(getattr(self, key) for key in keys)

# Create an instance of MeshRules defining the specific strategy:
# - 'embed' dimensions will be replicated (None).
# - 'mlp' dimensions will be sharded along the 'model' mesh axis.
# - 'data' dimensions will be sharded along the 'data' mesh axis.
mesh_rules = MeshRules(
  embed=None,
  mlp='model',
  data='data',
)

### 2.3 Define The Sharded Model

In [None]:
# Define the MLP using Flax NNX API.
class MLP(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.w1 = nnx.Param(
      nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
      # ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
      sharding=mesh_rules('embed', 'mlp'),
    )
    self.b1 = nnx.Param(
      jnp.zeros((dmid,)),
      # Sharding: ('mlp',) -> ('model',) -> Shard dim 0 along 'model' axis.
      sharding=mesh_rules('mlp'),
    )
    self.w2 = nnx.Param(
      nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
       # Sharding: ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
      sharding=mesh_rules('embed', 'mlp'),
    )

  def __call__(self, x: jax.Array):
    return nnx.relu(x @ self.w1 + self.b1) @ self.w2


### 2.4 Handling Sharded Optimizer State

In [None]:
# Define a custom type for SGD momentum state, inheriting from nnx.Variable.
# This allows it to be tracked as part of the NNX state tree.
class SGDState(nnx.Variable):
  pass

# Define the SGD optimizer using NNX API.
class SGD(nnx.Object):
  # Constructor takes the model parameters (as nnx.State), learning rate, and decay.
  def __init__(self, params: nnx.State, lr, decay=0.9):
    # Helper function to initialize momentum buffer for a given parameter.
    def init_optimizer_state(variable: nnx.Variable):
      # Create momentum state with zeros, same shape and metadata (incl. sharding)
      # as the parameter it corresponds to.
      return SGDState(
        jnp.zeros_like(variable.value), **variable.get_metadata()
      )

    self.lr = lr
    # Store a reference to the parameter State tree.
    self.params = params
    # Create the momentum state tree, mirroring the structure of 'params',
    # using the helper function. Momentum will have the same sharding as params.
    self.momentum = jax.tree.map(init_optimizer_state, self.params)
    self.decay = decay

  # Method to update parameters based on gradients.
  def update(self, grads: nnx.State):
    # Define the update logic for a single parameter/momentum/gradient triple.
    def update_fn(
      params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
    ):
      # Standard SGD with momentum update rule.
      # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
      momentum.value = self.decay * momentum.value + (1 - self.decay) * grad.value
      # θ_{t+1} = θ_t - α * v_t
      params.value -= self.lr * momentum.value # NOTE: Direct mutation of param value!

    # Apply the update function across the parameter, momentum, and gradient trees.
    # This performs the update in-place on the parameter values referenced by self.params.
    jax.tree.map(update_fn, self.params, self.momentum, grads)


### 2.5 Applying Sharding to the Model and Optimizer

In [None]:
# JIT-compile the model and optimizer creation function.
@nnx.jit
def create_model():
  # Instantiate the MLP model. rngs=nnx.Rngs(0) provides PRNG keys.
  model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
  # Create the optimizer. nnx.variables(model, nnx.Param) extracts
  # only the nnx.Param state variables from the model object.
  optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)

  # === Explicit Sharding Application ===
  # 1. Extract ALL state (model params + optimizer momentum) into a flat State pytree.
  state = nnx.state(optimizer)

  # 2. Define the target sharding for the state pytree.
  # This function maps state paths to NamedSharding objects based on stored metadata.
  def get_named_shardings(path: tuple, value: nnx.VariableState):
    # Assumes params and momentum use the sharding defined in their metadata.
    if path[0] in ('params', 'momentum'):
      # value.sharding contains the tuple like ('model',) or (None, 'model')
      # stored during Param/SGDState creation.
      return value.replace(NamedSharding(mesh, P(*value.sharding)))
    else:
      # Handle other state if necessary (e.g., learning rate if it were a Variable)
      raise ValueError(f'Unknown path: {path}')
  # Create the pytree of NamedSharding objects.
  named_shardings = state.map(get_named_shardings)

  # 3. Apply sharding constraint. This tells JAX how the 'state' pytree
  # SHOULD be sharded when computations involving it are run under jit/pjit.
  # It doesn't immediately move data but sets up the constraint for the compiler.
  sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)

  # 4. Update the original objects (model params, optimizer momentum)
  # with the constrained state values. This step makes the sharding
  # "stick" to the objects themselves for subsequent use outside this function.
  nnx.update(optimizer, sharded_state)

  # Return the model and optimizer objects, now containing sharded state variables.
  return model, optimizer

# Call the function to create the sharded model and optimizer.
model, optimizer = create_model()

# Visualize the sharding of the first weight's parameter tensor.
jax.debug.visualize_array_sharding(model.w1.value)
# Visualize the sharding of the first weight's momentum tensor.
jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)


### 2.6 Distributed Training

In [None]:
# JIT-compile the training step function.
@nnx.jit
def train_step(model: MLP, optimizer: SGD, x, y):
  # Define the loss function (Mean Squared Error).
  # Takes the model object as input, consistent with nnx.value_and_grad.
  def loss_fn(model):
    y_pred = model(x) # Forward pass
    loss = jnp.mean((y - y_pred) ** 2)
    return loss

  # Calculate loss and gradients w.r.t the model's state (its nnx.Param variables).
  # 'grad' will be an nnx.State object mirroring model's Param structure.
  loss, grad = nnx.value_and_grad(loss_fn)(model)

  # Call the optimizer's update method to apply gradients.
  # This updates the model parameters in-place.
  optimizer.update(grad)

  # Return the calculated loss.
  return loss


### 2.7 Training Loop and Results

In [None]:
# Generate synthetic dataset: y = 0.8*x^2 + 0.1 + noise
X = np.linspace(-2, 2, 100)[:, None] # Input features
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) # Target values

# A generator function to yield batches of data for training.
def dataset(batch_size, num_steps):
  for _ in range(num_steps):
    # Randomly sample indices for the batch.
    idx = np.random.choice(len(X), size=batch_size)
    # Yield the corresponding input and target pairs.
    yield X[idx], Y[idx]

# --- Training Loop ---
losses = [] # To store loss values for plotting
# Iterate through the dataset generator for 10,000 steps.
for step, (x_batch, y_batch) in enumerate(
  dataset(batch_size=32, num_steps=10_000)
):
  # CRITICAL: Place the NumPy data onto JAX devices AND apply sharding.
  # named_sharding('data') -> Shard along the 'data' mesh axis (first dim, size 2).
  # Each device along the 'data' axis gets a slice of the batch.
  x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data'))

  # Execute the JIT-compiled training step with the sharded model, optimizer, and data.
  loss = train_step(model, optimizer, x_batch, y_batch)

  # Record the loss (move scalar loss back to host CPU).
  losses.append(float(loss))
  # Log progress periodically.
  if step % 1000 == 0:
    print(f'Step {step}: Loss = {loss}')

# --- Plotting Results ---
plt.figure()
plt.title("Training Loss")
plt.plot(losses[20:]) # Plot loss, skipping initial noisy steps
plt.xlabel("Step")
plt.ylabel("MSE Loss")

# Get model predictions on the full dataset (X is on host CPU).
# Model applies function executes potentially on device, result brought back implicitly.
y_pred = model(X)
plt.figure()
plt.title("Model Fit")
plt.scatter(X, Y, color='blue', label='Data') # Original data
plt.plot(X, y_pred, color='black', label='Prediction') # Model's predictions
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.show() # Display the plots


## 3. Profiling Parallel Training

## 4. Transformer Scaling with Tensor Parallelism

![Megatraon-LM](https://docs.pytorch.org/tutorials/_images/megatron_lm.png)

<sup> Image Source: [Megatron-LM](https://arxiv.org/abs/1909.08053) </sup>

## 5. Scaling Transformers in Flax

In [None]:
# Define our device mesh - same as in the original JAX implementation
mesh = jax.make_mesh((4, 2), ('x', 'y'))

# Define logical axis names and their mapping to device mesh dimensions
# This maps conceptual dimensions to physical device axes
sharding_rules = [
    ('batch', 'x'),
    ('sequence', None),
    ('d_model', None),
    ('query_heads', 'y'),
    ('key_heads', 'y'),
    ('key_dim', None),
    ('d_ff', 'y'),
    ('vocab', None)
]

In [None]:
# Configuration class
class TransformerConfig:
  def __init__(self):
        # Significantly reduced model size
        self.d_model = 512       
        self.ffw_multiplier = 4
        self.num_layers = 2      
        self.query_heads = 8     
        self.kv_heads = 8        
        self.key_dim = 64        
        self.vocab_size = 1024   
        self.dtype = jnp.bfloat16

In [None]:
# RMS Normalization layer
class RMSNorm(nnx.Module):
    def __init__(self, dim, rngs):
        # Initialize gamma parameter with proper sharding annotation
        self.gamma = nnx.Param(
            nnx.with_metadata(
                jax.nn.initializers.ones,
                sharding=('d_model',),
                sharding_rules=sharding_rules
            )(rngs.params(), (dim,))
        )
    
    def __call__(self, x):
        rms = jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + 1e-6)
        return self.gamma.value * x * rms

In [None]:
# Transformer layer implementation
class TransformerLayer(nnx.Module):
    def __init__(self, cfg, rngs):
        # Initialize with he_normal (same as original)
        init_fn = jax.nn.initializers.he_normal()
        
        # QKV projections with appropriate sharding annotations
        self.q = nnx.Param(
            nnx.with_metadata(
                init_fn,
                sharding=('d_model', 'query_heads', 'key_dim'),
                sharding_rules=sharding_rules
            )(rngs.params(), (cfg.d_model, cfg.query_heads, cfg.key_dim))
        )
        
        self.k = nnx.Param(
            nnx.with_metadata(
                init_fn,
                sharding=('d_model', 'key_heads', 'key_dim'),
                sharding_rules=sharding_rules
            )(rngs.params(), (cfg.d_model, cfg.kv_heads, cfg.key_dim))
        )
        
        self.v = nnx.Param(
            nnx.with_metadata(
                init_fn,
                sharding=('d_model', 'key_heads', 'key_dim'),
                sharding_rules=sharding_rules
            )(rngs.params(), (cfg.d_model, cfg.kv_heads, cfg.key_dim))
        )
        
        # Output projection
        self.proj = nnx.Param(
            nnx.with_metadata(
                init_fn,
                sharding=('query_heads', 'key_dim', 'd_model'),
                sharding_rules=sharding_rules
            )(rngs.params(), (cfg.query_heads, cfg.key_dim, cfg.d_model))
        )
        
        # FFN parameters
        self.w1 = nnx.Param(
            nnx.with_metadata(
                init_fn,
                sharding=('d_model', 'd_ff'),
                sharding_rules=sharding_rules
            )(rngs.params(), (cfg.d_model, cfg.d_model * cfg.ffw_multiplier))
        )
        
        self.w2 = nnx.Param(
            nnx.with_metadata(
                init_fn,
                sharding=('d_ff', 'd_model'),
                sharding_rules=sharding_rules
            )(rngs.params(), (cfg.d_model * cfg.ffw_multiplier, cfg.d_model))
        )
        
        # Layer normalization parameters
        self.norm1 = RMSNorm(cfg.d_model, rngs)
        self.norm2 = RMSNorm(cfg.d_model, rngs)
        
        # Store config values needed during forward pass
        self.query_heads = cfg.query_heads
        self.kv_heads = cfg.kv_heads
    
    def __call__(self, x):
        # Apply sharding constraint to input
        x = jax.lax.with_sharding_constraint(x, PartitionSpec('x', None, 'x'))
        
        # First normalization and attention block
        with jax.named_scope('pre_attn_norm'):
            attn_in = self.norm1(x)
        
        # QKV projections
        with jax.named_scope('qkv_matmul'):
            q = jnp.einsum('btd,dhq->bhtq', attn_in, self.q.value)
            k = jnp.einsum('btd,dhk->bhtk', attn_in, self.k.value)
            v = jnp.einsum('btd,dhv->bhtv', attn_in, self.v.value)
        
        # Attention computation
        with jax.named_scope('attn'):
            scale = q.shape[-1] ** -0.5
            num_query_heads, num_kv_heads = self.query_heads, self.kv_heads
            
            if num_query_heads == num_kv_heads or num_kv_heads == 1:
                qk = jnp.einsum('bhtd,bhsd->bhts', q, k) * scale
                logits = jax.nn.softmax(qk.astype(jnp.float32), axis=-1)
                attn_vec = jnp.einsum('bhsd,bhts->bhtd', v, logits)
            else:
                assert num_query_heads % num_kv_heads == 0
                q = q.reshape(q.shape[0:1] +
                            (num_kv_heads, num_query_heads // num_kv_heads) +
                            q.shape[2:])
                qk = jnp.einsum('bqhtd,bhsd->bqhts', q, k) * scale
                logits = jax.nn.softmax(qk.astype(jnp.float32), axis=-1)
                attn_vec = jnp.einsum('bqsd,bqhts->bqhtd', v, logits)
                attn_vec = attn_vec.reshape(attn_vec.shape[0:1] + (num_query_heads,) + attn_vec.shape[3:])
        
        # Attention projection
        with jax.named_scope('attn_proj'):
            attn_out = jnp.einsum('bhtv,hvd->btd', attn_vec, self.proj.value)
        
        # First residual connection
        with jax.named_scope('residual'):
            x = x + attn_out
        
        # Second normalization
        with jax.named_scope('ffn_pre_norm'):
            ffw_in = self.norm2(x)
        
        # Apply sharding constraint before FFN
        ffw_in = jax.lax.with_sharding_constraint(ffw_in, PartitionSpec('x', None, 'x'))
        
        # FFN block
        with jax.named_scope('ffw'):
            # Add explicit sharding constraints on weights
            w1 = jax.lax.with_sharding_constraint(
                self.w1.value,
                PartitionSpec(None, 'y')
            )
            w2 = jax.lax.with_sharding_constraint(
                self.w2.value,
                PartitionSpec('y', None)
            )
            
            ffw_out = jnp.einsum('btd,df->btf', ffw_in, w1).astype(jnp.bfloat16)
            
            # Add sharding constraint on intermediate activations
            ffw_out = jax.lax.with_sharding_constraint(
                ffw_out, 
                PartitionSpec('x', None, 'y')
            )
            
            ffw_out = jax.nn.gelu(ffw_out)
            ffw_out = jnp.einsum('btf,fd->btd', ffw_out, w2).astype(jnp.bfloat16)
        
        # Second residual connection
        with jax.named_scope('residual'):
            x = x + ffw_out
        
        # Final sharding constraint
        x = jax.lax.with_sharding_constraint(x, PartitionSpec('x', None, 'x'))
        
        return x


In [None]:
# Full transformer model
class Transformer(nnx.Module):
    def __init__(self, cfg, rngs):
        # Token embedding with sharding annotation
        self.embedding = nnx.Param(
            nnx.with_metadata(
                jax.nn.initializers.he_normal(),
                sharding=('vocab', 'd_model'),
                sharding_rules=sharding_rules
            )(rngs.params(), (cfg.vocab_size, cfg.d_model))
        )
        
        # Create all transformer layers
        self.layers = [TransformerLayer(cfg, rngs) for _ in range(cfg.num_layers)]
        self.vocab_size = cfg.vocab_size
    
    def __call__(self, x):
        # Input embedding
        one_hot = jax.nn.one_hot(x, self.vocab_size)
        x = jnp.einsum('vd,btv->btd', self.embedding.value, one_hot)
        
        # Process through each layer
        for idx, layer in enumerate(self.layers):
            with jax.named_scope(f'layer_{idx}'):
                x = layer(x)
        
        # Final projection to logits
        logits = jnp.einsum('vd,btd->btv', self.embedding.value, x)
        return jax.nn.log_softmax(logits, axis=-1)


In [None]:
# Function to create a sharded model
@nnx.jit
def create_sharded_model():
    # Initialize configuration
    cfg = TransformerConfig()
    
    # Create unsharded model
    model = Transformer(cfg, rngs=nnx.Rngs(42))
    
    # Get model state and partition specs from annotations
    state = nnx.state(model)
    pspecs = nnx.get_partition_spec(state)
    
    # Apply sharding constraints to state
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    
    # Update model with sharded state
    nnx.update(model, sharded_state)
    
    return model

In [None]:
# Create sample data with appropriate sharding
def create_sample_data(batch_size, seq_len):
    data_sharding = NamedSharding(mesh, PartitionSpec('x', None))
    x = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
    return jax.device_put(x, data_sharding)

# Training step function
@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        logits = model(x)
        
        # Cross-entropy loss
        labels_onehot = jax.nn.one_hot(y, model.vocab_size)
        loss = -jnp.sum(labels_onehot * logits, axis=-1)
        return jnp.mean(loss)
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    
    return loss

# Profiling utility function as mentioned in the flax tutorial
def block_all(xs):
    jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
    return xs

In [None]:
# Main execution block
with mesh:
    # Create sharded model
    print("Creating sharded model...")
    sharded_model = create_sharded_model()
    
    # Create optimizer
    print("Creating optimizer...")
    optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3))
    
    # Create sample data
    print("Creating sample data...")
    batch_size, seq_len = 8, 1024
    input_ids = create_sample_data(batch_size, seq_len)
    target_ids = create_sample_data(batch_size, seq_len)
    
    # Train for a few steps
    print("Starting training...")
    for i in range(5):
        loss = block_all(train_step(sharded_model, optimizer, input_ids, target_ids))
        print(f"Step {i+1}, Loss: {loss}")
    
    # Optionally profile the training step
    print("\nProfiling training step...")
    import time
    
    # Warmup
    _ = block_all(train_step(sharded_model, optimizer, input_ids, target_ids))
    
    # Measure
    start_time = time.time()
    for _ in range(10):
        _ = block_all(train_step(sharded_model, optimizer, input_ids, target_ids))
    end_time = time.time()
    
    print(f"Average step time: {(end_time - start_time) / 10:.4f} seconds")