# JAX Loop Utils

<a href="https://colab.research.google.com/github/Astera-org/jax_loop_utils/blob/master/synopsis.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


    pip install jax-loop-utils

https://github.com/garymm/jax_loop_utils

Writing and maintaining your own training loop gives lots of
flexibility but also quickly leads to non-trivial amount of code that is
repeated in every project.

`jax_loop_utils` provides small independent helpers to make the training loop shorter and
easier to read, while keeping maximum flexibility.

**This notebook** walks you through the different modules of `jax_loop_utils` with simple
example code for showcasing the important concepts and to be pasted into your
training loop to get started using `jax_loop_utils`.

### Setup

In [1]:
%pip install -q "jax-loop-utils[tf-data]" flax

In [4]:
%pip show jax-loop-utils

Name: jax-loop-utils
Version: 0.0.12
Location: /Users/garymm/src/Astera-org/jax_loop_utils/.venv/lib/python3.12/site-packages
Editable project location: /Users/garymm/src/Astera-org/jax_loop_utils
Requires: absl-py, etils, jax, jaxlib, ml-collections, numpy, packaging, typing-extensions, wrapt
Required-by:


In [1]:
import jax
import jax.numpy as jnp

In [2]:
import chex

chex.set_n_cpu_devices(2)  # Simulate 2 local devices in a CPU Colab runtime.

### `jax_loop_utils.metric_writers`

The module [`metric_writers`] provides a simple [interface] to write time series
metrics in a unified way.

Metric writers provided:

- `SummaryWriter`: Uses `tf.summary` to write summary files. For display in
  TensorBoard.
- `LoggingWriter`: Simply writes values to the INFO log. This obviously only
  supports data types that can be converted to text but is still helpful for
  seeing the training progress on the command line.
- `TorchTensorboardWriter`: Uses `torch.utils.tensorboard` to write summary
  files. Use this writer for the Pytorch-based code.

Additional we provide metric writers to combine multiple metric writers
(`MultiWriter`) and to move the write operation to a background thread
(`AsyncWriter`).

[`metric_writers`]: https://github.com/Astera-org/jax_loop_utils/blob/master/jax_loop_utils/metric_writers/__init__.py
[interface]: https://github.com/Astera-org/jax_loop_utils/blob/master/jax_loop_utils/metric_writers/interface.py


In [3]:
from absl import logging

logging.set_verbosity(logging.INFO)

In [4]:
logdir = "./metrics"

In [5]:
import os
import pathlib

from jax_loop_utils import metric_writers
from jax_loop_utils.metric_writers.tf import SummaryWriter

# Handy shortcut to create create async logging/tensorboard writer.


def create_default_writer(
    logdir: os.PathLike | None = None,
    *,
    just_logging: bool = False,
    asynchronous: bool = True,
    collection: str | None = None,
) -> metric_writers.MultiWriter:
    """Create the default writer for the platform.

    On most platforms this will create a MultiWriter that writes to multiple back
    ends (logging, TF summaries etc.).

    Args:
      logdir: Logging dir to use for TF summary files. If empty/None will the
        returned writer will not write TF summary files.
      just_logging: If True only use a LoggingWriter. This is useful in multi-host
        setups when only the first host should write metrics and all other hosts
        should only write to their own logs.
        default (None) will automatically determine if you # GOOGLE-INTERNAL have
      asynchronous: If True return an AsyncMultiWriter to not block when writing
        metrics.
      collection: A string which, if provided, provides an indication that the
        provided metrics should all be written to the same collection, or
        grouping.

    Returns:
      A `MetricWriter` according to the platform and arguments.
    """
    if just_logging:
        if asynchronous:
            return metric_writers.AsyncMultiWriter(
                [metric_writers.LoggingWriter(collection=collection)]
            )
        else:
            return metric_writers.MultiWriter([metric_writers.LoggingWriter(collection=collection)])
    writers = [metric_writers.LoggingWriter(collection=collection)]
    if logdir is not None:
        logdir = pathlib.Path(logdir)
        if collection is not None:
            logdir /= collection
        writers.append(SummaryWriter(os.fspath(logdir)))
    if asynchronous:
        return metric_writers.AsyncMultiWriter(writers)
    return metric_writers.MultiWriter(writers)


