<a href="https://colab.research.google.com/github/pkraison/jax-playground/blob/main/Training_a_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Switch accelerator by clicking on `Runtime` in the top menu, then `Change runtime type`, and selecting `GPU` from the `Hardware accelerator` dropdown. If the runtime fails, feel free to disable the GPU and run the notebook on the CPU.

JAX provides a high-performance backend with the XLA (Accelerated Linear Algebra) compiler to optimize our computations on the available hardware. As JAX continue to be developed, there are more and more features being implemented, that improve efficiency. We can enable some of these new features via XLA flags. At the moment of writing (JAX version 0.4.25, March 2024), the following flags are recommended in the JAX [GPU performance tips tutorial](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) and [PAX](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags):

In [1]:
import os

os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=false "
)

In [2]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.struct import dataclass
from flax.training import train_state

# Type aliases
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

## Mixed Precision Training

Mixed precision training is a technique that uses both 16-bit and 32-bit floating-point numbers to speed up training. The idea is to use 16-bit floating-point numbers for most of the computations, as they are faster and require less memory. However, 16-bit floating-point numbers have a smaller range and precision compared to 32-bit floating-point numbers. Therefore, we use 32-bit floating-point numbers for certain computations, such as the model's weight updates and the final loss computation, to avoid numerical instability.

A potential problem with `float16` is that we can encounter underflow and overflow issues during training. This means that the gradients or activations become too large or too small to be represented in the range of `float16`, and we lose information. Scaling the loss and gradients by a constant factor can help mitigate this issue to bring the values back into the representable range. This is known as [loss scaling](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling), and it is a common technique used in mixed precision training.

As an alternative, JAX and other deep learning frameworks like [PyTorch](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch/) also support the `bfloat16` format, which is a 16-bit floating-point format with 8 exponent bits and 7 mantissa bits. The `bfloat16` format has a larger range but lower precision compared to the IEEE half-precision type `float16`, and matches `float32` in terms of range. A closer comparison between the formats is shown in the figure below (figure credit: [Google Cloud Documentation](https://cloud.google.com/tpu/docs/bfloat16)):


The main benefit of using `bfloat16` is that it can be used without loss scaling, as it has a larger range compared to `float16`. This allows `bfloat16` to be used as a drop-in replacement for `float32` in many cases to save memory and achieve performances close to `float32` (see e.g. [JKalamkar et al., 2019](https://arxiv.org/abs/1905.12322)). For situations where precision matters over range, `float16` may be the better option. Besides memory efficiency, many accelerators like [TPUs](https://cloud.google.com/tpu/docs/bfloat16) and [GPUs](https://www.nvidia.com/en-us/data-center/tensor-cores/) have native support for `bfloat16`, which can lead up to 2x speedup in training performance compared to `float32` on these devices. Hence, we will use `bfloat16` in this notebook.

We implement mixed precision training by lowering all features and activations within the model to `bfloat16`, while keeping the weights and optimizer states in `float32`. This is done to keep high precision for the weight updates and optimizer states, while reducing the memory footprint and increasing the training speed by using `bfloat16` for the forward and backward passes. While this does not reduce the memory footprint of the model parameters themselves, we often achieve a significant reduction in memory consumption due to the reduced memory footprint of the activations without influencing the model's performance.

In [3]:
class MLPClassifier(nn.Module):
    dtype: Any
    hidden_size: int = 256
    num_classes: int = 100
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.hidden_size,
            dtype=self.dtype,  # Computation in specified dtype, params stay in float32
        )(x)
        x = nn.LayerNorm(dtype=self.dtype)(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.num_classes,
            dtype=self.dtype,
        )(x)
        x = x.astype(jnp.float32)
        x = nn.log_softmax(x, axis=-1)
        return x

In [4]:
x = jnp.ones((512, 128), dtype=jnp.float32)
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
model_float32 = MLPClassifier(dtype=jnp.float32)
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})

'\n\n'

In [5]:
model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16)
model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})

'\n\n'

