# README Examples Validation

This notebook contains all code examples from the README.md file.
Run this notebook to verify that:
1. All examples execute without errors
2. Outputs match what's documented in the README

**After making changes to blox**, run this notebook and update the README if outputs have changed.

In [1]:
# Setup: ensure blox is importable and configure devices
import sys

sys.path.insert(0, '../src')

# Set up 4 CPU devices for sharding examples.
# Must be done before JAX is imported.
import chex

chex.set_n_cpu_devices(4)

## Quick Start: Define your layers

From README section: "Define your layers"

In [2]:
import blox as bx
import jax
import jax.numpy as jnp


class CustomLinear(bx.Module):

  def __init__(
      self,
      graph: bx.Graph,
      output_size: int,
  ) -> None:
    super().__init__(graph)
    self.output_size = output_size

  def __call__(
      self,
      params: bx.Params,
      inputs: jax.Array,
  ) -> tuple[jax.Array, bx.Params]:
    # Param initialization is lazy which serves two important purposes:
    # 1. Avoids the need to specify input dimensions at construction.
    # 2. Prevents accidental allocation of params on device.
    kernel, params = self.get_param(
        params=params,
        name='kernel',
        shape=(inputs.shape[-1], self.output_size),
        init=jax.nn.initializers.glorot_uniform(),
    )
    bias, params = self.get_param(
        params=params,
        name='bias',
        shape=(self.output_size,),
        init=jax.nn.initializers.zeros,
    )
    return inputs @ kernel + bias, params


print('CustomLinear defined successfully!')

CustomLinear defined successfully!


## Composition & Dependency Injection

From README section: "Composition & Dependency Injection"

In [3]:
class CustomMLP(bx.Module):

  def __init__(
      self,
      graph: bx.Graph,
      hidden_size: int,
      # We can inject externally created modules...
      output_projection: bx.Module,
  ) -> None:
    super().__init__(graph)
    # ... or create new ones internally.
    self.hidden_proj = CustomLinear(graph.child('hidden'), hidden_size)
    self.output_projection = output_projection

  def __call__(
      self,
      params: bx.Params,
      inputs: jax.Array,
  ) -> tuple[jax.Array, bx.Params]:
    # Chain the functional transformations.
    hidden, params = self.hidden_proj(params, inputs)
    hidden = jax.nn.relu(hidden)
    return self.output_projection(params, hidden)


print('CustomMLP defined successfully!')

CustomMLP defined successfully!


## Initialization & Inspection

From README section: "Initialization & Inspection"

**Important**: The output of `bx.display()` should match the README. If it differs, update the README!

In [4]:
# Define the structure for wiring modules.
graph = bx.Graph('net')

# Create the output layer explicitly and use it to create our CustomMLP.
readout = CustomLinear(graph.child('readout'), output_size=1)
model = CustomMLP(graph.child('mlp'), hidden_size=32, output_projection=readout)

# Create dummy input data to infer shapes.
inputs = jnp.ones((1, 10))

# Initialize the parameters.
# Params requires an Rng module for handling randomness.
rng = bx.Rng(graph.child('rng'), seed=42)
params = bx.Params(rng=rng)

# Run a forward pass to trigger lazy initialization.
unused_outputs, params = model(params, inputs)

# Finalize Params to prevent accidental structure changes later.
params = params.finalized()

# Visualize the full graph and parameter structure.
bx.display(graph, params)

**Expected output structure** (values may vary due to random initialization):
- `readout` and `mlp` should be siblings
- `hidden` should be nested inside `mlp`
- `output_projection` in `mlp` should show as a reference to `readout`

## Parallel Execution (vmap & shard_map)

From README section: "Parallel Execution"

In [5]:
# Create a fresh graph for this example
graph2 = bx.Graph('dropout_net')
dropout = bx.Dropout(
    graph2.child('dropout'),
    rate=0.5,
    rng=bx.Rng(graph2.child('dropout_rng'), seed=0),
)
rng2 = bx.Rng(graph2.child('rng'), seed=42)
params2 = bx.Params(rng=rng2)

# Initialize dropout
dummy_input = jnp.ones((4, 8))
_, params2 = dropout(params2, dummy_input, is_training=True)
params2 = params2.finalized()


def apply_model(params, inputs):
  # Fold in the batch axis so each batch element gets a unique RNG stream.
  params = params.fold_in_axes('batch')
  outputs, params = dropout(params, inputs, is_training=True)
  # Fold out before returning to restore the replicated state structure.
  return outputs, params.fold_out_axes('batch')


# Note that params (including the Rng) are replicated.
inputs_batch = jnp.ones((4, 8))  # 4 batch elements
batched_outputs, _ = jax.vmap(
    apply_model, in_axes=(None, 0), out_axes=(0, None), axis_name='batch'
)(params2, inputs_batch)

print('Batched outputs shape:', batched_outputs.shape)
print('Each batch element has different dropout mask (check non-zero counts):')
for i in range(4):
  print(f'  Batch {i}: {jnp.sum(batched_outputs[i] != 0).item()} non-zero')