writer = create_default_writer(logdir)
for step in range(10):
    writer.write_scalars(step, dict(loss=0.9**step))

INFO:absl:[0] loss=1
INFO:absl:[1] loss=0.9
INFO:absl:[2] loss=0.81
INFO:absl:[3] loss=0.729
INFO:absl:[4] loss=0.6561
INFO:absl:[5] loss=0.59049
INFO:absl:[6] loss=0.531441
INFO:absl:[7] loss=0.478297


INFO:absl:[8] loss=0.430467
INFO:absl:[9] loss=0.38742
INFO:absl:[10] steps_per_sec=459770
INFO:absl:[10] uptime=0.0005575
INFO:absl:[20] steps_per_sec=439.342
INFO:absl:[20] uptime=0.0238017
INFO:absl:[30] steps_per_sec=9074.06
INFO:absl:[30] uptime=0.0257181
INFO:absl:[40] steps_per_sec=5392.65
INFO:absl:[40] uptime=0.028389
INFO:absl:[50] steps_per_sec=3739.37
INFO:absl:[50] uptime=0.0304817
INFO:absl:[60] steps_per_sec=4778.5
INFO:absl:[60] uptime=0.0324869
INFO:absl:[70] steps_per_sec=5000.73
INFO:absl:[70] uptime=0.0336201
INFO:absl:[80] steps_per_sec=8857.07
INFO:absl:[80] uptime=0.0361314
INFO:absl:[90] steps_per_sec=3984.86
INFO:absl:[90] uptime=0.0388259


In [None]:
%load_ext tensorboard
%tensorboard --logdir=./metrics

### `jax_loop_utils.periodic_actions`

[`periodic_actions`] are simple helpers that allow you to do in the training
loop at regular intervals. Currently we support

- `PeriodicAction`, `PeriodicCallback`: To implement your own actions.
- `Profile`: To create TensorBoard compatible profiles.
- `ReportProgress`: To continuously print progress status updates.

[`periodic_actions`]: https://github.com/Astera-org/jax_loop_utils/blob/master/jax_loop_utils/periodic_actions.py

In [6]:
from jax_loop_utils import periodic_actions

total_steps = 100
hooks = [
    # Outputs progress via metric writer (in this case logs & TensorBoard).
    periodic_actions.ReportProgress(num_train_steps=total_steps, every_steps=10, writer=writer),
    periodic_actions.Profile(logdir=logdir),
]

for step in range(total_steps):
    for hook in hooks:
        hook(step)

INFO:absl:Setting work unit notes: 459770.3 steps/s, 10.0% (10/100), ETA: 0m
INFO:absl:Setting work unit notes: 439.3 steps/s, 20.0% (20/100), ETA: 0m
INFO:absl:Setting work unit notes: 9074.1 steps/s, 30.0% (30/100), ETA: 0m
INFO:absl:Setting work unit notes: 5392.7 steps/s, 40.0% (40/100), ETA: 0m
INFO:absl:Setting work unit notes: 3739.4 steps/s, 50.0% (50/100), ETA: 0m
INFO:absl:Setting work unit notes: 4778.5 steps/s, 60.0% (60/100), ETA: 0m
INFO:absl:Setting work unit notes: 5000.7 steps/s, 70.0% (70/100), ETA: 0m
INFO:absl:Setting work unit notes: 8857.1 steps/s, 80.0% (80/100), ETA: 0m
INFO:absl:Setting work unit notes: 3984.9 steps/s, 90.0% (90/100), ETA: 0m


In [7]:
# If you click on "refresh" in above TensorBoard you'll now see a new
# "steps_per_sec" metric...
!ls -lh metrics

total 8
-rw-r--r--  1 garymm  staff   1.9K Nov 15 16:41 events.out.tfevents.1731717685.Garys-MacBook-Pro.local.3562.0.v2


  pid, fd = os.forkpty()