Gradient checkpointing is a technique that trades compute for memory by recomputing some activations during the backward pass. The idea is to store only a subset of the activations during the forward pass, and recompute the rest of the activations during the backward pass. This can be useful when the memory consumption of the activations is the limiting factor for the model's size, and the recomputation of the activations is cheaper than storing them. This is often the case for models with a large memory footprint, such as the Transformer, where the activations can be a significant portion of the memory consumption.
In JAX and Flax, we can implement gradient checkpointing using the `remat` function. The `remat` function allows us to control which intermediate arrays should be saved on the forward pass, and which are recomputed on the backward pass. As a simple example, consider the following function that computes the GELU activation function manually with its approximation (see e.g. [Hendrycks and Gimpel, 2016](https://arxiv.org/abs/1606.08415)). Note that in practice, we would use the `gelu` function from the `flax.nn` module which is already optimized, but we use this example to illustrate the concept of gradient checkpointing:


In [6]:
def gelu(x: jax.Array) -> jax.Array:
    """GeLU activation function with approximate tanh."""
    # This will be printed once every time the function is executed.
    jax.debug.print("Executing GeLU")
    # See https://arxiv.org/abs/1606.08415 for details.
    x3 = jnp.power(x, 3)
    tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3)
    return 0.5 * x * (1 + jnp.tanh(tanh_input))

In this function, we instantiate several intermediate tensors, which we may need to store during the backward pass and can be expensive for large tensors. Meanwhile, the computation is relatively cheap, such that we would want to compute these tensors during the backward pass instead of storing them. We can use the `remat` function to control which tensors are stored and which are recomputed during the backward pass. We can use the `remat` function as follows:

In [7]:
def loss_fn(x: jax.Array, remat: bool) -> jax.Array:
    act_fn = gelu
    if remat:
        act_fn = jax.remat(act_fn)
    return jnp.mean(act_fn(x))

If we now transform this function with a `jax.grad` call, we will see that JAX is executing the function twice (we see the `Executing GeLU` print statement twice). This is because JAX is computing the forward pass, then releases all intermediate tensors, and then recomputes them again in the backward pass.

In [8]:
x = jax.random.normal(jax.random.PRNGKey(0), (100,))
grad_fn = jax.grad(loss_fn)
_ = grad_fn(x, remat=True)

Executing GeLU
Executing GeLU


If we would run the same function without the `remat` function, we would only see the `Executing GeLU` print statement once, as JAX would not need to recompute the intermediate tensors during the backward pass.

In [9]:
_ = loss_fn(x, remat=False)

Executing GeLU


This shows that the `remat` function is controlling which tensors are stored and which are recomputed during the backward pass. We will see in the later Transformer example how we can use it in a neural network layer.

In JAX, the XLA compiler can also automatically apply rematerialization to the forward pass when we jit the function. In that case, we do not need to use the `remat` function explicitly, as the XLA compiler will automatically apply rematerialization to the forward pass. However, it can still be beneficial to use the `remat` function in some cases, like in `scans` (see [practical notes on remat](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes)) or to manually control which tensors are stored and which are recomputed.

A common trade-off in training large models is the batch size. A larger batch size can lead to a more accurate estimate of the gradient, but it also requires more memory. In some cases, the batch size is limited by the memory of the accelerator, and we cannot increase the batch size further. In these cases, we can use gradient accumulation to simulate a larger batch size by accumulating the gradients over multiple sub-batches. Each sub-batch is independently processed, and we perform an optimizer step once all sub-batches have been processed. Gradient accumulation can be useful when the memory consumption of the activations is the limiting factor for the model's size, but we require a larger batch size for training. However, a disadvantage of gradient accumulation is that each sub-batch is processed independently and sequentially, such that nothing is parallelized and we need to ensure that we can still utilize the accelerator to its full potential with the small batch size.



In JAX and Flax, we have easy control over the gradient accumulation process, since we explicitly calculate the gradients via `jax.grad`. Let's implement this process for our simple classification MLP from the mixed precision training. We first create a train state from Flax, which we extend by an RNG for easier handling of dropout.

In [10]:
class TrainState(train_state.TrainState):
    rng: jax.Array

In [11]:
@dataclass
class Batch:
    inputs: jax.Array
    labels: jax.Array

In [12]:
def classification_loss_fn(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[PyTree, Metrics]:
    """Classification loss function with cross-entropy."""
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = batch.inputs.shape[0]
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

In [13]:
def accumulate_gradients_loop(
    state: TrainState,
    batch: Batch,
    rng: jax.random.PRNGKey,
    num_minibatches: int,
    loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
    """Calculate gradients and metrics for a batch using gradient accumulation.

    Args:
        state: Current training state.
        batch: Full training batch.
        rng: Random number generator to use.
        num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
        loss_fn: Loss function to calculate gradients and metrics.

    Returns:
        Tuple with accumulated gradients and metrics over the minibatches.
    """
    batch_size = batch.inputs.shape[0]
    minibatch_size = batch_size // num_minibatches
    rngs = jax.random.split(rng, num_minibatches)
    # Define gradient function for single minibatch.
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    # Prepare loop variables.
    grads = None
    metrics = None
    for minibatch_idx in range(num_minibatches):
        with jax.named_scope(f"minibatch_{minibatch_idx}"):
            # Split the batch into minibatches.
            start = minibatch_idx * minibatch_size
            end = start + minibatch_size
            minibatch = jax.tree_map(lambda x: x[start:end], batch)
            # Calculate gradients and metrics for the minibatch.
            (_, step_metrics), step_grads = grad_fn(
                state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
            )
            # Accumulate gradients and metrics across minibatches.
            if grads is None:
                grads = step_grads
                metrics = step_metrics
            else:
                grads = jax.tree_map(jnp.add, grads, step_grads)
                metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    # Average gradients over minibatches.
    grads = jax.tree_map(lambda g: g / num_minibatches, grads)
    return grads, metrics

A disadvantage of the implementation above is that we need to compile the gradient function for each sub-batch, which can be slow. We can avoid this by using the `scan` transformation in JAX ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)), which allows us to write a for-loop with a single compilation of the inner step. The `scan` transformation requires the function to take two inputs: the `carry` and the input `x`. The `carry` is the state that is passed between the steps, and the `x` input is the input to the current step. The function returns the new `carry` and any output that we want to gather per step. In our case, the `carry` is the accumulated gradients and the accumulated metrics of all previous steps, and the `x` input is the current minibatch index, with which we select the minibatch and RNG to use. As the new carry, we return the updated accumulated gradients and metrics, and do not require a per-step output. We implement the gradient accumulation with `scan` below:

In [14]:
def accumulate_gradients_scan(
    state: TrainState,
    batch: Batch,
    rng: jax.random.PRNGKey,
    num_minibatches: int,
    loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
    """Calculate gradients and metrics for a batch using gradient accumulation.

    In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time.

    Args:
        state: Current training state.
        batch: Full training batch.
        rng: Random number generator to use.
        num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
        loss_fn: Loss function to calculate gradients and metrics.

    Returns:
        Tuple with accumulated gradients and metrics over the minibatches.
    """
    batch_size = batch.inputs.shape[0]
    minibatch_size = batch_size // num_minibatches
    rngs = jax.random.split(rng, num_minibatches)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]:
        """Determine gradients and metrics for a single minibatch."""
        minibatch = jax.tree_map(
            lambda x: jax.lax.dynamic_slice_in_dim(  # Slicing with variable index (jax.Array).
                x, start_index=minibatch_idx * minibatch_size, slice_size=minibatch_size, axis=0
            ),
            batch,
        )
        (_, step_metrics), step_grads = grad_fn(
            state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
        )
        return step_grads, step_metrics

    def _scan_step(
        carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int
    ) -> Tuple[Tuple[PyTree, Metrics], None]:
        """Scan step function for looping over minibatches."""
        step_grads, step_metrics = _minibatch_step(minibatch_idx)
        carry = jax.tree_map(jnp.add, carry, (step_grads, step_metrics))
        return carry, None

    # Determine initial shapes for gradients and metrics.
    grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0)
    grads = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
    metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
    # Loop over minibatches to determine gradients and metrics.
    (grads, metrics), _ = jax.lax.scan(
        _scan_step, init=(grads, metrics), xs=jnp.arange(num_minibatches), length=num_minibatches
    )
    # Average gradients over minibatches.
    grads = jax.tree_map(lambda g: g / num_minibatches, grads)
    return grads, metrics