Batched outputs shape: (4, 8)
Each batch element has different dropout mask (check non-zero counts):
  Batch 0: 5 non-zero
  Batch 1: 5 non-zero
  Batch 2: 5 non-zero
  Batch 3: 1 non-zero


## Parameter Metadata & Sharding

From README section: "Parameter Metadata & Sharding"

In [6]:
import functools

from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P

graph3 = bx.Graph('net')
linear = bx.Linear(
    graph3.child('linear'),
    output_size=1024,
    kernel_metadata={'sharding': (None, 'model')},
    bias_metadata={'sharding': ('model',)},
)
rng3 = bx.Rng(graph3.child('rng'), 42)


# Define an initialization function.
def init(x):
  _, params = linear(bx.Params(rng=rng3), x)
  return params.finalized()


# Abstract evaluation to get the Params structure (no memory allocation).
inputs3 = jnp.ones((4, 4))
abstract_params = jax.eval_shape(init, inputs3)

# Create the sharding specification from metadata.
mesh = jax.make_mesh((4,), ('model',))

params_sharding = jax.tree.map(
    lambda p: NamedSharding(mesh, P(*p.sharding)),
    abstract_params,
    is_leaf=lambda x: isinstance(x, bx.Param),
)

# JIT-compile the init function with out_shardings.
# Params are created directly on the correct devices, with no memory overhead.
sharded_init = jax.jit(init, out_shardings=params_sharding)
sharded_params = sharded_init(inputs3)


@functools.partial(jax.jit, in_shardings=(params_sharding, None))
def forward(params, x):
  return linear(params, x)


out, new_params = forward(sharded_params, inputs3)

print('Output shape:', out.shape)
print('Sharding example completed successfully!')

  mesh = jax.make_mesh((4,), ('model',))


Output shape: (4, 1024)
Sharding example completed successfully!


## Recurrence & Scanning

From README section: "Recurrence & Scanning"

In [7]:
# Create a fresh graph for LSTM example
graph4 = bx.Graph('lstm_net')
lstm = bx.LSTM(graph4.child('lstm'), hidden_size=128)
rng4 = bx.Rng(graph4.child('rng'), seed=42)
params4 = bx.Params(rng=rng4)

# Create sequence input [Batch, Time, Features]
inputs_sequence = jnp.ones((2, 10, 64))  # 2 batches, 10 timesteps, 64 features

# Initialize the LSTM state.
state, params4 = lstm.initial_state(
    params4, inputs_sequence[:, 0, :]
)  # Use single timestep for init

# Run efficient compiled scan over a sequence [Batch, Time, Features].
# It automatically handles carry propagation.
(outputs, final_state), params4 = lstm.apply(
    params4, inputs_sequence, prev_state=state
)

print('LSTM outputs shape:', outputs.shape)
print('Final hidden state shape:', final_state.hidden.shape)
print('Final cell state shape:', final_state.cell.shape)

LSTM outputs shape: (2, 10, 128)
Final hidden state shape: (2, 128)
Final cell state shape: (2, 128)


## Training (JIT & Gradients)

From README section: "Training (JIT & Gradients)"

In [8]:
# Use the model from the Quick Start section
# Re-create to have fresh params
graph5 = bx.Graph('train_net')
readout5 = CustomLinear(graph5.child('readout'), output_size=1)
model5 = CustomMLP(
    graph5.child('mlp'), hidden_size=32, output_projection=readout5
)
rng5 = bx.Rng(graph5.child('rng'), seed=42)
params5 = bx.Params(rng=rng5)

# Initialize
train_inputs = jnp.ones((8, 10))  # 8 samples, 10 features
_, params5 = model5(params5, train_inputs)
params5 = params5.finalized()


@jax.jit
def train_step(params, inputs, targets):
  # Split params into two sets.
  # Trainable: weights, biases (we want gradients for these).
  # Non-trainable: Rng, batch stats, EMA (we just want the updated values).
  trainable, non_trainable = params.split()

  def loss_fn(t, nt):
    # Merge parameters to run the forward pass.
    predictions, new_params = model5(t.merge(nt), inputs)

    # Calculate the loss.
    loss = jnp.mean((predictions - targets) ** 2)

    # Extract the updated non-trainable state to pass it out.
    _, new_non_trainable = new_params.split()
    return loss, new_non_trainable

  # Calculate gradients and capture the auxiliary state (non_trainable updates).
  grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(
      trainable, non_trainable
  )

  # Update the trainable weights using SGD.
  new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)

  # Merge the updated weights with the updated non-trainable state.
  return new_trainable.merge(new_non_trainable)


# Run a few training steps
targets = jnp.zeros((8, 1))  # Target outputs
for step in range(5):
  params5 = train_step(params5, train_inputs, targets)

# Verify training worked by checking loss decreased
predictions, _ = model5(params5, train_inputs)
final_loss = jnp.mean((predictions - targets) ** 2)
print(f'Final loss after 5 steps: {final_loss:.6f}')
print('Training example completed successfully!')

Final loss after 5 steps: 0.002367
Training example completed successfully!