### `jax_loop_utils.metrics`

The [`metrics`] module provides a framework for functional metric computation.
Note that this module does **not** include the actual metric definitions (other
than `metrics.Accuracy` that is provided for demonstration purposes), but
rather provides abstractions that can be used to compute metrics in a
distributed distributed environment.

This section is a bit longer than the previous sections and walks you through
the following parts:

1. How `metrics.Metric` is computed, and defining "averageable" metrics.
2. Using `metrics.Collection` to compute several metrics at once.
3. Aggregating in an evaluation step that is transformed by `pmap()`.
4. Define a new metric with custom aggregation (i.e. non "averageable").


[`metrics`]: https://github.com/Astera-org/jax_loop_utils/blob/master/jax_loop_utils/metrics.py

In [8]:
import flax

from jax_loop_utils import metrics

# Metrics are computed in three steps:

# 1. Compute intermediate values from model outputs
accuracy_batch1 = metrics.Accuracy.from_model_output(
    logits=jnp.array([[-1.0, 1.0], [1.0, -1.0]]),
    labels=jnp.array([0, 0]),  # i.e. 1st incorrect, 2nd correct
)
accuracy_batch2 = metrics.Accuracy.from_model_output(
    logits=jnp.array([[-1.0, 1.0], [1.0, -1.0]]),
    labels=jnp.array([1, 0]),  # i.e. both correct
)

# 2. Intermediate values are aggregated
accuracy = accuracy_batch1
accuracy = accuracy.merge(accuracy_batch2)

# 3. Final metrics are computed from aggregated intermediate values:
accuracy.compute()

Array(0.75, dtype=float32)

In [9]:
# It's easy to define your own metrics if they are "averageable":

AverageLoss = metrics.Average.from_output("loss")

AverageLoss.from_model_output(loss=jnp.array([1.1, 3.3])).compute()

Array(2.2, dtype=float32)

In [10]:
# You can provide a functional to derive the value-to-be-averaged:

# Note that our metric only uses the model output named "loss". There can be an
# arbitrary number of additional model outputs that we don't need here (**_).
AverageSquaredLoss = metrics.Average.from_fun(lambda loss, **_: loss**2)

AverageSquaredLoss.from_model_output(loss=jnp.array([1.1**0.5, 3.3**0.5])).compute()

Array(2.1999998, dtype=float32)

In [11]:
# Usually you would want to compute a collection of metrics from model outputs:


@flax.struct.dataclass  # <-- required for JAX transformations
class MyMetrics(metrics.Collection):
    loss: metrics.Average.from_output("loss")  # type: ignore[invalid-type-form]
    accuracy: metrics.Accuracy


# 1. Compute intermediate values from model outputs
my_metrics_batch1 = MyMetrics.single_from_model_output(
    logits=jnp.array([[-1.0, 1.0], [1.0, -1.0]]),
    labels=jnp.array([0, 0]),  # i.e. 1st incorrect, 2nd correct
    loss=jnp.array([3.3, 2.2]),
)
my_metrics_batch2 = MyMetrics.single_from_model_output(
    logits=jnp.array([[-1.0, 1.0], [1.0, -1.0]]),
    labels=jnp.array([1, 0]),  # i.e. both correct
    loss=jnp.array([2.2, 1.1]),
)

# 2. Intermediate values are aggregated
my_metrics = my_metrics_batch1.merge(my_metrics_batch2)

# 3. Final metrics are computed from aggregated intermediate values:
my_metrics.compute()

{'loss': Array(2.2, dtype=float32), 'accuracy': Array(0.75, dtype=float32)}

In [12]:
# Often you want to compute these metrics inside a pmap(). The framework
# provides the handy `Collection.gather_from_model_output` that will first
# compute the intermediate values, then call `jax.lax.all_gather()` to gather
# the intermediate values from all the devices (in a multi-host setup that's
# all the devices in the mesh, not only the local devices), and then reduce them
# by calling `Metric.merge()` in a `jax.lax.scan()` loop.

