##### Copyright 2020 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Gradient checkpointing

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/addons/tutorials/training_gradient_checkpointing"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/training_gradient_checkpointing.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/addons/blob/master/docs/tutorials/training_gradient_checkpointing.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
      <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/docs/tutorials/training_gradient_checkpointing.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

### What is gradient checkpointing?

Gradient checkpointing enables users to train large models with relatively small memory resources. Large models could refer to 
1. Models with large variables i.e weight matrices. As a consequence such models have correspondingly large gradients and optimizer states. The activations (intermediate outputs from the model layers) tend to be relatively small (depends on the batch size). Typically fully connected networks and RNNs fall under this category.
2. Models with small weights but large activations. CNNs and transformers tend to fall under this category.

It is important to note that gradient checkpointing is meant to help with models of type 2. Models of type 1 do not stand to gain much benefit.

### How is it implemented?

1. *Recompute step* - This is the core step in gradient checkpointing. Recompute allows one to calculate the  forward activations while doing the backward pass. Therefore there is no need to store intermediate activations during forward pass, thus saving memory. But during the backward pass, for every layer the forward  activations need to be recomputed starting from the first layer, thus increasing time to train.
2. *Checkpoint step* - This step allows one to designate (either manually or automatically) certain layers in the model as ‘checkpoint layers’ whose activations will be persisted in memory during the forward pass.. This essentially allows the user to balance the memory vs time tradeoff. By checkpointing certain layers, one just has to recompute new activations from the last checkpoint layer thereby improving training times.

This tutorial demonstrates how to use the recompute step implemented in TF add ons to train a TF Keras model. Note that at this time the implementation works only in eager mode for TF 2.x and only for sequential networks.

## Setup

In [2]:
try:
  %tensorflow_version 2.x
except:
  pass

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, optimizers

In [3]:
# Build the model. This is a sample large sequential cnn model.
def get_cnn_model(img_dim, n_channels):
    model = tf.keras.Sequential([
    layers.Reshape(
        target_shape=[img_dim, img_dim, n_channels],
        input_shape=(img_dim, img_dim, n_channels)),
    layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(60, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((2, 2), padding='same'),
    layers.Flatten(),
    layers.Dense(32, activation=tf.nn.relu),
    layers.Dense(10)])
    return model

In [4]:
# define the loss function
def compute_loss(logits, labels):
  return tf.reduce_mean(
      tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=labels))

In [5]:
# this how you wrap your model with the 'recompute decorator'
@tfa.training.recompute_sequential
def model_fn(model, x):
    return model(x)

In [6]:
# perform training with dummy inputs
def train():
    img_dim = 256
    n_channels = 1
    bs = 16
    x = tf.ones([bs,img_dim,img_dim,1])
    y = tf.ones([bs], dtype=tf.int64)
    model = get_cnn_model(img_dim, n_channels)
    optimizer = optimizers.SGD()
    for _ in range(5):
        with tf.GradientTape() as tape: 
            logits = model_fn(model, x, _watch_vars=model.trainable_variables)
            # To train without recompute, uncomment this line and remove the recompute_sequential decorator
            #logits = model_fn(model, x)
            loss  = compute_loss(logits, y)
            print('loss', loss) 
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        del grads 
    return

In [7]:
# call the train function
train()

loss tf.Tensor(2.3119702, shape=(), dtype=float32)
loss tf.Tensor(2.2820306, shape=(), dtype=float32)
loss tf.Tensor(2.22289, shape=(), dtype=float32)
loss tf.Tensor(1.9491141, shape=(), dtype=float32)
loss tf.Tensor(0.48555315, shape=(), dtype=float32)