In [15]:
def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]:
    if use_scan:
        return accumulate_gradients_scan(*args, **kwargs)
    else:
        return accumulate_gradients_loop(*args, **kwargs)

In [16]:
def train_step(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
    num_minibatches: int,
) -> Tuple[TrainState, Metrics]:
    """Training step function.

    Executes a full training step with gradient accumulation.

    Args:
        state: Current training state.
        metrics: Current metrics, accumulated from previous training steps.
        batch: Training batch.
        num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.

    Returns:
        Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
    """
    # Split the random number generator for the current step.
    rng, step_rng = jax.random.split(state.rng)
    # Determine gradients and metrics for the full batch.
    grads, step_metrics = accumulate_gradients(
        state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True
    )
    # Optimizer step.
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Accumulate metrics across training steps.
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [17]:
batch_size = 512
num_inputs = 128
num_classes = 100
rng_seed = 0

rng = jax.random.PRNGKey(rng_seed)
data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4)
batch = Batch(
    inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)),
    labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes),
)

In [18]:
# Zero dropout for checking later equality between training with and without gradient accumulation.
model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0)
params = model.init(model_rng, batch.inputs, train=False)["params"]
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(1e-3),
    rng=state_rng,
)

In [19]:
_, metric_shapes = jax.eval_shape(
    functools.partial(train_step, num_minibatches=4),
    state,
    None,
    batch,
)
print("Metric shapes:")
pprint(metric_shapes)

  minibatch = jax.tree_map(
  grads = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
  metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
  minibatch = jax.tree_map(


Metric shapes:
{'accuracy': (ShapeDtypeStruct(shape=(), dtype=int32),
              ShapeDtypeStruct(shape=(), dtype=int32)),
 'loss': (ShapeDtypeStruct(shape=(), dtype=float32),
          ShapeDtypeStruct(shape=(), dtype=int32))}


  carry = jax.tree_map(jnp.add, carry, (step_grads, step_metrics))
  grads = jax.tree_map(lambda g: g / num_minibatches, grads)


In [20]:
train_step_jit = jax.jit(
    train_step,
    static_argnames="num_minibatches",
)

In [21]:
def train_with_minibatches(
    state: TrainState,
    batch: Batch,
    num_minibatches: int,
    num_train_steps: int,
) -> Tuple[TrainState, Metrics]:
    """Small helper function for training loop."""
    train_metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
    for _ in range(num_train_steps):
        state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches)
    return state, train_metrics

In [22]:
def print_metrics(metrics: Metrics, title: str | None = None) -> None:
    """Prints metrics with an optional title."""
    metrics = jax.device_get(metrics)
    lines = [f"{k}: {v[0] / v[1]:.6f}" for k, v in metrics.items()]
    if title:
        title = f" {title} "
        max_len = max(len(title), max(map(len, lines)))
        lines = [title.center(max_len, "=")] + lines
    print("\n".join(lines))

In [23]:
state_mini1, metrics_mini1 = train_with_minibatches(
    state, batch, num_minibatches=1, num_train_steps=5
)
state_mini4, metrics_mini4 = train_with_minibatches(
    state, batch, num_minibatches=4, num_train_steps=5
)
print_metrics(metrics_mini1, "Minibatch 1")
print_metrics(metrics_mini4, "Minibatch 4")

  train_metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
  minibatch = jax.tree_map(
  grads = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
  metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
  minibatch = jax.tree_map(
  carry = jax.tree_map(jnp.add, carry, (step_grads, step_metrics))
  grads = jax.tree_map(lambda g: g / num_minibatches, grads)
  metrics = jax.tree_map(jnp.add, metrics, step_metrics)
  train_metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
  minibatch = jax.tree_map(
  grads = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
  metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
  minibatch = jax.tree_map(
  carry = jax.tree_map(jnp.add, carry, (step_grads, step_metrics))
  grads = jax.tree_map(lambda g: g / num_minibatches, grads)
  metrics = jax.tree_map(jnp.add, metrics, step_metrics)


== Minibatch 1 ===
accuracy: 0.026953
loss: 4.593171
== Minibatch 4 ===
accuracy: 0.026953
loss: 4.593171


## JAX-Specific Structures

In JAX, we can also use some JAX-specific structures to reduce the memory footprint of the model and help training larger models. These may not be useful for other frameworks like PyTorch, but good to keep in mind for JAX users. We cover two aspects: donating buffers and scanning.

### Donating buffers

In JAX, we follow the idea of functional programming where all functions need to be stateless and pure. This means that we cannot modify the input arguments, and we cannot modify other global variables. This is also true for the model parameters, which are passed as arguments to the training step and returned with updated values. This enforces the device to have memory for at least twice the model parameters and optimizer state. However, as the model grows in size, this can become a significant limitation. To mitigate this, JAX provides a mechanism to donate buffers, which allows us to reuse the memory of the input arguments for the output arguments. This can be useful when the input and output arguments have the same shape and data type, and we do not need the input arguments after the function has been executed. This is often the case for the model parameters and optimizer state, where we do not need the input arguments after the optimizer step has been executed. We can use the `jax.jit` function with the `donate_argnums`/`donate_argnames` argument to donate buffers. We can donate buffers for the model parameters and optimizer state, which can reduce the memory footprint of the model and help training larger models. We implement this below for the training step


In [24]:
train_step_donated = jax.jit(
    train_step,
    static_argnames="num_minibatches",
    donate_argnames=(
        "state",
        "metrics",
    ),
)

## Intermediate Summary

In this notebook, we have discussed several techniques to train larger models on a single device. We have implemented mixed precision training, gradient accumulation, and gradient checkpointing on a simple MLP model. We have also discussed JAX-specific structures to reduce the memory footprint of the model and help training larger models. In the next part ([Part 1.2](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/scaling/JAX/single_gpu_transformer.html)), we will combine these techniques to train a larger Transformer model on a single GPU, and explore the benefits and trade-offs of each technique. We will also profile the model to get further insights into the efficiency of these techniques.

## References and Resources

\[Chen et al., 2016\] Chen, T., Xu, B., Zhang, C. and Guestrin, C., 2016. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174. [Paper link](https://arxiv.org/abs/1604.06174)

\[Micikevicius et a., 2018\] Micikevicius, P., Narang, S., Alben, J., Diamos, G., Elsen, E., Garcia, D., Ginsburg, B., Houston, M., Kuchaiev, O., Venkatesh, G. and Wu, H., 2018, February. Mixed Precision Training. In International Conference on Learning Representations. [Paper link](https://arxiv.org/abs/1710.03740)

\[Bulatov, 2018\] Bulatov, Y., 2018. Fitting larger networks into memory. [Blog post link](https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9)

\[Kalamkar et al., 2019\] Kalamkar, D., Mudigere, D., Mellempudi, N., Das, D., Banerjee, K., Avancha, S., Vooturi, D.T., Jammalamadaka, N., Huang, J., Yuen, H. and Yang, J., 2019. A study of BFLOAT16 for deep learning training. arXiv preprint arXiv:1905.12322. [Paper link](https://arxiv.org/abs/1905.12322)

\[Ahmed et al., 2022\] Ahmed, S., Sarofeen, C., Ruberry, M., et al., 2022. What Every User Should Know About Mixed Precision Training in PyTorch. [Tutorial link](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch/)

\[Weng et al., 2022\] Weng, L., Brockman, G., 2022. Techniques for training large neural networks. [Blog link](https://openai.com/research/techniques-for-training-large-neural-networks)

\[Raschka, 2023\] Raschka, S., 2023. Optimizing Memory Usage for Training LLMs and Vision Transformers in PyTorch. [Tutorial link](https://lightning.ai/pages/community/tutorial/pytorch-memory-vit-llm/) (gives more details for the topics here in PyTorch)

\[HuggingFace, 2024\] HuggingFace, 2024. Performance and Scalability: How To Fit a Bigger Model and Train It Faster. [Tutorial link](https://huggingface.co/docs/transformers/v4.18.0/en/performance)

\[NVIDIA, 2024\] NVIDIA, 2024. Mixed Precision Training. [Documentation link](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html)

\[NVIDIA, 2024\] NVIDIA, 2024. Performance Guide for Training. [Documentation link](https://docs.nvidia.com/deeplearning/performance/index.html)

\[Google, 2024\] JAX Team Google, 2024. Control autodiff’s saved values with jax.checkpoint (aka jax.remat). [Tutorial link](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html)

\[Google, 2024\] JAX Team Google, 2024. Profiling JAX programs. [Tutorial link](https://jax.readthedocs.io/en/latest/profiling.html)

\[Google, 2024\] JAX Team Google, 2024. GPU peformance tips. [Tutorial link](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html)