# Sounds complicated? Using it is actually surprisingly simple:


def fake_model(params, batch):
    del params  # Fake.
    return batch


def eval_step(my_metrics, params, batch):
    model_outputs = fake_model(params, batch)
    # IMPORTANT: If you called `.single_from_model_output()` here, then all values
    # from devices after the first device would be ignored for the metric
    # computation.
    return my_metrics.merge(MyMetrics.gather_from_model_output(**model_outputs))


eval_step_p = jax.pmap(eval_step, axis_name="batch")

my_metrics = flax.jax_utils.replicate(MyMetrics.empty())

for batch in [
    # Single batch of data pmapped on two devices in parallel.
    dict(
        logits=jnp.array(
            [
                # Batch for device 1
                [[-1.0, 1.0], [1.0, -1.0]],
                # Batch for device 2
                [[-1.0, 1.0], [1.0, -1.0]],
            ]
        ),
        labels=jnp.array(
            [
                # Batch for device 1
                [0, 0],
                # Batch for device 2
                [1, 0],
            ]
        ),
        loss=jnp.array(
            [
                # Batch for device 1
                [3.3, 2.2],
                # Batch for device 2
                [2.2, 1.1],
            ]
        ),
    ),
]:
    my_metrics = eval_step_p(my_metrics, None, batch)

# Note that up to this point all inputs/outputs to `eval_step_p()` are
# replicated such that their leading dimension == number of local devices == 8.
my_metrics.unreplicate().compute()

{'loss': Array(2.2, dtype=float32), 'accuracy': Array(0.75, dtype=float32)}

In [13]:
try:
    my_metrics.compute()
    raise RuntimeError("Expected ValueError!")
except ValueError as e:
    print("Note that not calling `.unreplicate()` raises an erorr:", e)

Note that not calling `.unreplicate()` raises an erorr: Collection is still replicated (ndim=1). Maybe you forgot to call a flax.jax_utils.unreplicate() or a Collections.reduce()?


In [14]:
# You can also provide your own aggregation logic:


@flax.struct.dataclass
class Precision(metrics.Metric):
    """Computes the precision from model outputs `logits` and `labels`."""

    true_positives: jnp.array
    pred_positives: jnp.array

    @classmethod
    def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_) -> metrics.Metric:
        assert logits.shape[-1] == 2, "Expected binary logits."
        preds = logits.argmax(axis=-1)
        return cls(
            true_positives=((preds == 1) & (labels == 1)).sum(),
            pred_positives=(preds == 1).sum(),
        )

    def merge(self, other: metrics.Metric) -> metrics.Metric:
        # Note that for precision we cannot average metric values because the
        # denominator of the metric value is pred_positives and not every batch of
        # examples has the same number of pred_positives (as opposed to e.g.
        # accuracy where every batch has the same number of)
        return type(self)(
            true_positives=self.true_positives + other.true_positives,
            pred_positives=self.pred_positives + other.pred_positives,
        )

    def compute(self):
        return self.true_positives / self.pred_positives


Precision.from_model_output(
    # 1 TP, 1 FN -- 2 pred_positives -- precision = 1.0
    logits=jnp.array([[-1.0, 1.0], [1.0, -1.0]]),
    labels=jnp.array([1, 1]),  # i.e. 1st incorrect, 2nd correct
).merge(
    Precision.from_model_output(
        # 1 TP, 1 FP -- 2 pred_positives -- precision = 0.5
        logits=jnp.array([[-1.0, 1.0], [-1.0, 1.0]]),
        labels=jnp.array([1, 0]),  # i.e. 1st incorrect, 2nd correct
    )
).compute()

# If one incorrectly used metrics.Average to aggregate the metric, the final
# value would be 0.75 because both batches have the same weight in terms of
# examples. But the first batch constains 2 pred_positives and should thus be
# weighted 2x, resulting in the correct (1 + 1) / (1 + 2) == 0.66

Array(0.6666667, dtype=float32)