<a href="https://colab.research.google.com/github/pkraison/jax-playground/blob/main/Training_a_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Switch accelerator by clicking on `Runtime` in the top menu, then `Change runtime type`, and selecting `GPU` from the `Hardware accelerator` dropdown. If the runtime fails, feel free to disable the GPU and run the notebook on the CPU.

JAX provides a high-performance backend with the XLA (Accelerated Linear Algebra) compiler to optimize our computations on the available hardware. As JAX continue to be developed, there are more and more features being implemented, that improve efficiency. We can enable some of these new features via XLA flags. At the moment of writing (JAX version 0.4.25, March 2024), the following flags are recommended in the JAX [GPU performance tips tutorial](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) and [PAX](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags):

In [1]:
import os

os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=false "
)

In [2]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.struct import dataclass
from flax.training import train_state

# Type aliases
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

## Mixed Precision Training

Mixed precision training is a technique that uses both 16-bit and 32-bit floating-point numbers to speed up training. The idea is to use 16-bit floating-point numbers for most of the computations, as they are faster and require less memory. However, 16-bit floating-point numbers have a smaller range and precision compared to 32-bit floating-point numbers. Therefore, we use 32-bit floating-point numbers for certain computations, such as the model's weight updates and the final loss computation, to avoid numerical instability.

A potential problem with `float16` is that we can encounter underflow and overflow issues during training. This means that the gradients or activations become too large or too small to be represented in the range of `float16`, and we lose information. Scaling the loss and gradients by a constant factor can help mitigate this issue to bring the values back into the representable range. This is known as [loss scaling](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling), and it is a common technique used in mixed precision training.

As an alternative, JAX and other deep learning frameworks like [PyTorch](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch/) also support the `bfloat16` format, which is a 16-bit floating-point format with 8 exponent bits and 7 mantissa bits. The `bfloat16` format has a larger range but lower precision compared to the IEEE half-precision type `float16`, and matches `float32` in terms of range. A closer comparison between the formats is shown in the figure below (figure credit: [Google Cloud Documentation](https://cloud.google.com/tpu/docs/bfloat16)):


The main benefit of using `bfloat16` is that it can be used without loss scaling, as it has a larger range compared to `float16`. This allows `bfloat16` to be used as a drop-in replacement for `float32` in many cases to save memory and achieve performances close to `float32` (see e.g. [JKalamkar et al., 2019](https://arxiv.org/abs/1905.12322)). For situations where precision matters over range, `float16` may be the better option. Besides memory efficiency, many accelerators like [TPUs](https://cloud.google.com/tpu/docs/bfloat16) and [GPUs](https://www.nvidia.com/en-us/data-center/tensor-cores/) have native support for `bfloat16`, which can lead up to 2x speedup in training performance compared to `float32` on these devices. Hence, we will use `bfloat16` in this notebook.

We implement mixed precision training by lowering all features and activations within the model to `bfloat16`, while keeping the weights and optimizer states in `float32`. This is done to keep high precision for the weight updates and optimizer states, while reducing the memory footprint and increasing the training speed by using `bfloat16` for the forward and backward passes. While this does not reduce the memory footprint of the model parameters themselves, we often achieve a significant reduction in memory consumption due to the reduced memory footprint of the activations without influencing the model's performance.

In [3]:
class MLPClassifier(nn.Module):
    dtype: Any
    hidden_size: int = 256
    num_classes: int = 100
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.hidden_size,
            dtype=self.dtype,  # Computation in specified dtype, params stay in float32
        )(x)
        x = nn.LayerNorm(dtype=self.dtype)(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.num_classes,
            dtype=self.dtype,
        )(x)
        x = x.astype(jnp.float32)
        x = nn.log_softmax(x, axis=-1)
        return x

In [4]:
x = jnp.ones((512, 128), dtype=jnp.float32)
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
model_float32 = MLPClassifier(dtype=jnp.float32)
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})

'\n\n'