What is gradient checkpointing?
Gradient checkpointing enables users to train large models with relatively small memory resources. Large models could refer to
1. Models with large variables i.e weight matrices. As a consequence such models have correspondingly large gradients and optimizer states. The activations (intermediate outputs from the model layers) tend to be relatively small (depends on the batch size). Typically fully connected networks and RNNs fall under this category.
2. Models with small weights but large activations. CNNs and transformers tend to fall under this category.
It is important to note that gradient checkpointing is meant to help with models of type 2. Models of type 1 do not stand to gain much benefit.

In [None]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets
#from tensorflow.python.ops import custom_gradient
from datetime import datetime
from packaging import version
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import time

Model definition
To make use of tf recompute grad, one has to manually split their 'large' model into manageable 'partitions' (each partition corresponds to a logical 'checkpoint'). Below is an example of a simple 'large' CNN model. 

In [None]:
def get_big_cnn_model(img_dim, n_channels, num_partitions, blocks_per_partition):
    model = tf.keras.Sequential()
    model.add(layers.Input(shape=(img_dim, img_dim, n_channels)))
    for _ in range(num_partitions):
        for _ in range(blocks_per_partition):
            model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu))
            model.add(layers.MaxPooling2D((1, 1), padding='same'))
            model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu))
            model.add(layers.MaxPooling2D((1, 1), padding='same'))
            model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu))
            model.add(layers.MaxPooling2D((1, 1), padding='same'))
    model.add(layers.Flatten())
    model.add(layers.Dense(32, activation=tf.nn.relu))
    model.add(layers.Dense(10))
    return model

Here is an example of the large CNN model that has been split into 3 partitions

In [None]:
def get_split_cnn_model(img_dim, n_channels, num_partitions, blocks_per_partition):
    models = [tf.keras.Sequential() for _ in range(num_partitions)]
    models[0].add(layers.Input(shape=(img_dim, img_dim, n_channels)))
    for i in range(num_partitions):
        model = models[i]
        if i > 0:
            last_shape = models[i-1].layers[-1].output_shape
            model.add(layers.Input(shape=last_shape[1:]))
        for _ in range(blocks_per_partition):
            model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu))
            model.add(layers.MaxPooling2D((1, 1), padding='same'))
            model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu))
            model.add(layers.MaxPooling2D((1, 1), padding='same'))
            model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu))
            model.add(layers.MaxPooling2D((1, 1), padding='same'))
    models[-1].add(layers.Flatten())
    models[-1].add(layers.Dense(32, activation=tf.nn.relu))
    models[-1].add(layers.Dense(10))
    return models

In [None]:
def compute_loss(logits, labels):
  return tf.reduce_mean(
      tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=labels))

In [None]:
def get_data(img_dim, n_channels, batch_size):
    inputs = tf.ones([batch_size,img_dim,img_dim,n_channels])
    labels = tf.ones([batch_size], dtype=tf.int64)
    return inputs, labels

In [None]:
# This training loop should produce an OOM exception on a GPU with 16GB memory
def train(n_steps):
    tf.random.set_seed(123)
    img_dim, n_channels, batch_size = 256, 1, 16
    x, y = get_data(img_dim, n_channels, batch_size)
    model = get_big_cnn_model(img_dim, n_channels, num_partitions=3, blocks_per_partition=9)
    optimizer = optimizers.SGD()
    losses = []
    tr_vars = model.trainable_variables
    for _ in range(n_steps):
        with tf.GradientTape() as tape:
            logits = model(x)
            loss  = compute_loss(logits, y)
            print('loss ', loss)
            losses.append(loss)       
        grads = tape.gradient(loss, tr_vars) # tr_vars
        optimizer.apply_gradients(zip(grads, tr_vars))
        del grads 
    return losses

In [None]:
# This training loop should be able to run successfully. Infact you can more double the model size by setting blocks_per_partition = 20
# and still train successfully
def train_tf_recompute_split(n_steps):
    tf.random.set_seed(123)
    img_dim, n_channels, batch_size = 256, 1, 16
    x, y = get_data(img_dim, n_channels, batch_size)
    models = get_split_cnn_model(img_dim, n_channels, num_partitions=3, blocks_per_partition=9)
    model1, model2, model3 = models
    model1_re = tf.recompute_grad(model1)
    model2_re = tf.recompute_grad(model2)
    model3_re = tf.recompute_grad(model3)
    optimizer = optimizers.SGD()
    tr_vars = model1.trainable_variables + model2.trainable_variables + model3.trainable_variables
    losses = []
    for _ in range(n_steps):
        with tf.GradientTape() as tape:
            logits1 = model1_re(x)
            logits2 = model2_re(logits1)
            logits3 = model3_re(logits2)
            loss  = compute_loss(logits3, y)
            print('loss ', loss)
            losses.append(loss)
        grads = tape.gradient(loss, tr_vars) # tr_vars
        optimizer.apply_gradients(zip(grads, tr_vars))
        del grads 
    return losses

In [None]:
start = time.time()
train_tf_recompute_split(1)
end = time.time()
print('Time elapsed is ', end - start, ' seconds')

In [None]:
# num_part = 3 and blocks_per_part = 7
# recomp
# losstf.Tensor(2.3025851, shape=(), dtype=float32)
# losstf.Tensor(2.292471, shape=(), dtype=float32)
# losstf.Tensor(2.2751424, shape=(), dtype=float32)

# no recomp
# losstf.Tensor(2.3025851, shape=(), dtype=float32)
# losstf.Tensor(2.2924523, shape=(), dtype=float32)
# losstf.Tensor(2.2754436, shape=(), dtype=float32)
# Time elapsed is  17.34099841117859  seconds